Spaces:
Sleeping
Sleeping
Adding application file
Browse files- app.py +165 -0
- code/__init__.py +1 -0
- code/__pycache__/__init__.cpython-310.pyc +0 -0
- code/__pycache__/train_intent_classifier_local.cpython-310.pyc +0 -0
- code/__pycache__/train_sqlgen_t5_local.cpython-310.pyc +0 -0
- code/cloud_train_intent_classifier_script.py +77 -0
- code/cloud_train_sqlgen_t5_script.py +72 -0
- code/train_intent_classifier_local.py +84 -0
- code/train_sqlgen_t5_local.py +76 -0
- data/retail_dataset.csv +11 -0
- data/retail_schema.sql +30 -0
- data/testing_sql_data.csv +21 -0
- model_intent_classifier/config.json +38 -0
- model_intent_classifier/label_mapping.json +1 -0
- model_intent_classifier/pytorch_model.bin +3 -0
- model_intent_classifier/special_tokens_map.json +7 -0
- model_intent_classifier/tokenizer.json +0 -0
- model_intent_classifier/tokenizer_config.json +14 -0
- model_intent_classifier/vocab.txt +0 -0
- model_sqlgen_t5/config.json +61 -0
- model_sqlgen_t5/pytorch_model.bin +3 -0
- model_sqlgen_t5/special_tokens_map.json +107 -0
- model_sqlgen_t5/spiece.model +3 -0
- model_sqlgen_t5/tokenizer_config.json +114 -0
- requirements.txt +15 -0
app.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import pandas as pd
|
5 |
+
import sqlite3
|
6 |
+
from pathlib import Path
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import re
|
9 |
+
|
10 |
+
# For Hugging Face Spaces, set project root to current directory
|
11 |
+
PROJECT_ROOT = Path(__file__).parent.resolve()
|
12 |
+
sys.path.append(str(PROJECT_ROOT))
|
13 |
+
|
14 |
+
# Import model loading and utility functions
|
15 |
+
from code.train_sqlgen_t5_local import load_model as load_sql_model, generate_sql, get_schema_from_csv
|
16 |
+
from code.train_intent_classifier_local import load_model as load_intent_model, classify_intent
|
17 |
+
|
18 |
+
# Load models
|
19 |
+
sql_model, sql_tokenizer, device = load_sql_model()
|
20 |
+
intent_model, intent_tokenizer, device, label_mapping = load_intent_model()
|
21 |
+
|
22 |
+
# Path to the built-in data file in the data folder
|
23 |
+
DATA_FILE = str(PROJECT_ROOT / "data" / "testing_sql_data.csv")
|
24 |
+
|
25 |
+
# Verify data file exists
|
26 |
+
if not os.path.exists(DATA_FILE):
|
27 |
+
raise FileNotFoundError(f"Data file not found at {DATA_FILE}. Please ensure testing_sql_data.csv exists in the data folder.")
|
28 |
+
|
29 |
+
def process_query(question, chart_type="auto"):
|
30 |
+
try:
|
31 |
+
# Generate schema from CSV
|
32 |
+
schema = get_schema_from_csv(DATA_FILE)
|
33 |
+
# Generate SQL
|
34 |
+
sql_query = generate_sql(question, schema, sql_model, sql_tokenizer, device)
|
35 |
+
# --- Fix: Table and column name replacements ---
|
36 |
+
sql_query = re.sub(r'(FROM|JOIN)\s+\w+', r'\1 data', sql_query, flags=re.IGNORECASE)
|
37 |
+
sql_query = re.sub(r'(FROM|JOIN)\s+"[^"]+"', r'\1 data', sql_query, flags=re.IGNORECASE)
|
38 |
+
sql_query = re.sub(r'(FROM|JOIN)\s+\'[^"]+\'', r'\1 data', sql_query, flags=re.IGNORECASE)
|
39 |
+
sql_query = sql_query.replace('product_price', 'total_price')
|
40 |
+
sql_query = sql_query.replace('store_name', 'store_id')
|
41 |
+
sql_query = sql_query.replace('sales_method', 'date')
|
42 |
+
sql_query = re.sub(r'\bsales\b', 'total_price', sql_query)
|
43 |
+
# --- End fix ---
|
44 |
+
# Classify intent
|
45 |
+
intent = classify_intent(question, intent_model, intent_tokenizer, device, label_mapping)
|
46 |
+
# Execute SQL on the CSV data
|
47 |
+
df = pd.read_csv(DATA_FILE)
|
48 |
+
conn = sqlite3.connect(":memory:")
|
49 |
+
df.to_sql("data", conn, index=False, if_exists="replace")
|
50 |
+
result_df = pd.read_sql_query(sql_query, conn)
|
51 |
+
conn.close()
|
52 |
+
# Defensive check for result_df columns
|
53 |
+
if result_df.empty or len(result_df.columns) < 2:
|
54 |
+
chart_path = None
|
55 |
+
insights = "No results or not enough columns to display chart/insights."
|
56 |
+
return result_df, intent, sql_query, chart_path, insights
|
57 |
+
# Generate chart
|
58 |
+
chart_path = os.path.join(PROJECT_ROOT, "chart.png")
|
59 |
+
if not result_df.empty:
|
60 |
+
plt.figure(figsize=(10, 6))
|
61 |
+
if chart_type == "auto":
|
62 |
+
if intent == "trend":
|
63 |
+
chart_type = "line"
|
64 |
+
elif intent == "comparison":
|
65 |
+
chart_type = "bar"
|
66 |
+
else:
|
67 |
+
chart_type = "bar"
|
68 |
+
if chart_type == "bar":
|
69 |
+
result_df.plot(kind="bar", x=result_df.columns[0], y=result_df.columns[1])
|
70 |
+
elif chart_type == "line":
|
71 |
+
result_df.plot(kind="line", x=result_df.columns[0], y=result_df.columns[1], marker='o')
|
72 |
+
elif chart_type == "pie":
|
73 |
+
result_df.plot(kind="pie", y=result_df.columns[1], labels=result_df[result_df.columns[0]])
|
74 |
+
plt.title(question)
|
75 |
+
plt.tight_layout()
|
76 |
+
plt.savefig(chart_path)
|
77 |
+
plt.close()
|
78 |
+
else:
|
79 |
+
chart_path = None
|
80 |
+
# Generate insights
|
81 |
+
insights = generate_insights(result_df, intent, question)
|
82 |
+
return result_df, intent, sql_query, chart_path, insights
|
83 |
+
except Exception as e:
|
84 |
+
return None, "Error", str(e), None, f"Error: {str(e)}"
|
85 |
+
|
86 |
+
def generate_insights(result_df, intent, question):
|
87 |
+
if result_df is None or result_df.empty or len(result_df.columns) < 2:
|
88 |
+
return "No data available for insights."
|
89 |
+
insights = []
|
90 |
+
if intent == "summary":
|
91 |
+
try:
|
92 |
+
total = result_df[result_df.columns[1]].sum()
|
93 |
+
insights.append(f"Total {result_df.columns[1]}: {total:,.2f}")
|
94 |
+
except Exception:
|
95 |
+
pass
|
96 |
+
elif intent == "comparison":
|
97 |
+
if len(result_df) >= 2:
|
98 |
+
try:
|
99 |
+
highest = result_df.iloc[0]
|
100 |
+
lowest = result_df.iloc[-1]
|
101 |
+
diff = ((highest.iloc[1] / lowest.iloc[1] - 1) * 100)
|
102 |
+
insights.append(f"{highest.iloc[0]} is {diff:.1f}% higher than {lowest.iloc[0]}")
|
103 |
+
except Exception:
|
104 |
+
pass
|
105 |
+
elif intent == "trend":
|
106 |
+
if len(result_df) >= 2:
|
107 |
+
try:
|
108 |
+
first = result_df.iloc[0][result_df.columns[1]]
|
109 |
+
last = result_df.iloc[-1][result_df.columns[1]]
|
110 |
+
change = ((last / first - 1) * 100)
|
111 |
+
insights.append(f"Overall change: {change:+.1f}%")
|
112 |
+
except Exception:
|
113 |
+
pass
|
114 |
+
insights.append(f"Analysis covers {len(result_df)} records")
|
115 |
+
if "category" in result_df.columns:
|
116 |
+
insights.append(f"Number of categories: {result_df['category'].nunique()}")
|
117 |
+
return "\n".join(f"• {insight}" for insight in insights)
|
118 |
+
|
119 |
+
# Clickable FAQs (6 only)
|
120 |
+
faqs = [
|
121 |
+
"What are the top 5 products by quantity sold?",
|
122 |
+
"What is the total sales amount for each category?",
|
123 |
+
"Which store had the highest total sales?",
|
124 |
+
"What are the most popular payment methods?",
|
125 |
+
"What is the sales trend over time?",
|
126 |
+
"What is the average transaction value?"
|
127 |
+
]
|
128 |
+
|
129 |
+
def fill_question(faq):
|
130 |
+
return gr.update(value=faq)
|
131 |
+
|
132 |
+
with gr.Blocks(title="RetailGenie - Natural Language to SQL") as demo:
|
133 |
+
gr.Markdown("""
|
134 |
+
# RetailGenie - Natural Language to SQL
|
135 |
+
Ask questions in natural language to generate SQL queries and visualizations. Using retail dataset with product sales information.
|
136 |
+
""")
|
137 |
+
with gr.Row():
|
138 |
+
with gr.Column(scale=1):
|
139 |
+
question = gr.Textbox(
|
140 |
+
label="Enter your question",
|
141 |
+
placeholder="What is the total sales amount for each product category?"
|
142 |
+
)
|
143 |
+
faq_radio = gr.Radio(faqs, label="FAQs (click to autofill)", interactive=True)
|
144 |
+
faq_radio.change(fn=fill_question, inputs=faq_radio, outputs=question)
|
145 |
+
chart_type = gr.Radio(
|
146 |
+
["auto", "bar", "line", "pie"],
|
147 |
+
label="Chart Type",
|
148 |
+
value="auto"
|
149 |
+
)
|
150 |
+
submit_btn = gr.Button("Generate", variant="primary")
|
151 |
+
with gr.Column(scale=2):
|
152 |
+
with gr.Accordion("SQL and Intent Details", open=False):
|
153 |
+
intent_output = gr.Textbox(label="Predicted Intent")
|
154 |
+
sql_output = gr.Textbox(label="Generated SQL", lines=3)
|
155 |
+
results_df = gr.DataFrame(label="Query Results")
|
156 |
+
chart_output = gr.Image(label="Chart")
|
157 |
+
insights_output = gr.Textbox(label="Insights", lines=5)
|
158 |
+
submit_btn.click(
|
159 |
+
fn=process_query,
|
160 |
+
inputs=[question, chart_type],
|
161 |
+
outputs=[results_df, intent_output, sql_output, chart_output, insights_output]
|
162 |
+
)
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
demo.launch()
|
code/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# This file makes the code directory a Python package
|
code/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (124 Bytes). View file
|
|
code/__pycache__/train_intent_classifier_local.cpython-310.pyc
ADDED
Binary file (2.51 kB). View file
|
|
code/__pycache__/train_sqlgen_t5_local.cpython-310.pyc
ADDED
Binary file (2.41 kB). View file
|
|
code/cloud_train_intent_classifier_script.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
import json
|
7 |
+
from google.cloud import storage
|
8 |
+
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
|
9 |
+
from datasets import Dataset
|
10 |
+
from sklearn.preprocessing import LabelEncoder
|
11 |
+
import torch
|
12 |
+
|
13 |
+
# CLI arguments
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
16 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
# Load dataset
|
20 |
+
print("📦 Loading dataset from:", args.dataset_path)
|
21 |
+
df = pd.read_csv(args.dataset_path)
|
22 |
+
df = df[["question", "intent"]]
|
23 |
+
|
24 |
+
# Label encoding
|
25 |
+
le = LabelEncoder()
|
26 |
+
df["label"] = le.fit_transform(df["intent"])
|
27 |
+
label_mapping = dict(zip(le.classes_, le.transform(le.classes_)))
|
28 |
+
dataset = Dataset.from_pandas(df)
|
29 |
+
|
30 |
+
# Tokenizer and model
|
31 |
+
model_name = "distilbert-base-uncased"
|
32 |
+
tokenizer = DistilBERTTokenizerFast.from_pretrained(model_name)
|
33 |
+
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=len(label_mapping))
|
34 |
+
|
35 |
+
def tokenize(example):
|
36 |
+
return tokenizer(example["question"], truncation=True, padding="max_length", max_length=128)
|
37 |
+
|
38 |
+
dataset = dataset.map(tokenize)
|
39 |
+
|
40 |
+
training_args = TrainingArguments(
|
41 |
+
output_dir="./results_intent_classifier",
|
42 |
+
per_device_train_batch_size=4,
|
43 |
+
num_train_epochs=10,
|
44 |
+
logging_dir="./logs_intent",
|
45 |
+
logging_steps=5,
|
46 |
+
save_strategy="epoch",
|
47 |
+
evaluation_strategy="no"
|
48 |
+
)
|
49 |
+
|
50 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
|
51 |
+
trainer.train()
|
52 |
+
|
53 |
+
# Save to temp dir
|
54 |
+
local_dir = tempfile.mkdtemp()
|
55 |
+
model.save_pretrained(local_dir)
|
56 |
+
tokenizer.save_pretrained(local_dir)
|
57 |
+
|
58 |
+
with open(os.path.join(local_dir, "label_mapping.json"), "w") as f:
|
59 |
+
json.dump(label_mapping, f)
|
60 |
+
|
61 |
+
# Upload to GCS
|
62 |
+
gcs_model_path = os.path.join(args.output_dir, "intent")
|
63 |
+
bucket_name = gcs_model_path.split("/")[2]
|
64 |
+
base_path = "/".join(gcs_model_path.split("/")[3:])
|
65 |
+
|
66 |
+
client = storage.Client()
|
67 |
+
|
68 |
+
for fname in os.listdir(local_dir):
|
69 |
+
local_path = os.path.join(local_dir, fname)
|
70 |
+
gcs_blob_path = os.path.join(base_path, fname)
|
71 |
+
|
72 |
+
print(f"⬆️ Uploading {fname} to gs://{bucket_name}/{gcs_blob_path}")
|
73 |
+
bucket = client.bucket(bucket_name)
|
74 |
+
blob = bucket.blob(gcs_blob_path)
|
75 |
+
blob.upload_from_filename(local_path)
|
76 |
+
|
77 |
+
print(f"✅ Intent model successfully uploaded to gs://{bucket_name}/{base_path}")
|
code/cloud_train_sqlgen_t5_script.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
from google.cloud import storage
|
7 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
|
8 |
+
from datasets import Dataset
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# CLI arguments
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--dataset_path", type=str, required=True)
|
14 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
print("📦 Loading dataset from:", args.dataset_path)
|
18 |
+
df = pd.read_csv(args.dataset_path)
|
19 |
+
df = df[["question", "sql"]].rename(columns={"question": "input_text", "sql": "target_text"})
|
20 |
+
df["input_text"] = "translate question to SQL: " + df["input_text"]
|
21 |
+
dataset = Dataset.from_pandas(df)
|
22 |
+
|
23 |
+
# Load tokenizer and model
|
24 |
+
model_name = "t5-small"
|
25 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
26 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
27 |
+
|
28 |
+
def preprocess(example):
|
29 |
+
input_enc = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=128)
|
30 |
+
target_enc = tokenizer(example["target_text"], truncation=True, padding="max_length", max_length=128)
|
31 |
+
input_enc["labels"] = target_enc["input_ids"]
|
32 |
+
return input_enc
|
33 |
+
|
34 |
+
tokenized_dataset = dataset.map(preprocess)
|
35 |
+
|
36 |
+
# Training arguments
|
37 |
+
training_args = TrainingArguments(
|
38 |
+
output_dir="./results_t5_sqlgen",
|
39 |
+
per_device_train_batch_size=4,
|
40 |
+
num_train_epochs=10,
|
41 |
+
logging_dir="./logs",
|
42 |
+
logging_steps=5,
|
43 |
+
save_strategy="epoch",
|
44 |
+
evaluation_strategy="no"
|
45 |
+
)
|
46 |
+
|
47 |
+
# Train model
|
48 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset)
|
49 |
+
trainer.train()
|
50 |
+
|
51 |
+
# Save model to temporary local directory
|
52 |
+
local_dir = tempfile.mkdtemp()
|
53 |
+
model.save_pretrained(local_dir)
|
54 |
+
tokenizer.save_pretrained(local_dir)
|
55 |
+
|
56 |
+
# Upload all files to GCS
|
57 |
+
gcs_model_path = os.path.join(args.output_dir, "sqlgen")
|
58 |
+
bucket_name = gcs_model_path.split("/")[2]
|
59 |
+
base_path = "/".join(gcs_model_path.split("/")[3:])
|
60 |
+
|
61 |
+
client = storage.Client()
|
62 |
+
|
63 |
+
for fname in os.listdir(local_dir):
|
64 |
+
local_path = os.path.join(local_dir, fname)
|
65 |
+
gcs_blob_path = os.path.join(base_path, fname)
|
66 |
+
|
67 |
+
print(f"⬆️ Uploading {fname} to gs://{bucket_name}/{gcs_blob_path}")
|
68 |
+
bucket = client.bucket(bucket_name)
|
69 |
+
blob = bucket.blob(gcs_blob_path)
|
70 |
+
blob.upload_from_filename(local_path)
|
71 |
+
|
72 |
+
print(f"✅ Model successfully uploaded to gs://{bucket_name}/{base_path}")
|
code/train_intent_classifier_local.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
|
6 |
+
# Get project root directory
|
7 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
8 |
+
|
9 |
+
def load_model():
|
10 |
+
print("📦 Loading pre-trained intent classification model...")
|
11 |
+
model_name = "distilbert-base-uncased"
|
12 |
+
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
# Define intent labels
|
16 |
+
intent_labels = [
|
17 |
+
"summary", "comparison", "trend", "anomaly", "forecast"
|
18 |
+
]
|
19 |
+
num_labels = len(intent_labels)
|
20 |
+
|
21 |
+
# Create label mapping
|
22 |
+
label_mapping = {label: idx for idx, label in enumerate(intent_labels)}
|
23 |
+
|
24 |
+
# Load model with our number of labels
|
25 |
+
model = DistilBertForSequenceClassification.from_pretrained(
|
26 |
+
model_name, num_labels=num_labels
|
27 |
+
)
|
28 |
+
model = model.to(device)
|
29 |
+
model.eval()
|
30 |
+
|
31 |
+
return model, tokenizer, device, label_mapping
|
32 |
+
|
33 |
+
def classify_intent(question, model, tokenizer, device, label_mapping):
|
34 |
+
# Tokenize input
|
35 |
+
inputs = tokenizer(
|
36 |
+
question,
|
37 |
+
return_tensors="pt",
|
38 |
+
truncation=True,
|
39 |
+
padding=True,
|
40 |
+
max_length=128
|
41 |
+
).to(device)
|
42 |
+
|
43 |
+
# Get prediction
|
44 |
+
with torch.no_grad():
|
45 |
+
outputs = model(**inputs)
|
46 |
+
predicted_class_id = outputs.logits.argmax().item()
|
47 |
+
|
48 |
+
# Convert back to label
|
49 |
+
id2label = {v: k for k, v in label_mapping.items()}
|
50 |
+
intent = id2label[predicted_class_id]
|
51 |
+
|
52 |
+
return intent
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
# Load the model
|
56 |
+
model, tokenizer, device, label_mapping = load_model()
|
57 |
+
|
58 |
+
# Save the model and label mapping
|
59 |
+
output_dir = os.path.join(PROJECT_ROOT, "model_intent_classifier")
|
60 |
+
print(f"💾 Saving model to {output_dir}")
|
61 |
+
os.makedirs(output_dir, exist_ok=True)
|
62 |
+
model.save_pretrained(output_dir)
|
63 |
+
tokenizer.save_pretrained(output_dir)
|
64 |
+
|
65 |
+
# Save label mapping
|
66 |
+
with open(os.path.join(output_dir, "label_mapping.json"), "w") as f:
|
67 |
+
json.dump(label_mapping, f)
|
68 |
+
|
69 |
+
print(f"✅ Model successfully saved to {output_dir}")
|
70 |
+
|
71 |
+
# Example usage
|
72 |
+
test_questions = [
|
73 |
+
"What is the total sales amount for each product category?",
|
74 |
+
"Compare sales between March and April",
|
75 |
+
"Show me the sales trend over the last 6 months",
|
76 |
+
"Which products have unusual sales patterns?",
|
77 |
+
"What will be the sales forecast for next month?"
|
78 |
+
]
|
79 |
+
|
80 |
+
print("\nTesting intent classification:")
|
81 |
+
for question in test_questions:
|
82 |
+
intent = classify_intent(question, model, tokenizer, device, label_mapping)
|
83 |
+
print(f"Question: {question}")
|
84 |
+
print(f"Predicted intent: {intent}\n")
|
code/train_sqlgen_t5_local.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
3 |
+
import os
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
# Get project root directory
|
7 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
8 |
+
|
9 |
+
def load_model():
|
10 |
+
print("📦 Loading pre-trained text-to-SQL model...")
|
11 |
+
model_name = "cssupport/t5-small-awesome-text-to-sql"
|
12 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
15 |
+
model = model.to(device)
|
16 |
+
model.eval()
|
17 |
+
return model, tokenizer, device
|
18 |
+
|
19 |
+
def generate_sql(question, schema, model, tokenizer, device):
|
20 |
+
# Format input as expected by the model
|
21 |
+
input_prompt = f"tables:\n{schema}\nquery for: {question}"
|
22 |
+
|
23 |
+
# Tokenize the input prompt
|
24 |
+
inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
|
25 |
+
|
26 |
+
# Generate SQL
|
27 |
+
with torch.no_grad():
|
28 |
+
outputs = model.generate(**inputs, max_length=512)
|
29 |
+
|
30 |
+
# Decode the output
|
31 |
+
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
32 |
+
return generated_sql
|
33 |
+
|
34 |
+
def get_schema_from_csv(csv_path):
|
35 |
+
"""Generate CREATE TABLE statements from CSV file"""
|
36 |
+
df = pd.read_csv(csv_path)
|
37 |
+
columns = []
|
38 |
+
for col in df.columns:
|
39 |
+
# Infer column type
|
40 |
+
dtype = df[col].dtype
|
41 |
+
if dtype == 'int64':
|
42 |
+
col_type = 'INT'
|
43 |
+
elif dtype == 'float64':
|
44 |
+
col_type = 'DECIMAL(10,2)'
|
45 |
+
else:
|
46 |
+
col_type = 'VARCHAR(255)'
|
47 |
+
columns.append(f"{col} {col_type}")
|
48 |
+
|
49 |
+
table_name = os.path.splitext(os.path.basename(csv_path))[0]
|
50 |
+
create_table = f"CREATE TABLE {table_name} (\n " + ",\n ".join(columns) + "\n);"
|
51 |
+
return create_table
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
# Load the pre-trained model
|
55 |
+
model, tokenizer, device = load_model()
|
56 |
+
|
57 |
+
# Save the model locally for future use
|
58 |
+
output_dir = os.path.join(PROJECT_ROOT, "model_sqlgen_t5")
|
59 |
+
print(f"💾 Saving model to {output_dir}")
|
60 |
+
os.makedirs(output_dir, exist_ok=True)
|
61 |
+
model.save_pretrained(output_dir)
|
62 |
+
tokenizer.save_pretrained(output_dir)
|
63 |
+
print(f"✅ Model successfully saved to {output_dir}")
|
64 |
+
|
65 |
+
# Example usage with CSV
|
66 |
+
csv_path = os.path.join(PROJECT_ROOT, "data", "retail_dataset.csv")
|
67 |
+
if os.path.exists(csv_path):
|
68 |
+
schema = get_schema_from_csv(csv_path)
|
69 |
+
print("\nGenerated schema from CSV:")
|
70 |
+
print(schema)
|
71 |
+
|
72 |
+
question = "What is the total sales amount for each product category?"
|
73 |
+
sql_query = generate_sql(question, schema, model, tokenizer, device)
|
74 |
+
print("\nExample usage:")
|
75 |
+
print(f"Question: {question}")
|
76 |
+
print(f"Generated SQL: {sql_query}")
|
data/retail_dataset.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question,intent,sql
|
2 |
+
What are the top 5 selling products in March?,summary,"SELECT product_name, SUM(quantity) AS total_sold FROM transactions WHERE MONTH(date) = 3 GROUP BY product_name ORDER BY total_sold DESC LIMIT 5;"
|
3 |
+
Which store had the highest revenue in April?,summary,"SELECT store_id, SUM(total_price) AS revenue FROM transactions WHERE MONTH(date) = 4 GROUP BY store_id ORDER BY revenue DESC LIMIT 1;"
|
4 |
+
Compare returns between electronics and clothing in Q1.,comparison,"SELECT category, COUNT(*) AS return_count FROM returns WHERE category IN ('electronics', 'clothing') AND QUARTER(date) = 1 GROUP BY category;"
|
5 |
+
Which products saw a sales drop compared to last month?,anomaly,"SELECT t1.product_id, t1.month, t1.sales, t2.sales AS last_month_sales FROM monthly_sales t1 JOIN monthly_sales t2 ON t1.product_id = t2.product_id AND t1.month = t2.month + 1 WHERE t1.sales < t2.sales;"
|
6 |
+
Show the sales trend of iPhone 14 in the last 6 months.,trend,"SELECT MONTH(date) AS month, SUM(quantity) AS total_sales FROM transactions WHERE product_name = 'iPhone 14' AND date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) GROUP BY month ORDER BY month;"
|
7 |
+
Which category has the highest number of returns?,summary,"SELECT category, COUNT(*) AS total_returns FROM returns GROUP BY category ORDER BY total_returns DESC LIMIT 1;"
|
8 |
+
What is the total sales for each product category?,summary,"SELECT category, SUM(total_price) AS total_sales FROM transactions GROUP BY category;"
|
9 |
+
List the most returned products in the last month.,summary,"SELECT product_name, COUNT(*) AS return_count FROM returns WHERE date >= DATE_SUB(CURDATE(), INTERVAL 1 MONTH) GROUP BY product_name ORDER BY return_count DESC;"
|
10 |
+
Which store had the lowest performance in Q2?,summary,"SELECT store_id, SUM(total_price) AS total_sales FROM transactions WHERE QUARTER(date) = 2 GROUP BY store_id ORDER BY total_sales ASC LIMIT 1;"
|
11 |
+
What are the top 3 most popular products in electronics?,summary,"SELECT product_name, SUM(quantity) AS total_sold FROM transactions WHERE category = 'electronics' GROUP BY product_name ORDER BY total_sold DESC LIMIT 3;"
|
data/retail_schema.sql
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
-- Table: transactions
|
3 |
+
CREATE TABLE transactions (
|
4 |
+
transaction_id INT PRIMARY KEY,
|
5 |
+
product_id INT,
|
6 |
+
product_name VARCHAR(100),
|
7 |
+
category VARCHAR(50),
|
8 |
+
quantity INT,
|
9 |
+
total_price DECIMAL(10, 2),
|
10 |
+
store_id INT,
|
11 |
+
date DATE
|
12 |
+
);
|
13 |
+
|
14 |
+
-- Table: returns
|
15 |
+
CREATE TABLE returns (
|
16 |
+
return_id INT PRIMARY KEY,
|
17 |
+
product_id INT,
|
18 |
+
product_name VARCHAR(100),
|
19 |
+
category VARCHAR(50),
|
20 |
+
store_id INT,
|
21 |
+
date DATE
|
22 |
+
);
|
23 |
+
|
24 |
+
-- Table: monthly_sales
|
25 |
+
CREATE TABLE monthly_sales (
|
26 |
+
product_id INT,
|
27 |
+
product_name VARCHAR(100),
|
28 |
+
month INT,
|
29 |
+
sales INT
|
30 |
+
);
|
data/testing_sql_data.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
product_id,product_name,category,quantity,total_price,store_id,payment_method,date
|
2 |
+
1,Sneakers,Footwear,5,500,1,Credit Card,2023-03-01
|
3 |
+
2,T-Shirt,Apparel,3,90,1,Cash,2023-03-02
|
4 |
+
3,Laptop,Electronics,2,2000,2,Credit Card,2023-04-01
|
5 |
+
4,Running Shoes,Footwear,7,700,2,Debit Card,2023-03-15
|
6 |
+
5,Polo Shirt,Apparel,2,60,1,Cash,2023-03-20
|
7 |
+
6,Jeans,Apparel,3,240,3,Credit Card,2023-03-05
|
8 |
+
7,Smartwatch,Electronics,1,350,2,Mobile Payment,2023-03-10
|
9 |
+
8,Hoodie,Apparel,2,100,1,Credit Card,2023-03-22
|
10 |
+
9,Tablet,Electronics,1,800,3,Debit Card,2023-04-05
|
11 |
+
10,Backpack,Accessories,4,150,1,Cash,2023-04-10
|
12 |
+
11,Sports Shoes,Footwear,6,600,2,Credit Card,2023-03-12
|
13 |
+
12,Headphones,Electronics,3,450,1,Mobile Payment,2023-03-18
|
14 |
+
13,Baseball Cap,Apparel,5,75,3,Cash,2023-03-25
|
15 |
+
14,Gaming Monitor,Electronics,2,400,2,Credit Card,2023-04-12
|
16 |
+
15,Leather Wallet,Accessories,3,90,1,Debit Card,2023-03-28
|
17 |
+
16,Winter Jacket,Apparel,2,180,2,Credit Card,2023-03-30
|
18 |
+
17,Wireless Mouse,Electronics,4,120,3,Cash,2023-04-02
|
19 |
+
18,Leather Belt,Accessories,2,60,1,Credit Card,2023-03-08
|
20 |
+
19,Mechanical Keyboard,Electronics,1,100,2,Mobile Payment,2023-04-08
|
21 |
+
20,Athletic Socks,Apparel,6,60,3,Cash,2023-03-14
|
model_intent_classifier/config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "distilbert-base-uncased",
|
3 |
+
"activation": "gelu",
|
4 |
+
"architectures": [
|
5 |
+
"DistilBertForSequenceClassification"
|
6 |
+
],
|
7 |
+
"attention_dropout": 0.1,
|
8 |
+
"dim": 768,
|
9 |
+
"dropout": 0.1,
|
10 |
+
"hidden_dim": 3072,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0",
|
13 |
+
"1": "LABEL_1",
|
14 |
+
"2": "LABEL_2",
|
15 |
+
"3": "LABEL_3",
|
16 |
+
"4": "LABEL_4"
|
17 |
+
},
|
18 |
+
"initializer_range": 0.02,
|
19 |
+
"label2id": {
|
20 |
+
"LABEL_0": 0,
|
21 |
+
"LABEL_1": 1,
|
22 |
+
"LABEL_2": 2,
|
23 |
+
"LABEL_3": 3,
|
24 |
+
"LABEL_4": 4
|
25 |
+
},
|
26 |
+
"max_position_embeddings": 512,
|
27 |
+
"model_type": "distilbert",
|
28 |
+
"n_heads": 12,
|
29 |
+
"n_layers": 6,
|
30 |
+
"pad_token_id": 0,
|
31 |
+
"qa_dropout": 0.1,
|
32 |
+
"seq_classif_dropout": 0.2,
|
33 |
+
"sinusoidal_pos_embds": false,
|
34 |
+
"tie_weights_": true,
|
35 |
+
"torch_dtype": "float32",
|
36 |
+
"transformers_version": "4.20.0",
|
37 |
+
"vocab_size": 30522
|
38 |
+
}
|
model_intent_classifier/label_mapping.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"summary": 0, "comparison": 1, "trend": 2, "anomaly": 3, "forecast": 4}
|
model_intent_classifier/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2652d23a878b3575ea53ee5f629b5578acd7bc8f082dcb804050f38ae34ffadb
|
3 |
+
size 267864195
|
model_intent_classifier/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
model_intent_classifier/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_intent_classifier/tokenizer_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"do_lower_case": true,
|
4 |
+
"mask_token": "[MASK]",
|
5 |
+
"model_max_length": 512,
|
6 |
+
"name_or_path": "distilbert-base-uncased",
|
7 |
+
"pad_token": "[PAD]",
|
8 |
+
"sep_token": "[SEP]",
|
9 |
+
"special_tokens_map_file": null,
|
10 |
+
"strip_accents": null,
|
11 |
+
"tokenize_chinese_chars": true,
|
12 |
+
"tokenizer_class": "DistilBertTokenizer",
|
13 |
+
"unk_token": "[UNK]"
|
14 |
+
}
|
model_intent_classifier/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_sqlgen_t5/config.json
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "cssupport/t5-small-awesome-text-to-sql",
|
3 |
+
"architectures": [
|
4 |
+
"T5ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"d_ff": 2048,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 512,
|
9 |
+
"decoder_start_token_id": 0,
|
10 |
+
"dense_act_fn": "relu",
|
11 |
+
"dropout_rate": 0.1,
|
12 |
+
"eos_token_id": 1,
|
13 |
+
"feed_forward_proj": "relu",
|
14 |
+
"initializer_factor": 1.0,
|
15 |
+
"is_encoder_decoder": true,
|
16 |
+
"is_gated_act": false,
|
17 |
+
"layer_norm_epsilon": 1e-06,
|
18 |
+
"model_type": "t5",
|
19 |
+
"n_positions": 512,
|
20 |
+
"num_decoder_layers": 6,
|
21 |
+
"num_heads": 8,
|
22 |
+
"num_layers": 6,
|
23 |
+
"output_past": true,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"relative_attention_max_distance": 128,
|
26 |
+
"relative_attention_num_buckets": 32,
|
27 |
+
"task_specific_params": {
|
28 |
+
"summarization": {
|
29 |
+
"early_stopping": true,
|
30 |
+
"length_penalty": 2.0,
|
31 |
+
"max_length": 200,
|
32 |
+
"min_length": 30,
|
33 |
+
"no_repeat_ngram_size": 3,
|
34 |
+
"num_beams": 4,
|
35 |
+
"prefix": "summarize: "
|
36 |
+
},
|
37 |
+
"translation_en_to_de": {
|
38 |
+
"early_stopping": true,
|
39 |
+
"max_length": 300,
|
40 |
+
"num_beams": 4,
|
41 |
+
"prefix": "translate English to German: "
|
42 |
+
},
|
43 |
+
"translation_en_to_fr": {
|
44 |
+
"early_stopping": true,
|
45 |
+
"max_length": 300,
|
46 |
+
"num_beams": 4,
|
47 |
+
"prefix": "translate English to French: "
|
48 |
+
},
|
49 |
+
"translation_en_to_ro": {
|
50 |
+
"early_stopping": true,
|
51 |
+
"max_length": 300,
|
52 |
+
"num_beams": 4,
|
53 |
+
"prefix": "translate English to Romanian: "
|
54 |
+
}
|
55 |
+
},
|
56 |
+
"tf_legacy_loss": false,
|
57 |
+
"torch_dtype": "float32",
|
58 |
+
"transformers_version": "4.20.0",
|
59 |
+
"use_cache": true,
|
60 |
+
"vocab_size": 32128
|
61 |
+
}
|
model_sqlgen_t5/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f92514c0d85ce64c114ac84d139b3d3c545ef077b66cc98cf5fdd8ae75d9ad8f
|
3 |
+
size 242070639
|
model_sqlgen_t5/special_tokens_map.json
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<extra_id_0>",
|
4 |
+
"<extra_id_1>",
|
5 |
+
"<extra_id_2>",
|
6 |
+
"<extra_id_3>",
|
7 |
+
"<extra_id_4>",
|
8 |
+
"<extra_id_5>",
|
9 |
+
"<extra_id_6>",
|
10 |
+
"<extra_id_7>",
|
11 |
+
"<extra_id_8>",
|
12 |
+
"<extra_id_9>",
|
13 |
+
"<extra_id_10>",
|
14 |
+
"<extra_id_11>",
|
15 |
+
"<extra_id_12>",
|
16 |
+
"<extra_id_13>",
|
17 |
+
"<extra_id_14>",
|
18 |
+
"<extra_id_15>",
|
19 |
+
"<extra_id_16>",
|
20 |
+
"<extra_id_17>",
|
21 |
+
"<extra_id_18>",
|
22 |
+
"<extra_id_19>",
|
23 |
+
"<extra_id_20>",
|
24 |
+
"<extra_id_21>",
|
25 |
+
"<extra_id_22>",
|
26 |
+
"<extra_id_23>",
|
27 |
+
"<extra_id_24>",
|
28 |
+
"<extra_id_25>",
|
29 |
+
"<extra_id_26>",
|
30 |
+
"<extra_id_27>",
|
31 |
+
"<extra_id_28>",
|
32 |
+
"<extra_id_29>",
|
33 |
+
"<extra_id_30>",
|
34 |
+
"<extra_id_31>",
|
35 |
+
"<extra_id_32>",
|
36 |
+
"<extra_id_33>",
|
37 |
+
"<extra_id_34>",
|
38 |
+
"<extra_id_35>",
|
39 |
+
"<extra_id_36>",
|
40 |
+
"<extra_id_37>",
|
41 |
+
"<extra_id_38>",
|
42 |
+
"<extra_id_39>",
|
43 |
+
"<extra_id_40>",
|
44 |
+
"<extra_id_41>",
|
45 |
+
"<extra_id_42>",
|
46 |
+
"<extra_id_43>",
|
47 |
+
"<extra_id_44>",
|
48 |
+
"<extra_id_45>",
|
49 |
+
"<extra_id_46>",
|
50 |
+
"<extra_id_47>",
|
51 |
+
"<extra_id_48>",
|
52 |
+
"<extra_id_49>",
|
53 |
+
"<extra_id_50>",
|
54 |
+
"<extra_id_51>",
|
55 |
+
"<extra_id_52>",
|
56 |
+
"<extra_id_53>",
|
57 |
+
"<extra_id_54>",
|
58 |
+
"<extra_id_55>",
|
59 |
+
"<extra_id_56>",
|
60 |
+
"<extra_id_57>",
|
61 |
+
"<extra_id_58>",
|
62 |
+
"<extra_id_59>",
|
63 |
+
"<extra_id_60>",
|
64 |
+
"<extra_id_61>",
|
65 |
+
"<extra_id_62>",
|
66 |
+
"<extra_id_63>",
|
67 |
+
"<extra_id_64>",
|
68 |
+
"<extra_id_65>",
|
69 |
+
"<extra_id_66>",
|
70 |
+
"<extra_id_67>",
|
71 |
+
"<extra_id_68>",
|
72 |
+
"<extra_id_69>",
|
73 |
+
"<extra_id_70>",
|
74 |
+
"<extra_id_71>",
|
75 |
+
"<extra_id_72>",
|
76 |
+
"<extra_id_73>",
|
77 |
+
"<extra_id_74>",
|
78 |
+
"<extra_id_75>",
|
79 |
+
"<extra_id_76>",
|
80 |
+
"<extra_id_77>",
|
81 |
+
"<extra_id_78>",
|
82 |
+
"<extra_id_79>",
|
83 |
+
"<extra_id_80>",
|
84 |
+
"<extra_id_81>",
|
85 |
+
"<extra_id_82>",
|
86 |
+
"<extra_id_83>",
|
87 |
+
"<extra_id_84>",
|
88 |
+
"<extra_id_85>",
|
89 |
+
"<extra_id_86>",
|
90 |
+
"<extra_id_87>",
|
91 |
+
"<extra_id_88>",
|
92 |
+
"<extra_id_89>",
|
93 |
+
"<extra_id_90>",
|
94 |
+
"<extra_id_91>",
|
95 |
+
"<extra_id_92>",
|
96 |
+
"<extra_id_93>",
|
97 |
+
"<extra_id_94>",
|
98 |
+
"<extra_id_95>",
|
99 |
+
"<extra_id_96>",
|
100 |
+
"<extra_id_97>",
|
101 |
+
"<extra_id_98>",
|
102 |
+
"<extra_id_99>"
|
103 |
+
],
|
104 |
+
"eos_token": "</s>",
|
105 |
+
"pad_token": "<pad>",
|
106 |
+
"unk_token": "<unk>"
|
107 |
+
}
|
model_sqlgen_t5/spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
3 |
+
size 791656
|
model_sqlgen_t5/tokenizer_config.json
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<extra_id_0>",
|
4 |
+
"<extra_id_1>",
|
5 |
+
"<extra_id_2>",
|
6 |
+
"<extra_id_3>",
|
7 |
+
"<extra_id_4>",
|
8 |
+
"<extra_id_5>",
|
9 |
+
"<extra_id_6>",
|
10 |
+
"<extra_id_7>",
|
11 |
+
"<extra_id_8>",
|
12 |
+
"<extra_id_9>",
|
13 |
+
"<extra_id_10>",
|
14 |
+
"<extra_id_11>",
|
15 |
+
"<extra_id_12>",
|
16 |
+
"<extra_id_13>",
|
17 |
+
"<extra_id_14>",
|
18 |
+
"<extra_id_15>",
|
19 |
+
"<extra_id_16>",
|
20 |
+
"<extra_id_17>",
|
21 |
+
"<extra_id_18>",
|
22 |
+
"<extra_id_19>",
|
23 |
+
"<extra_id_20>",
|
24 |
+
"<extra_id_21>",
|
25 |
+
"<extra_id_22>",
|
26 |
+
"<extra_id_23>",
|
27 |
+
"<extra_id_24>",
|
28 |
+
"<extra_id_25>",
|
29 |
+
"<extra_id_26>",
|
30 |
+
"<extra_id_27>",
|
31 |
+
"<extra_id_28>",
|
32 |
+
"<extra_id_29>",
|
33 |
+
"<extra_id_30>",
|
34 |
+
"<extra_id_31>",
|
35 |
+
"<extra_id_32>",
|
36 |
+
"<extra_id_33>",
|
37 |
+
"<extra_id_34>",
|
38 |
+
"<extra_id_35>",
|
39 |
+
"<extra_id_36>",
|
40 |
+
"<extra_id_37>",
|
41 |
+
"<extra_id_38>",
|
42 |
+
"<extra_id_39>",
|
43 |
+
"<extra_id_40>",
|
44 |
+
"<extra_id_41>",
|
45 |
+
"<extra_id_42>",
|
46 |
+
"<extra_id_43>",
|
47 |
+
"<extra_id_44>",
|
48 |
+
"<extra_id_45>",
|
49 |
+
"<extra_id_46>",
|
50 |
+
"<extra_id_47>",
|
51 |
+
"<extra_id_48>",
|
52 |
+
"<extra_id_49>",
|
53 |
+
"<extra_id_50>",
|
54 |
+
"<extra_id_51>",
|
55 |
+
"<extra_id_52>",
|
56 |
+
"<extra_id_53>",
|
57 |
+
"<extra_id_54>",
|
58 |
+
"<extra_id_55>",
|
59 |
+
"<extra_id_56>",
|
60 |
+
"<extra_id_57>",
|
61 |
+
"<extra_id_58>",
|
62 |
+
"<extra_id_59>",
|
63 |
+
"<extra_id_60>",
|
64 |
+
"<extra_id_61>",
|
65 |
+
"<extra_id_62>",
|
66 |
+
"<extra_id_63>",
|
67 |
+
"<extra_id_64>",
|
68 |
+
"<extra_id_65>",
|
69 |
+
"<extra_id_66>",
|
70 |
+
"<extra_id_67>",
|
71 |
+
"<extra_id_68>",
|
72 |
+
"<extra_id_69>",
|
73 |
+
"<extra_id_70>",
|
74 |
+
"<extra_id_71>",
|
75 |
+
"<extra_id_72>",
|
76 |
+
"<extra_id_73>",
|
77 |
+
"<extra_id_74>",
|
78 |
+
"<extra_id_75>",
|
79 |
+
"<extra_id_76>",
|
80 |
+
"<extra_id_77>",
|
81 |
+
"<extra_id_78>",
|
82 |
+
"<extra_id_79>",
|
83 |
+
"<extra_id_80>",
|
84 |
+
"<extra_id_81>",
|
85 |
+
"<extra_id_82>",
|
86 |
+
"<extra_id_83>",
|
87 |
+
"<extra_id_84>",
|
88 |
+
"<extra_id_85>",
|
89 |
+
"<extra_id_86>",
|
90 |
+
"<extra_id_87>",
|
91 |
+
"<extra_id_88>",
|
92 |
+
"<extra_id_89>",
|
93 |
+
"<extra_id_90>",
|
94 |
+
"<extra_id_91>",
|
95 |
+
"<extra_id_92>",
|
96 |
+
"<extra_id_93>",
|
97 |
+
"<extra_id_94>",
|
98 |
+
"<extra_id_95>",
|
99 |
+
"<extra_id_96>",
|
100 |
+
"<extra_id_97>",
|
101 |
+
"<extra_id_98>",
|
102 |
+
"<extra_id_99>"
|
103 |
+
],
|
104 |
+
"clean_up_tokenization_spaces": true,
|
105 |
+
"eos_token": "</s>",
|
106 |
+
"extra_ids": 100,
|
107 |
+
"model_max_length": 512,
|
108 |
+
"name_or_path": "cssupport/t5-small-awesome-text-to-sql",
|
109 |
+
"pad_token": "<pad>",
|
110 |
+
"sp_model_kwargs": {},
|
111 |
+
"special_tokens_map_file": null,
|
112 |
+
"tokenizer_class": "T5Tokenizer",
|
113 |
+
"unk_token": "<unk>"
|
114 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==2.2.5
|
2 |
+
pandas==2.2.3
|
3 |
+
torch==2.7.0
|
4 |
+
transformers==4.20.0
|
5 |
+
datasets==3.6.0
|
6 |
+
scikit-learn==1.5.0
|
7 |
+
matplotlib==3.10.3
|
8 |
+
gradio==5.29.1
|
9 |
+
huggingface-hub==0.31.2
|
10 |
+
sentencepiece==0.2.0
|
11 |
+
tokenizers==0.12.1
|
12 |
+
tqdm==4.67.1
|
13 |
+
pillow==11.2.1
|
14 |
+
aiohttp==3.11.18
|
15 |
+
sqlparse==0.5.0
|