Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import sys | |
import pandas as pd | |
import sqlite3 | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import re | |
# For Hugging Face Spaces, set project root to current directory | |
PROJECT_ROOT = Path(__file__).parent.resolve() | |
sys.path.append(str(PROJECT_ROOT)) | |
# Import model loading and utility functions | |
from code.train_sqlgen_t5_local import load_model as load_sql_model, generate_sql, get_schema_from_csv | |
from code.train_intent_classifier_local import load_model as load_intent_model, classify_intent | |
# Load models | |
sql_model, sql_tokenizer, device = load_sql_model() | |
intent_model, intent_tokenizer, device, label_mapping = load_intent_model() | |
# Path to the built-in data file in the data folder | |
DATA_FILE = str(PROJECT_ROOT / "data" / "testing_sql_data.csv") | |
# Verify data file exists | |
if not os.path.exists(DATA_FILE): | |
raise FileNotFoundError(f"Data file not found at {DATA_FILE}. Please ensure testing_sql_data.csv exists in the data folder.") | |
def process_query(question, chart_type="auto"): | |
try: | |
# Generate schema from CSV | |
schema = get_schema_from_csv(DATA_FILE) | |
# Generate SQL | |
sql_query = generate_sql(question, schema, sql_model, sql_tokenizer, device) | |
# --- Fix: Table and column name replacements --- | |
sql_query = re.sub(r'(FROM|JOIN)\s+\w+', r'\1 data', sql_query, flags=re.IGNORECASE) | |
sql_query = re.sub(r'(FROM|JOIN)\s+"[^"]+"', r'\1 data', sql_query, flags=re.IGNORECASE) | |
sql_query = re.sub(r'(FROM|JOIN)\s+\'[^"]+\'', r'\1 data', sql_query, flags=re.IGNORECASE) | |
sql_query = sql_query.replace('product_price', 'total_price') | |
sql_query = sql_query.replace('store_name', 'store_id') | |
sql_query = sql_query.replace('sales_method', 'date') | |
sql_query = re.sub(r'\bsales\b', 'total_price', sql_query) | |
# --- End fix --- | |
# Classify intent | |
intent = classify_intent(question, intent_model, intent_tokenizer, device, label_mapping) | |
# Execute SQL on the CSV data | |
df = pd.read_csv(DATA_FILE) | |
conn = sqlite3.connect(":memory:") | |
df.to_sql("data", conn, index=False, if_exists="replace") | |
result_df = pd.read_sql_query(sql_query, conn) | |
conn.close() | |
# Defensive check for result_df columns | |
if result_df.empty or len(result_df.columns) < 2: | |
chart_path = None | |
insights = "No results or not enough columns to display chart/insights." | |
return result_df, intent, sql_query, chart_path, insights | |
# Generate chart | |
chart_path = os.path.join(PROJECT_ROOT, "chart.png") | |
if not result_df.empty: | |
plt.figure(figsize=(10, 6)) | |
if chart_type == "auto": | |
if intent == "trend": | |
chart_type = "line" | |
elif intent == "comparison": | |
chart_type = "bar" | |
else: | |
chart_type = "bar" | |
if chart_type == "bar": | |
result_df.plot(kind="bar", x=result_df.columns[0], y=result_df.columns[1]) | |
elif chart_type == "line": | |
result_df.plot(kind="line", x=result_df.columns[0], y=result_df.columns[1], marker='o') | |
elif chart_type == "pie": | |
result_df.plot(kind="pie", y=result_df.columns[1], labels=result_df[result_df.columns[0]]) | |
plt.title(question) | |
plt.tight_layout() | |
plt.savefig(chart_path) | |
plt.close() | |
else: | |
chart_path = None | |
# Generate insights | |
insights = generate_insights(result_df, intent, question) | |
return result_df, intent, sql_query, chart_path, insights | |
except Exception as e: | |
return None, "Error", str(e), None, f"Error: {str(e)}" | |
def generate_insights(result_df, intent, question): | |
if result_df is None or result_df.empty or len(result_df.columns) < 2: | |
return "No data available for insights." | |
insights = [] | |
if intent == "summary": | |
try: | |
total = result_df[result_df.columns[1]].sum() | |
insights.append(f"Total {result_df.columns[1]}: {total:,.2f}") | |
except Exception: | |
pass | |
elif intent == "comparison": | |
if len(result_df) >= 2: | |
try: | |
highest = result_df.iloc[0] | |
lowest = result_df.iloc[-1] | |
diff = ((highest.iloc[1] / lowest.iloc[1] - 1) * 100) | |
insights.append(f"{highest.iloc[0]} is {diff:.1f}% higher than {lowest.iloc[0]}") | |
except Exception: | |
pass | |
elif intent == "trend": | |
if len(result_df) >= 2: | |
try: | |
first = result_df.iloc[0][result_df.columns[1]] | |
last = result_df.iloc[-1][result_df.columns[1]] | |
change = ((last / first - 1) * 100) | |
insights.append(f"Overall change: {change:+.1f}%") | |
except Exception: | |
pass | |
insights.append(f"Analysis covers {len(result_df)} records") | |
if "category" in result_df.columns: | |
insights.append(f"Number of categories: {result_df['category'].nunique()}") | |
return "\n".join(f"• {insight}" for insight in insights) | |
# Clickable FAQs (6 only) | |
faqs = [ | |
"What are the top 5 products by quantity sold?", | |
"What is the total sales amount for each category?", | |
"Which store had the highest total sales?", | |
"What are the most popular payment methods?", | |
"What is the sales trend over time?", | |
"What is the average transaction value?" | |
] | |
def fill_question(faq): | |
return gr.update(value=faq) | |
with gr.Blocks(title="RetailGenie - Natural Language to SQL") as demo: | |
gr.Markdown(""" | |
# RetailGenie - Natural Language to SQL | |
Ask questions in natural language to generate SQL queries and visualizations. Using retail dataset with product sales information. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
question = gr.Textbox( | |
label="Enter your question", | |
placeholder="What is the total sales amount for each product category?" | |
) | |
faq_radio = gr.Radio(faqs, label="FAQs (click to autofill)", interactive=True) | |
faq_radio.change(fn=fill_question, inputs=faq_radio, outputs=question) | |
chart_type = gr.Radio( | |
["auto", "bar", "line", "pie"], | |
label="Chart Type", | |
value="auto" | |
) | |
submit_btn = gr.Button("Generate", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Accordion("SQL and Intent Details", open=False): | |
intent_output = gr.Textbox(label="Predicted Intent") | |
sql_output = gr.Textbox(label="Generated SQL", lines=3) | |
results_df = gr.DataFrame(label="Query Results") | |
chart_output = gr.Image(label="Chart") | |
insights_output = gr.Textbox(label="Insights", lines=5) | |
submit_btn.click( | |
fn=process_query, | |
inputs=[question, chart_type], | |
outputs=[results_df, intent_output, sql_output, chart_output, insights_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |