Spaces:
Runtime error
Runtime error
| from typing import List, Dict, Any, Optional, Type | |
| from langchain_core.tools import BaseTool | |
| from pydantic import BaseModel, Field | |
| import pandas as pd | |
| from .sql_runtime import SQLRuntime | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from .load_llm import load_llm | |
| from langchain_core.messages import SystemMessage | |
| from langchain_core.prompts import HumanMessagePromptTemplate | |
| from langchain.agents import AgentExecutor, create_react_agent | |
| from dotenv import load_dotenv | |
| from react import run_agent_executor | |
| from prompts import react_prompt | |
| # definig the input schema | |
| class QueryInput(BaseModel): | |
| query: str = Field(..., description="The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries") | |
| class TableNameInput(BaseModel): | |
| table_name: str = Field(..., description="The name of the table to analyze") | |
| class ColumnSearchInput(BaseModel): | |
| table_name: str = Field(..., description="The name of the table to search") | |
| column_name: str = Field(..., description="The name of the column to search") | |
| limit: int = Field(default=10, description="Maximum number of distinct values to return") | |
| class SQLQueryTool(BaseTool): | |
| name: str = "sql_query" | |
| description: str = """ | |
| Execute a SQL query and return the results. | |
| Use this when you need to run a specific SQL query on the elections database. | |
| The query should be a valid SQL statement and should end with a semicolon. | |
| There should be no harmful queries executed. | |
| There are three tables in the database: elections_2019, elections_2024, maha_2019 | |
| """ | |
| args_schema: Type[BaseModel] = QueryInput | |
| # def __init__(self, db_path: Optional[str] = None): | |
| # super().__init__() | |
| # self. | |
| def _run(self, query: str) -> str: | |
| sql_runtime = SQLRuntime('../data/elections.db') | |
| try: | |
| result = sql_runtime.execute(query) | |
| if result["code"] != 0: | |
| return f"Error executing query: {result['msg']['reason']}" | |
| # Convert to DataFrame for nice string representation | |
| df = pd.DataFrame(result["data"]) | |
| if not df.empty: | |
| return df.to_string() | |
| return "Query returned no results" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| class TableInfoTool(BaseTool): | |
| name: str = "get_table_info" | |
| description: str = """ | |
| Get information about a specific table including its schema and basic statistics. | |
| Use this when you need to understand the structure of a table or get basic statistics about it. | |
| """ | |
| args_schema: Type[BaseModel] = TableNameInput | |
| # def __init__(self, db_path: Optional[str] = None): | |
| # super().__init__() | |
| def _run(self, table_name: str) -> str: | |
| sql_runtime = SQLRuntime('../data/elections.db') | |
| try: | |
| # Get schema | |
| schema = sql_runtime.get_schema_for_table(table_name) | |
| # Get row count | |
| count_query = f"SELECT COUNT(*) FROM {table_name}" | |
| count_result = sql_runtime.execute(count_query) | |
| row_count = count_result["data"][0][0] if count_result["code"] == 0 else "Error" | |
| # Get sample data | |
| sample_query = f"SELECT * FROM {table_name} LIMIT 3" | |
| sample_result = sql_runtime.execute(sample_query) | |
| info = f""" | |
| Table: {table_name} | |
| Columns: {', '.join(schema)} | |
| Row Count: {row_count} | |
| Sample Data: | |
| {pd.DataFrame(sample_result['data'], columns=schema).to_string() if sample_result['code'] == 0 else 'Error getting sample data'} | |
| """ | |
| return info | |
| except Exception as e: | |
| return f"Error getting table info: {str(e)}" | |
| class ColumnValuesTool(BaseTool): | |
| name: str = "find_column_values" | |
| description: str = """ | |
| Find distinct values in a specific column of a table. | |
| Use this when you need to know what unique values exist in a particular column. | |
| """ | |
| args_schema: Type[BaseModel] = ColumnSearchInput | |
| # def __init__(self, db_path: Optional[str] = None): | |
| # super().__init__() | |
| # self.sql_runtime = SQLRuntime(db_path) | |
| def _run(self, table_name: str, column_name: str, limit: int = 10) -> str: | |
| sql_runtime = SQLRuntime('../data/elections.db') | |
| try: | |
| query = f""" | |
| SELECT DISTINCT {column_name} | |
| FROM {table_name} | |
| LIMIT {limit} | |
| """ | |
| result = sql_runtime.execute(query) | |
| if result["code"] != 0: | |
| return f"Error finding values: {result['msg']['reason']}" | |
| values = [row[0] for row in result["data"]] | |
| return f"Distinct values in {column_name}: {', '.join(map(str, values))}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| class ListTablesTool(BaseTool): | |
| name: str = "list_tables" | |
| description: str = """ | |
| List all available tables in the database. | |
| Use this when you need to know what tables are available to query. | |
| """ | |
| # def __init__(self, db_path: Optional[str] = None): | |
| # super().__init__() | |
| # self.sql_runtime = SQLRuntime(db_path) | |
| def _run(self, *args, **kwargs) -> str: | |
| sql_runtime = SQLRuntime('../data/elections.db') | |
| try: | |
| tables = sql_runtime.list_tables() | |
| return f"Available tables: {', '.join(tables)}" | |
| except Exception as e: | |
| return f"Error listing tables: {str(e)}" | |
| def create_sql_agent_tools(db_path: Optional[str] = '../data/elections.db') -> List[BaseTool]: | |
| """ | |
| Create a list of all SQL tools for use with a Langchain agent. | |
| """ | |
| return [ | |
| SQLQueryTool(), | |
| TableInfoTool(), | |
| # ColumnValuesTool(), | |
| ListTablesTool() | |
| ] | |
| if __name__ == "__main__": | |
| load_dotenv() | |
| tools = create_sql_agent_tools() | |
| for tool in tools: | |
| print(f"Tool: {tool.name}") | |
| print(f"Description: {tool.description}") | |
| # print(f"Args Schema: {tool.args_schema.schema()}") | |
| # prompt = prompt = ChatPromptTemplate.from_messages( | |
| # [ | |
| # SystemMessage( | |
| # content=""" | |
| # You are a sql agent who has access to a database with three tables: elections_2019, elections_2024, maha_2019. | |
| # You can use the following tools: | |
| # - sql_query: Execute a SQL query and return the results. | |
| # - get_table_info: Get information about a specific table including its schema and basic statistics. | |
| # - find_column_values: Find distinct values in a specific column of a table. | |
| # - list_tables: List all available tables in the database. | |
| # Answer the questions using the tools provided. Do not execute harmful queries. | |
| # """ | |
| # ), | |
| # HumanMessagePromptTemplate.from_template("{text}"), | |
| # ] | |
| # ) | |
| output_parser = StrOutputParser() | |
| # Create the llm | |
| llm = load_llm() | |
| # llm.bind_tools(tools) | |
| # res = llm.invoke("who won elections in maharashtra in Nandurbar in elections 2019? use the given tools") | |
| # chain = prompt | llm | output_parser | |
| # Run the chain | |
| agent = create_react_agent(llm, tools, react_prompt) | |
| # Create an agent executor by passing in the agent and tools | |
| agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) | |
| print("Agent created successfully") | |
| # Run the agent | |
| # agent_executor.invoke({"input": "Who won the elections in 2019 for the state maharashtra in constituency Akkalkuwa?"}) | |
| res = agent_executor.invoke({"input": "who won elections in maharashtra in Nandurbar in elections 2019?"}) | |
| # run_agent_executor(agent_executor, {"input": "who won elections in maharashtra in Nandurbar in elections 2019?"}) | |