Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import sys | |
from pathlib import Path | |
from datetime import datetime | |
import json | |
# Add the duckdb-nsql directory to the Python path | |
current_dir = Path(__file__).resolve().parent | |
duckdb_nsql_dir = current_dir / 'duckdb-nsql' | |
eval_dir = duckdb_nsql_dir / 'eval' | |
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)]) | |
# Import necessary functions and classes from predict.py and evaluate.py | |
from eval.predict import predict, console, get_manifest, DefaultLoader | |
from eval.constants import PROMPT_FORMATTERS | |
from eval.evaluate import evaluate, compute_metrics, get_to_print | |
from eval.evaluate import test_suite_evaluation, read_tables_json | |
def run_evaluation(model_name): | |
results = [] | |
if "OPENROUTER_API_KEY" not in os.environ: | |
return "Error: OPENROUTER_API_KEY not found in environment variables." | |
try: | |
# Set up the arguments similar to the CLI in predict.py | |
dataset_path = "duckdb-nsql/eval/data/dev.json" | |
table_meta_path = "duckdb-nsql/eval/data/tables.json" | |
output_dir = "duckdb-nsql/output/" | |
prompt_format = "duckdbinstgraniteshort" | |
stop_tokens = [';'] | |
max_tokens = 30000 | |
temperature = 0.1 | |
num_beams = -1 | |
manifest_client = "openrouter" | |
manifest_engine = model_name | |
manifest_connection = "http://localhost:5000" | |
overwrite_manifest = True | |
parallel = False | |
# Initialize necessary components | |
data_formatter = DefaultLoader() | |
prompt_formatter = PROMPT_FORMATTERS[prompt_format]() | |
# Load manifest | |
manifest = get_manifest( | |
manifest_client=manifest_client, | |
manifest_connection=manifest_connection, | |
manifest_engine=manifest_engine, | |
) | |
results.append(f"Using model: {manifest_engine}") | |
# Load data and metadata | |
results.append("Loading metadata and data...") | |
db_to_tables = data_formatter.load_table_metadata(table_meta_path) | |
data = data_formatter.load_data(dataset_path) | |
# Generate output filename | |
date_today = datetime.now().strftime("%y-%m-%d") | |
pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json" | |
pred_path = Path(output_dir) / pred_filename | |
results.append(f"Prediction will be saved to: {pred_path}") | |
# Debug: Print predict function signature | |
yield f"Predict function signature: {inspect.signature(predict)}" | |
# Run prediction | |
yield "Starting prediction..." | |
try: | |
predict( | |
dataset_path=dataset_path, | |
table_meta_path=table_meta_path, | |
output_dir=output_dir, | |
prompt_format=prompt_format, | |
stop_tokens=stop_tokens, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
num_beams=num_beams, | |
manifest_client=manifest_client, | |
manifest_engine=manifest_engine, | |
manifest_connection=manifest_connection, | |
overwrite_manifest=overwrite_manifest, | |
parallel=parallel | |
) | |
except TypeError as e: | |
yield f"TypeError in predict function: {str(e)}" | |
yield "Attempting to call predict with only expected arguments..." | |
# Try calling predict with only the arguments it expects | |
predict_args = inspect.getfullargspec(predict).args | |
filtered_args = {k: v for k, v in locals().items() if k in predict_args} | |
predict(**filtered_args) | |
results.append("Prediction completed.") | |
# Run evaluation | |
results.append("Starting evaluation...") | |
# Set up evaluation arguments | |
gold_path = Path(dataset_path) | |
db_dir = "duckdb-nsql/eval/data/databases/" | |
tables_path = Path(table_meta_path) | |
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path)) | |
db_schemas = read_tables_json(str(tables_path)) | |
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8")) | |
pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()] | |
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict] | |
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict] | |
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict] | |
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict] | |
pred_sqls = [p["pred"] for p in pred_sqls_dict] | |
categories = [p.get("category", "") for p in gold_sqls_dict] | |
metrics = compute_metrics( | |
gold_sqls=gold_sqls, | |
pred_sqls=pred_sqls, | |
gold_dbs=gold_dbs, | |
setup_sqls=setup_sqls, | |
validate_sqls=validate_sqls, | |
kmaps=kmaps, | |
db_schemas=db_schemas, | |
database_dir=db_dir, | |
lowercase_schema_match=False, | |
model_name=model_name, | |
categories=categories, | |
) | |
results.append("Evaluation completed.") | |
# Format and add the evaluation metrics to the results | |
if metrics: | |
to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls)) | |
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]]) | |
results.append(f"Evaluation metrics:\n{formatted_metrics}") | |
else: | |
results.append("No evaluation metrics returned.") | |
except Exception as e: | |
results.append(f"An unexpected error occurred: {str(e)}") | |
return "\n\n".join(results) | |
with gr.Blocks() as demo: | |
gr.Markdown("# DuckDB SQL Evaluation App") | |
model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)") | |
start_btn = gr.Button("Start Evaluation") | |
output = gr.Textbox(label="Output", lines=20) | |
start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output) | |
demo.launch() |