import os
import sys
from pathlib import Path
from datetime import datetime
import json
import traceback
import uuid
from huggingface_hub import CommitScheduler

current_dir = Path(__file__).resolve().parent
duckdb_nsql_dir = current_dir / 'duckdb-nsql'
eval_dir = duckdb_nsql_dir / 'eval'
sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])

from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
from eval.evaluate import evaluate, compute_metrics, get_to_print
from eval.evaluate import test_suite_evaluation, read_tables_json
from eval.schema import TextToSQLParams, Table

AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())

prediction_folder = Path("prediction_results/")
evaluation_folder = Path("evaluation_results/")

file_uuid = uuid.uuid4()

prediction_scheduler = CommitScheduler(
    repo_id="sql-console/duckdb-nsql-predictions",
    repo_type="dataset",
    folder_path=prediction_folder,
    path_in_repo="data",
    every=10,
)

evaluation_scheduler = CommitScheduler(
    repo_id="sql-console/duckdb-nsql-scores",
    repo_type="dataset",
    folder_path=evaluation_folder,
    path_in_repo="data",
    every=10,
)

def save_prediction(inference_api, model_name, prompt_format, question, generated_sql):
    prediction_file = prediction_folder / f"prediction_{file_uuid}.json"
    prediction_folder.mkdir(parents=True, exist_ok=True)
    with prediction_scheduler.lock:
        with prediction_file.open("a") as f:
            json.dump({
                "inference_api": inference_api,
                "model_name": model_name,
                "prompt_format": prompt_format,
                "question": question,
                "generated_sql": generated_sql,
                "timestamp": datetime.now().isoformat()
            }, f)

def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics):
    evaluation_file = evaluation_folder / f"evaluation_{file_uuid}.json"
    evaluation_folder.mkdir(parents=True, exist_ok=True)

    # Extract and flatten the category-specific execution metrics
    categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
    flattened_metrics = {
        "inference_api": inference_api,
        "model_name": model_name,
        "prompt_format": prompt_format,
        "custom_prompt": str(custom_prompt) if prompt_format.startswith("custom") else "",
        "timestamp": datetime.now().isoformat()
    }

    # Flatten each category's metrics into separate columns
    for category in categories:
        if category in metrics['exec']:
            category_metrics = metrics['exec'][category]
            flattened_metrics[f"{category}_count"] = category_metrics['count']
            flattened_metrics[f"{category}_execution_accuracy"] = category_metrics['exec']
        else:
            flattened_metrics[f"{category}_count"] = 0
            flattened_metrics[f"{category}_execution_accuracy"] = 0.0

    with evaluation_scheduler.lock:
        with evaluation_file.open("a") as f:
            json.dump(flattened_metrics, f)
            f.write('\n')

def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
    dataset_path = str(eval_dir / "data/dev.json")
    table_meta_path = str(eval_dir / "data/tables.json")
    stop_tokens = ['`<|dummy|>`']
    max_tokens = 1000
    temperature = 0
    num_beams = -1
    manifest_client = inference_api
    manifest_engine = model_name
    manifest_connection = "http://localhost:5000"
    overwrite_manifest = True
    parallel = False

    yield "Starting prediction..."

    try:
        # Initialize necessary components
        data_formatter = DefaultLoader()
        if prompt_format.startswith("custom"):
            prompt_formatter_cls = PROMPT_FORMATTERS["custom"]
            prompt_formatter_cls.PROMPT_TEMPLATE = custom_prompt
            prompt_formatter = prompt_formatter_cls()
        else:
            prompt_formatter = PROMPT_FORMATTERS[prompt_format]()

        # Load manifest
        manifest = get_manifest(
            manifest_client=manifest_client,
            manifest_connection=manifest_connection,
            manifest_engine=manifest_engine,
        )

        # Load data
        data = data_formatter.load_data(dataset_path)
        db_to_tables = data_formatter.load_table_metadata(table_meta_path)

        # Prepare input for generate_sql
        text_to_sql_inputs = []
        for input_question in data:
            question = input_question["question"]
            db_id = input_question.get("db_id", "none")
            if db_id != "none":
                table_params = list(db_to_tables.get(db_id, {}).values())
            else:
                table_params = []

            text_to_sql_inputs.append(TextToSQLParams(
                instruction=question,
                database=db_id,
                tables=table_params,
            ))

        # Generate SQL
        generated_sqls = generate_sql(
            manifest=manifest,
            text_to_sql_in=text_to_sql_inputs,
            retrieved_docs=[[] for _ in text_to_sql_inputs],
            prompt_formatter=prompt_formatter,
            stop_tokens=stop_tokens,
            overwrite_manifest=overwrite_manifest,
            max_tokens=max_tokens,
            temperature=temperature,
            num_beams=num_beams,
            parallel=parallel
        )

        # Save results
        output_file.parent.mkdir(parents=True, exist_ok=True)
        with output_file.open('w') as f:
            for original_data, (sql, _) in zip(data, generated_sqls):
                output = {**original_data, "pred": sql}
                json.dump(output, f)
                f.write('\n')

                # Save prediction to dataset
                save_prediction(inference_api, model_name, prompt_format, original_data["question"], sql)

        yield f"Prediction completed. Results saved to {output_file}"
    except Exception as e:
        yield f"Prediction failed with error: {str(e)}"
        yield f"Error traceback: {traceback.format_exc()}"

def run_evaluation(inference_api, model_name, prompt_format="duckdbinstgraniteshort", custom_prompt=None):
    if "OPENROUTER_API_KEY" not in os.environ:
        yield "Error: OPENROUTER_API_KEY not found in environment variables."
        return
    if "HF_TOKEN" not in os.environ:
        yield "Error: HF_TOKEN not found in environment variables."
        return

    try:
        # Set up the arguments
        dataset_path = str(eval_dir / "data/dev.json")
        table_meta_path = str(eval_dir / "data/tables.json")
        output_dir = eval_dir / "output"

        yield f"Using model: {model_name}"
        yield f"Using prompt format: {prompt_format}"

        if prompt_format == "custom":
            prompt_format = prompt_format+"_"+str(abs(hash(custom_prompt)) % (10 ** 8))

        output_file = output_dir / f"{prompt_format}_0docs_{model_name.replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"

        # Ensure the output directory exists
        output_dir.mkdir(parents=True, exist_ok=True)

        if output_file.exists():
            yield f"Prediction file already exists: {output_file}"
            yield "Skipping prediction step and proceeding to evaluation."
        else:
            # Run prediction
            for output in run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
                yield output

        # Run evaluation
        yield "Starting evaluation..."

        # Set up evaluation arguments
        gold_path = Path(dataset_path)
        db_dir = str(eval_dir / "data/databases/")
        tables_path = Path(table_meta_path)

        kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
        db_schemas = read_tables_json(str(tables_path))

        gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
        pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()]

        gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
        setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
        validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
        gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
        pred_sqls = [p["pred"] for p in pred_sqls_dict]
        categories = [p.get("category", "") for p in gold_sqls_dict]

        yield "Computing metrics..."
        metrics = compute_metrics(
            gold_sqls=gold_sqls,
            pred_sqls=pred_sqls,
            gold_dbs=gold_dbs,
            setup_sqls=setup_sqls,
            validate_sqls=validate_sqls,
            kmaps=kmaps,
            db_schemas=db_schemas,
            database_dir=db_dir,
            lowercase_schema_match=False,
            model_name=model_name,
            categories=categories,
        )

        # Save evaluation results to dataset
        save_evaluation(inference_api, model_name, prompt_format, custom_prompt, metrics)

        yield "Evaluation completed."

        if metrics:
            yield "Overall Results:"
            overall_metrics = metrics['exec']['all']
            yield f"All (n={overall_metrics['count']}) - Execution Accuracy: {overall_metrics['exec']:.3f}"
            yield f"All (n={overall_metrics['count']}) - Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}"

            categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']

            for category in categories:
                if category in metrics['exec']:
                    category_metrics = metrics['exec'][category]
                    yield f"{category} (n={category_metrics['count']}) - Execution Accuracy: {category_metrics['exec']:.3f}"
                else:
                    yield f"{category}: No data available"
        else:
            yield "No evaluation metrics returned."
    except Exception as e:
        yield f"An unexpected error occurred: {str(e)}"
        yield f"Error traceback: {traceback.format_exc()}"

if __name__ == "__main__":
    model_name = input("Enter the model name: ")
    prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort"
    for result in run_evaluation(model_name, prompt_format):
        print(result, flush=True)