Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from langchain_community.utilities import SQLDatabase | |
| from langchain_openai import ChatOpenAI | |
| from langchain.chains import create_sql_query_chain | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers.openai_tools import PydanticToolsParser | |
| from langchain_core.pydantic_v1 import BaseModel, Field | |
| from typing import List | |
| import sqlite3 | |
| from langsmith import traceable | |
| from openai import OpenAI | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Set up LangSmith | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") | |
| os.environ["LANGCHAIN_PROJECT"] = "SQLq&a" | |
| # Initialize OpenAI client | |
| openai_client = OpenAI() | |
| # Set up the database connection | |
| db_path = os.path.join(os.path.dirname(__file__), "chinook.db") | |
| db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
| # Function to get table info | |
| def get_table_info(db_path): | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| # Get all table names | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| table_info = {} | |
| for table in tables: | |
| table_name = table[0] | |
| cursor.execute(f"PRAGMA table_info({table_name})") | |
| columns = cursor.fetchall() | |
| column_names = [column[1] for column in columns] | |
| table_info[table_name] = column_names | |
| conn.close() | |
| return table_info | |
| # Get table info | |
| table_info = get_table_info(db_path) | |
| # Format table info for display | |
| def format_table_info(table_info): | |
| info_str = f"Total number of tables: {len(table_info)}\n\n" | |
| info_str += "Tables and their columns:\n\n" | |
| for table, columns in table_info.items(): | |
| info_str += f"{table}:\n" | |
| for column in columns: | |
| info_str += f" - {column}\n" | |
| info_str += "\n" | |
| return info_str | |
| # Initialize the language model | |
| llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
| class Table(BaseModel): | |
| """Table in SQL database.""" | |
| name: str = Field(description="Name of table in SQL database.") | |
| # Create the table selection prompt | |
| table_names = "\n".join(db.get_usable_table_names()) | |
| system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \ | |
| The tables are: | |
| {table_names} | |
| Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.""" | |
| table_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system), | |
| ("human", "{input}"), | |
| ]) | |
| llm_with_tools = llm.bind_tools([Table]) | |
| output_parser = PydanticToolsParser(tools=[Table]) | |
| table_chain = table_prompt | llm_with_tools | output_parser | |
| # Function to get table names from the output | |
| def get_table_names(output: List[Table]) -> List[str]: | |
| return [table.name for table in output] | |
| # Create the SQL query chain | |
| query_chain = create_sql_query_chain(llm, db) | |
| # Combine table selection and query generation | |
| full_chain = ( | |
| RunnablePassthrough.assign( | |
| table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]})) | |
| ) | |
| | query_chain | |
| ) | |
| # Function to strip markdown formatting from SQL query | |
| def strip_markdown(text): | |
| # Remove code block formatting | |
| text = re.sub(r'```sql\s*|\s*```', '', text) | |
| # Remove any leading/trailing whitespace | |
| return text.strip() | |
| # Function to execute SQL query | |
| def execute_query(query: str) -> str: | |
| try: | |
| # Strip markdown formatting before executing | |
| clean_query = strip_markdown(query) | |
| result = db.run(clean_query) | |
| return str(result) | |
| except Exception as e: | |
| return f"Error executing query: {str(e)}" | |
| # Create the answer generation prompt | |
| answer_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question. | |
| If there was an error in executing the SQL query, please explain the error and suggest a correction. | |
| Do not include any SQL code formatting or markdown in your response. | |
| Here is the database schema for reference: | |
| {table_info}"""), | |
| ("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:") | |
| ]) | |
| # Assemble the final chain | |
| chain = ( | |
| RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x)) | |
| .assign(result=lambda x: execute_query(x["query"])) | |
| | answer_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # Function to process user input and generate response | |
| def process_input(message, history, table_info_str): | |
| response = chain.invoke({"question": message, "table_info": table_info_str}) | |
| return response | |
| # Formatted table info | |
| formatted_table_info = format_table_info(table_info) | |
| # Create Gradio interface | |
| iface = gr.ChatInterface( | |
| fn=process_input, | |
| title="SQL Q&A Chatbot for Chinook Database", | |
| description="Ask questions about the Chinook music store database and get answers!", | |
| examples=[ | |
| ["Who are the top 5 artists with the most albums in the database?"], | |
| ["What is the total sales amount for each country?"], | |
| ["Which employee has made the highest total sales, and what is the amount?"], | |
| ["What are the top 10 longest tracks in the database, and who are their artists?"], | |
| ["How many customers are there in each country, and what is the total sales for each?"] | |
| ], | |
| additional_inputs=[ | |
| gr.Textbox( | |
| label="Database Schema", | |
| value=formatted_table_info, | |
| lines=10, | |
| max_lines=20, | |
| interactive=False | |
| ) | |
| ], | |
| theme="soft" | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch() |