Spaces:
Runtime error
Runtime error
| from .sql_runtime import SQLRuntime | |
| from pydantic import BaseModel, Field | |
| from .load_llm import load_llm | |
| from .prompts import sql_query_prompt, sql_query_summary_prompt, sql_query_visualization_prompt | |
| from langchain_core.runnables import chain | |
| from typing import Optional | |
| from dotenv import load_dotenv | |
| class Generated_query(BaseModel): | |
| """ | |
| The SQL query to execute, make sure to use semicolon at the end of the query, do not execute harmful queries | |
| """ | |
| queries: list[str] = Field(description="List of SQL queries to execute, use title case for strings, make sure to use semicolon at the end of each query, do not execute harmful queries") | |
| class QuerySummary(BaseModel): | |
| """ | |
| The summary of the SQL query results | |
| """ | |
| summary: str = Field(description="The analysis of the SQL query results") | |
| errors: list[str] = Field(description="The errors in the execution of the queries") | |
| queries: list[str] = Field(description="The SQL queries executed and their results") | |
| def sql_generator(input: dict) -> Generated_query: | |
| query, db_path = input["query"], input["db_path"] | |
| sql_runtime = SQLRuntime(dbname=db_path) | |
| query_generator_llm = load_llm().with_structured_output(Generated_query) | |
| # getting the schemas | |
| schemas = sql_runtime.get_schemas() | |
| # chain to generate the queries | |
| chain = sql_query_prompt | query_generator_llm | |
| # executing the chain | |
| gen_queries = chain.invoke({ | |
| "db_schema": schemas, | |
| "input": query | |
| }) | |
| # executing the queries | |
| res = sql_runtime.execute_batch(gen_queries.queries) | |
| # print(res) | |
| return { | |
| "input": query, | |
| "results": res | |
| } | |
| def sql_formatter(input): | |
| """ | |
| Formats the output of the SQL queries | |
| """ | |
| output = [] | |
| for item in input["results"]: | |
| if item["code"] == 0: | |
| output.append(f"Query: {item['msg']['input']}, Result: {item['data']}") | |
| else: | |
| output.append(f"Query: {item['msg']['input']}, Error: {item['msg']['traceback']}") | |
| # print(output) | |
| return { | |
| "query": input["input"], | |
| "results": output | |
| } | |
| def analyze_results(input) -> QuerySummary: | |
| """ | |
| Analyzes the results of the SQL queries executed on the election database | |
| """ | |
| chain = sql_query_summary_prompt | load_llm().with_structured_output(QuerySummary) | |
| # chain2 = sql_query_visualization_prompt | load_llm().with_structured_output(QuerySummary) | |
| return chain.invoke({ | |
| "query": input["query"], | |
| "results": input["results"] | |
| }) | |
| if __name__ == '__main__': | |
| load_dotenv() | |
| # executing the queries | |
| # results = sql_generator.invoke("Find the name of the candidate who got the maximum votes in Maharashtra elections 2019") | |
| # for result in results: | |
| # print(f"Query: {result['msg']['input']}") | |
| # if result["code"] != 0: | |
| # print(f"Error executing query: {result['msg']['reason']}") | |
| # print(f"Traceback: {result['msg']['traceback']}") | |
| # else: | |
| # print(result["data"]) | |
| # print("\n") | |
| # formatting the output | |
| res = sql_generator | sql_formatter | analyze_results | |
| formatted_output, formatted_output2 = res.invoke( | |
| { | |
| "query": "What are the different party symbols in Maharashtra elections 2019, create a list of all the symbols", | |
| "db_path": "./data/elections.db" | |
| } | |
| ) | |
| print(formatted_output.summary) | |
| print(formatted_output.errors) | |
| print(formatted_output.queries) | |
| print("\n") | |
| print(formatted_output2.summary) | |
| print(formatted_output2.errors) | |
| print(formatted_output2.queries) |