Spaces:
Runtime error
Runtime error
| """ | |
| Runtime that accepts a sql statement and runs it on sql server. | |
| Returns the results of sql execution. | |
| """ | |
| import traceback | |
| import sqlite3 | |
| # MODIFY THE PATH BELOW FOR YOUR SYSTEM | |
| my_db = r"../data/elections.db" | |
| class SQLRuntime(object): | |
| def __init__(self, dbname=None): | |
| if dbname is None: | |
| dbname = my_db | |
| conn = sqlite3.connect(dbname) # creating a connection | |
| self.cursor = conn.cursor() # we need the cursor to execute statement | |
| return | |
| def list_tables(self): | |
| result = self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall() | |
| table_names = sorted(list(zip(*result))[0]) | |
| return table_names | |
| def get_schema_for_table(self, table_name): | |
| result = self.cursor.execute("PRAGMA table_info('%s')" % table_name).fetchall() | |
| column_names = list(zip(*result))[1] | |
| return column_names | |
| def get_schemas(self): | |
| schemas = {} | |
| table_names = self.list_tables() | |
| for name in table_names: | |
| fields = self.get_schema_for_table(name) # fields of the table name | |
| schemas[name] = fields | |
| return schemas | |
| def execute(self, statement): | |
| code = 0 | |
| msg = { | |
| "text": "SUCCESS", | |
| "reason": None, | |
| "traceback": None, | |
| } | |
| data = None | |
| try: | |
| self.cursor.execute(statement) | |
| except sqlite3.OperationalError: | |
| code = -1 | |
| msg = { | |
| "text": "ERROR: SQL execution error", | |
| "reason": "possibly due to incorrect table/fields names", | |
| "traceback": traceback.format_exc(), | |
| } | |
| if code == 0: | |
| data = self.cursor.fetchall() | |
| msg["input"] = statement | |
| result = { | |
| "code": code, | |
| "msg": msg, | |
| "data": data | |
| } | |
| return result | |
| def execute_batch(self, queries): | |
| results = [] | |
| for query in queries: | |
| result = self.execute(query) | |
| results.append(result) | |
| return results | |
| def post_process(self, data): | |
| """ | |
| post process the data so that we can identify any harmful code and remove them. | |
| Also, llm output may need an output parser. | |
| :param data: | |
| :return: | |
| """ | |
| # IMPLEMENT YOUR CODE HERE FOR POST-PROCESSING and VALIDATION | |
| return data | |
| def sql_runtime(statement): | |
| """ | |
| Instantiates a sql runtime and executes the given sql statement | |
| :param statement: sql statement | |
| """ | |
| SQL = SQLRuntime() | |
| data = SQL.execute(statement) | |
| return data | |
| if __name__ == '__main__': | |
| # stmt = """ | |
| # SELECT * FROM elections_2019; | |
| # """ | |
| # stmt = input("Enter stmt: ") | |
| sql = SQLRuntime() | |
| tables = sql.list_tables() | |
| print(tables) | |
| schemas = {} | |
| for table in tables: | |
| schemas[table] = sql.get_schema_for_table(table) | |
| print(f"Table: {table}, Schema: {schemas[table]}\n") | |
| # data1 = sql.execute(stmt) | |
| # dat = data1["data"] | |
| # if dat is not None and len(dat) > 0: | |
| # for record in dat: | |
| # print(record) | |
| # print("-" * 100) | |
| # sample question: find out the votes polled by NOTA for each instance of Akkalkuwa in the parliamentary elections 2019. | |
| stmt = """ | |
| SELECT party_name, SUM(nota_votes) | |
| FROM elections_2019 | |
| WHERE constituency='Akkalkuwa' | |
| GROUP BY party_name; | |
| """ | |
| data1 = sql.execute(stmt) | |
| # print(data1) | |
| dat = data1["data"] | |
| if dat is not None and len(dat) > 0: | |
| for record in dat: | |
| print(record) | |
| print("-" * 100) |