kkhushisaid's picture
Update app.py
6849134 verified
from dotenv import load_dotenv
import os
from sentence_transformers import SentenceTransformer
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
from groq import Groq
import pandas as pd
load_dotenv()
groq_api_key = os.getenv("groq_api_key")
# Use the current directory for Hugging Face Spaces
dataset_folder = "./data" # Assuming files are in a 'data/' folder
# Verify the folder exists
if not os.path.exists(dataset_folder):
print(f"Warning: Dataset folder '{dataset_folder}' not found. Using current directory instead.")
dataset_folder = "." # Fallback: Look in the current directory
# Print available files for debugging
print("Available files:", os.listdir(dataset_folder))
import warnings
# Ignore DtypeWarning
warnings.simplefilter("ignore", category=pd.errors.DtypeWarning)
# Load all CSV files in the dataset folder
dataframes = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"): # Check if the file is a CSV
try:
# Read first few rows to identify column names
sample_df = pd.read_csv(
os.path.join(dataset_folder, file),
nrows=5, # Read only first 5 rows for column type inference
encoding="utf-8",
errors="replace" # Replace encoding errors with a placeholder
)
column_types = {col: str for col in sample_df.columns} # Force all columns to string
# Read the entire file with enforced column types
df = pd.read_csv(
os.path.join(dataset_folder, file),
dtype=column_types, # Apply enforced string types
low_memory=False, # Avoid chunk-based reading issues
encoding="utf-8",
errors="replace"
).fillna('') # Fill NaN values with empty strings
dataframes.append(df) # Append DataFrame to the list
except Exception as e:
print(f"Error reading {file}: {e}")
# Merge all CSV files into one DataFrame (only if there are valid files)
if dataframes:
full_data = pd.concat(dataframes, ignore_index=True)
else:
print("Warning: No valid CSV files found in the dataset folder.")
full_data = pd.DataFrame() # Create an empty DataFrame as a fallback
def load_dataset_metadata(dataset_folder):
"""Loads metadata from all CSV files in the dataset folder."""
dataframes = []
metadata_list = []
for file in os.listdir(dataset_folder):
if file.endswith(".csv"):
df = pd.read_csv(os.path.join(dataset_folder, file))
dataframes.append((file, df))
# Generate table metadata
columns = df.columns.tolist()
table_metadata = f"""
Table: {file.replace('.csv', '')}
Columns:
{', '.join(columns)}
"""
metadata_list.append(table_metadata)
return dataframes, metadata_list
def create_metadata_embeddings(metadata_list):
"""Creates embeddings for all table metadata."""
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode(metadata_list)
return embeddings, model
def find_best_fit(embeddings, model, user_query, metadata_list):
"""Finds the best matching table based on user query."""
query_embedding = model.encode([user_query])
similarities = cosine_similarity(query_embedding, embeddings)
best_match_index = similarities.argmax()
return metadata_list[best_match_index]
def create_prompt(user_query, table_metadata):
"""Generates a direct and structured SQL prompt with stricter formatting."""
system_prompt = f"""
You are an AI assistant that generates precise SQL queries based on user questions.
**Table Name & Columns:**
{table_metadata}
**User Query:**
{user_query}
**Output Format (STRICT):**
- Provide ONLY the SQL query.
- Do NOT include explanations, comments, or unnecessary text.
- Ensure the table and column names match exactly.
- If the query is impossible, return: "ERROR: Unable to generate query."
**Example Queries:**
- User: "Show all startups founded in 2020."
- AI Response: SELECT * FROM startups WHERE founded_year = 2020;
- User: "List the top 5 startups by total funding."
- AI Response: SELECT name, total_funding FROM startups ORDER BY total_funding DESC LIMIT 5;
"""
return system_prompt
def generate_sql_query(system_prompt):
"""Uses Groq API to generate an SQL query with better debugging."""
try:
client = Groq(api_key=groq_api_key)
chat_completion = client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt}],
model="llama3-70b-8192"
)
# Debug: Print entire response
print("πŸ” Full API Response:", chat_completion)
# Extract AI response
result = chat_completion.choices[0].message.content.strip()
print(f"βœ… AI Response: {result}") # Debugging
# Check if the response starts with "SELECT"
if result.lower().startswith("select"):
return result
else:
print("⚠️ AI did not generate a valid SQL query!")
return "⚠️ AI response is not a valid SQL query."
except Exception as e:
print(f"❌ API Error: {e}")
return "⚠️ API failed. Check logs."
def response(user_query, dataset_folder):
"""Processes the user query and returns an SQL query."""
dataframes, metadata_list = load_dataset_metadata(dataset_folder)
embeddings, model = create_metadata_embeddings(metadata_list)
table_metadata = find_best_fit(embeddings, model, user_query, metadata_list)
system_prompt = create_prompt(user_query, table_metadata)
return generate_sql_query(system_prompt)
dataset_folder = "./data" # Change this based on where your files are uploaded
user_query = "Show me the top 10 startups with the highest funding."
def sql_query_interface(user_query):
return response(user_query, dataset_folder)
# Define Gradio UI
iface = gr.Interface(
fn=sql_query_interface,
inputs=gr.Textbox(label="Enter your query"),
outputs=gr.Textbox(label="Generated SQL Query"),
title="AI-Powered SQL Query Generator"
)
# Run Gradio app
if __name__ == "__main__":
iface.launch()