Spaces:
Sleeping
Sleeping
File size: 6,063 Bytes
977063a 5051da6 acfff07 f9d0ccd acfff07 977063a acfff07 1637f29 acfff07 1637f29 acfff07 1637f29 5051da6 f9d0ccd 5051da6 acfff07 2d065e3 acfff07 d9c57da acfff07 d9c57da acfff07 5051da6 acfff07 2d065e3 acfff07 5051da6 acfff07 5051da6 acfff07 5051da6 49c6a0b 5051da6 acfff07 5051da6 f9d0ccd 5051da6 49c6a0b 5051da6 49c6a0b 977063a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import gradio as gr
import os
import sys
from pathlib import Path
from datetime import datetime
import json
# Add the duckdb-nsql directory to the Python path
current_dir = Path(__file__).resolve().parent
duckdb_nsql_dir = current_dir / 'duckdb-nsql'
eval_dir = duckdb_nsql_dir / 'eval'
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
# Import necessary functions and classes from predict.py and evaluate.py
from eval.predict import predict, console, get_manifest, DefaultLoader
from eval.constants import PROMPT_FORMATTERS
from eval.evaluate import evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
def run_evaluation(model_name):
results = []
if "OPENROUTER_API_KEY" not in os.environ:
return "Error: OPENROUTER_API_KEY not found in environment variables."
try:
# Set up the arguments similar to the CLI in predict.py
dataset_path = "duckdb-nsql/eval/data/dev.json"
table_meta_path = "duckdb-nsql/eval/data/tables.json"
output_dir = "duckdb-nsql/output/"
prompt_format = "duckdbinstgraniteshort"
stop_tokens = [';']
max_tokens = 30000
temperature = 0.1
num_beams = -1
manifest_client = "openrouter"
manifest_engine = model_name
manifest_connection = "http://localhost:5000"
overwrite_manifest = True
parallel = False
# Initialize necessary components
data_formatter = DefaultLoader()
prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
# Load manifest
manifest = get_manifest(
manifest_client=manifest_client,
manifest_connection=manifest_connection,
manifest_engine=manifest_engine,
)
results.append(f"Using model: {manifest_engine}")
# Load data and metadata
results.append("Loading metadata and data...")
db_to_tables = data_formatter.load_table_metadata(table_meta_path)
data = data_formatter.load_data(dataset_path)
# Generate output filename
date_today = datetime.now().strftime("%y-%m-%d")
pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json"
pred_path = Path(output_dir) / pred_filename
results.append(f"Prediction will be saved to: {pred_path}")
# Debug: Print predict function signature
yield f"Predict function signature: {inspect.signature(predict)}"
# Run prediction
yield "Starting prediction..."
try:
predict(
dataset_path=dataset_path,
table_meta_path=table_meta_path,
output_dir=output_dir,
prompt_format=prompt_format,
stop_tokens=stop_tokens,
max_tokens=max_tokens,
temperature=temperature,
num_beams=num_beams,
manifest_client=manifest_client,
manifest_engine=manifest_engine,
manifest_connection=manifest_connection,
overwrite_manifest=overwrite_manifest,
parallel=parallel
)
except TypeError as e:
yield f"TypeError in predict function: {str(e)}"
yield "Attempting to call predict with only expected arguments..."
# Try calling predict with only the arguments it expects
predict_args = inspect.getfullargspec(predict).args
filtered_args = {k: v for k, v in locals().items() if k in predict_args}
predict(**filtered_args)
results.append("Prediction completed.")
# Run evaluation
results.append("Starting evaluation...")
# Set up evaluation arguments
gold_path = Path(dataset_path)
db_dir = "duckdb-nsql/eval/data/databases/"
tables_path = Path(table_meta_path)
kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
db_schemas = read_tables_json(str(tables_path))
gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()]
gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
pred_sqls = [p["pred"] for p in pred_sqls_dict]
categories = [p.get("category", "") for p in gold_sqls_dict]
metrics = compute_metrics(
gold_sqls=gold_sqls,
pred_sqls=pred_sqls,
gold_dbs=gold_dbs,
setup_sqls=setup_sqls,
validate_sqls=validate_sqls,
kmaps=kmaps,
db_schemas=db_schemas,
database_dir=db_dir,
lowercase_schema_match=False,
model_name=model_name,
categories=categories,
)
results.append("Evaluation completed.")
# Format and add the evaluation metrics to the results
if metrics:
to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls))
formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]])
results.append(f"Evaluation metrics:\n{formatted_metrics}")
else:
results.append("No evaluation metrics returned.")
except Exception as e:
results.append(f"An unexpected error occurred: {str(e)}")
return "\n\n".join(results)
with gr.Blocks() as demo:
gr.Markdown("# DuckDB SQL Evaluation App")
model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
start_btn = gr.Button("Start Evaluation")
output = gr.Textbox(label="Output", lines=20)
start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output)
demo.launch() |