RetailGenie / app.py
shubh7's picture
Adding application file
5f946b0
import gradio as gr
import os
import sys
import pandas as pd
import sqlite3
from pathlib import Path
import matplotlib.pyplot as plt
import re
# For Hugging Face Spaces, set project root to current directory
PROJECT_ROOT = Path(__file__).parent.resolve()
sys.path.append(str(PROJECT_ROOT))
# Import model loading and utility functions
from code.train_sqlgen_t5_local import load_model as load_sql_model, generate_sql, get_schema_from_csv
from code.train_intent_classifier_local import load_model as load_intent_model, classify_intent
# Load models
sql_model, sql_tokenizer, device = load_sql_model()
intent_model, intent_tokenizer, device, label_mapping = load_intent_model()
# Path to the built-in data file in the data folder
DATA_FILE = str(PROJECT_ROOT / "data" / "testing_sql_data.csv")
# Verify data file exists
if not os.path.exists(DATA_FILE):
raise FileNotFoundError(f"Data file not found at {DATA_FILE}. Please ensure testing_sql_data.csv exists in the data folder.")
def process_query(question, chart_type="auto"):
try:
# Generate schema from CSV
schema = get_schema_from_csv(DATA_FILE)
# Generate SQL
sql_query = generate_sql(question, schema, sql_model, sql_tokenizer, device)
# --- Fix: Table and column name replacements ---
sql_query = re.sub(r'(FROM|JOIN)\s+\w+', r'\1 data', sql_query, flags=re.IGNORECASE)
sql_query = re.sub(r'(FROM|JOIN)\s+"[^"]+"', r'\1 data', sql_query, flags=re.IGNORECASE)
sql_query = re.sub(r'(FROM|JOIN)\s+\'[^"]+\'', r'\1 data', sql_query, flags=re.IGNORECASE)
sql_query = sql_query.replace('product_price', 'total_price')
sql_query = sql_query.replace('store_name', 'store_id')
sql_query = sql_query.replace('sales_method', 'date')
sql_query = re.sub(r'\bsales\b', 'total_price', sql_query)
# --- End fix ---
# Classify intent
intent = classify_intent(question, intent_model, intent_tokenizer, device, label_mapping)
# Execute SQL on the CSV data
df = pd.read_csv(DATA_FILE)
conn = sqlite3.connect(":memory:")
df.to_sql("data", conn, index=False, if_exists="replace")
result_df = pd.read_sql_query(sql_query, conn)
conn.close()
# Defensive check for result_df columns
if result_df.empty or len(result_df.columns) < 2:
chart_path = None
insights = "No results or not enough columns to display chart/insights."
return result_df, intent, sql_query, chart_path, insights
# Generate chart
chart_path = os.path.join(PROJECT_ROOT, "chart.png")
if not result_df.empty:
plt.figure(figsize=(10, 6))
if chart_type == "auto":
if intent == "trend":
chart_type = "line"
elif intent == "comparison":
chart_type = "bar"
else:
chart_type = "bar"
if chart_type == "bar":
result_df.plot(kind="bar", x=result_df.columns[0], y=result_df.columns[1])
elif chart_type == "line":
result_df.plot(kind="line", x=result_df.columns[0], y=result_df.columns[1], marker='o')
elif chart_type == "pie":
result_df.plot(kind="pie", y=result_df.columns[1], labels=result_df[result_df.columns[0]])
plt.title(question)
plt.tight_layout()
plt.savefig(chart_path)
plt.close()
else:
chart_path = None
# Generate insights
insights = generate_insights(result_df, intent, question)
return result_df, intent, sql_query, chart_path, insights
except Exception as e:
return None, "Error", str(e), None, f"Error: {str(e)}"
def generate_insights(result_df, intent, question):
if result_df is None or result_df.empty or len(result_df.columns) < 2:
return "No data available for insights."
insights = []
if intent == "summary":
try:
total = result_df[result_df.columns[1]].sum()
insights.append(f"Total {result_df.columns[1]}: {total:,.2f}")
except Exception:
pass
elif intent == "comparison":
if len(result_df) >= 2:
try:
highest = result_df.iloc[0]
lowest = result_df.iloc[-1]
diff = ((highest.iloc[1] / lowest.iloc[1] - 1) * 100)
insights.append(f"{highest.iloc[0]} is {diff:.1f}% higher than {lowest.iloc[0]}")
except Exception:
pass
elif intent == "trend":
if len(result_df) >= 2:
try:
first = result_df.iloc[0][result_df.columns[1]]
last = result_df.iloc[-1][result_df.columns[1]]
change = ((last / first - 1) * 100)
insights.append(f"Overall change: {change:+.1f}%")
except Exception:
pass
insights.append(f"Analysis covers {len(result_df)} records")
if "category" in result_df.columns:
insights.append(f"Number of categories: {result_df['category'].nunique()}")
return "\n".join(f"• {insight}" for insight in insights)
# Clickable FAQs (6 only)
faqs = [
"What are the top 5 products by quantity sold?",
"What is the total sales amount for each category?",
"Which store had the highest total sales?",
"What are the most popular payment methods?",
"What is the sales trend over time?",
"What is the average transaction value?"
]
def fill_question(faq):
return gr.update(value=faq)
with gr.Blocks(title="RetailGenie - Natural Language to SQL") as demo:
gr.Markdown("""
# RetailGenie - Natural Language to SQL
Ask questions in natural language to generate SQL queries and visualizations. Using retail dataset with product sales information.
""")
with gr.Row():
with gr.Column(scale=1):
question = gr.Textbox(
label="Enter your question",
placeholder="What is the total sales amount for each product category?"
)
faq_radio = gr.Radio(faqs, label="FAQs (click to autofill)", interactive=True)
faq_radio.change(fn=fill_question, inputs=faq_radio, outputs=question)
chart_type = gr.Radio(
["auto", "bar", "line", "pie"],
label="Chart Type",
value="auto"
)
submit_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
with gr.Accordion("SQL and Intent Details", open=False):
intent_output = gr.Textbox(label="Predicted Intent")
sql_output = gr.Textbox(label="Generated SQL", lines=3)
results_df = gr.DataFrame(label="Query Results")
chart_output = gr.Image(label="Chart")
insights_output = gr.Textbox(label="Insights", lines=5)
submit_btn.click(
fn=process_query,
inputs=[question, chart_type],
outputs=[results_df, intent_output, sql_output, chart_output, insights_output]
)
if __name__ == "__main__":
demo.launch()