import gradio as gr import os import sys import pandas as pd import sqlite3 from pathlib import Path import matplotlib.pyplot as plt import re # For Hugging Face Spaces, set project root to current directory PROJECT_ROOT = Path(__file__).parent.resolve() sys.path.append(str(PROJECT_ROOT)) # Import model loading and utility functions from code.train_sqlgen_t5_local import load_model as load_sql_model, generate_sql, get_schema_from_csv from code.train_intent_classifier_local import load_model as load_intent_model, classify_intent # Load models sql_model, sql_tokenizer, device = load_sql_model() intent_model, intent_tokenizer, device, label_mapping = load_intent_model() # Path to the built-in data file in the data folder DATA_FILE = str(PROJECT_ROOT / "data" / "testing_sql_data.csv") # Verify data file exists if not os.path.exists(DATA_FILE): raise FileNotFoundError(f"Data file not found at {DATA_FILE}. Please ensure testing_sql_data.csv exists in the data folder.") def process_query(question, chart_type="auto"): try: # Generate schema from CSV schema = get_schema_from_csv(DATA_FILE) # Generate SQL sql_query = generate_sql(question, schema, sql_model, sql_tokenizer, device) # --- Fix: Table and column name replacements --- sql_query = re.sub(r'(FROM|JOIN)\s+\w+', r'\1 data', sql_query, flags=re.IGNORECASE) sql_query = re.sub(r'(FROM|JOIN)\s+"[^"]+"', r'\1 data', sql_query, flags=re.IGNORECASE) sql_query = re.sub(r'(FROM|JOIN)\s+\'[^"]+\'', r'\1 data', sql_query, flags=re.IGNORECASE) sql_query = sql_query.replace('product_price', 'total_price') sql_query = sql_query.replace('store_name', 'store_id') sql_query = sql_query.replace('sales_method', 'date') sql_query = re.sub(r'\bsales\b', 'total_price', sql_query) # --- End fix --- # Classify intent intent = classify_intent(question, intent_model, intent_tokenizer, device, label_mapping) # Execute SQL on the CSV data df = pd.read_csv(DATA_FILE) conn = sqlite3.connect(":memory:") df.to_sql("data", conn, index=False, if_exists="replace") result_df = pd.read_sql_query(sql_query, conn) conn.close() # Defensive check for result_df columns if result_df.empty or len(result_df.columns) < 2: chart_path = None insights = "No results or not enough columns to display chart/insights." return result_df, intent, sql_query, chart_path, insights # Generate chart chart_path = os.path.join(PROJECT_ROOT, "chart.png") if not result_df.empty: plt.figure(figsize=(10, 6)) if chart_type == "auto": if intent == "trend": chart_type = "line" elif intent == "comparison": chart_type = "bar" else: chart_type = "bar" if chart_type == "bar": result_df.plot(kind="bar", x=result_df.columns[0], y=result_df.columns[1]) elif chart_type == "line": result_df.plot(kind="line", x=result_df.columns[0], y=result_df.columns[1], marker='o') elif chart_type == "pie": result_df.plot(kind="pie", y=result_df.columns[1], labels=result_df[result_df.columns[0]]) plt.title(question) plt.tight_layout() plt.savefig(chart_path) plt.close() else: chart_path = None # Generate insights insights = generate_insights(result_df, intent, question) return result_df, intent, sql_query, chart_path, insights except Exception as e: return None, "Error", str(e), None, f"Error: {str(e)}" def generate_insights(result_df, intent, question): if result_df is None or result_df.empty or len(result_df.columns) < 2: return "No data available for insights." insights = [] if intent == "summary": try: total = result_df[result_df.columns[1]].sum() insights.append(f"Total {result_df.columns[1]}: {total:,.2f}") except Exception: pass elif intent == "comparison": if len(result_df) >= 2: try: highest = result_df.iloc[0] lowest = result_df.iloc[-1] diff = ((highest.iloc[1] / lowest.iloc[1] - 1) * 100) insights.append(f"{highest.iloc[0]} is {diff:.1f}% higher than {lowest.iloc[0]}") except Exception: pass elif intent == "trend": if len(result_df) >= 2: try: first = result_df.iloc[0][result_df.columns[1]] last = result_df.iloc[-1][result_df.columns[1]] change = ((last / first - 1) * 100) insights.append(f"Overall change: {change:+.1f}%") except Exception: pass insights.append(f"Analysis covers {len(result_df)} records") if "category" in result_df.columns: insights.append(f"Number of categories: {result_df['category'].nunique()}") return "\n".join(f"• {insight}" for insight in insights) # Clickable FAQs (6 only) faqs = [ "What are the top 5 products by quantity sold?", "What is the total sales amount for each category?", "Which store had the highest total sales?", "What are the most popular payment methods?", "What is the sales trend over time?", "What is the average transaction value?" ] def fill_question(faq): return gr.update(value=faq) with gr.Blocks(title="RetailGenie - Natural Language to SQL") as demo: gr.Markdown(""" # RetailGenie - Natural Language to SQL Ask questions in natural language to generate SQL queries and visualizations. Using retail dataset with product sales information. """) with gr.Row(): with gr.Column(scale=1): question = gr.Textbox( label="Enter your question", placeholder="What is the total sales amount for each product category?" ) faq_radio = gr.Radio(faqs, label="FAQs (click to autofill)", interactive=True) faq_radio.change(fn=fill_question, inputs=faq_radio, outputs=question) chart_type = gr.Radio( ["auto", "bar", "line", "pie"], label="Chart Type", value="auto" ) submit_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): with gr.Accordion("SQL and Intent Details", open=False): intent_output = gr.Textbox(label="Predicted Intent") sql_output = gr.Textbox(label="Generated SQL", lines=3) results_df = gr.DataFrame(label="Query Results") chart_output = gr.Image(label="Chart") insights_output = gr.Textbox(label="Insights", lines=5) submit_btn.click( fn=process_query, inputs=[question, chart_type], outputs=[results_df, intent_output, sql_output, chart_output, insights_output] ) if __name__ == "__main__": demo.launch()