import glob
import json

import datasets  # type: ignore
from huggingface_hub import snapshot_download  # type: ignore
import pandas as pd  # type: ignore

from backend.envs import EVAL_DATASET, TRACES_DATASET, TOKEN, EVAL_RESULTS_PATH


SUBSETS = ["base","cot","orig"]


def load_cot_data():

    ####
    # Load the evaluation results data
    ####

    # download raw data
    print("Downloading evaluation results...")
    snapshot_download(
        repo_id=EVAL_DATASET,
        revision="main",
        local_dir=EVAL_RESULTS_PATH,
        repo_type="dataset",
        max_workers=8,
        token=TOKEN
    )

    # get all models for which results are stored
    models = []
    for path in glob.glob(f"{EVAL_RESULTS_PATH}/data/*/*", recursive=False):
        models.append(path.replace(f"{EVAL_RESULTS_PATH}/data/",""))

    # load the evaluation results and create a dataframe
    results = []
    for model in models:
        for subset in SUBSETS:
            result_files = glob.glob(f"{EVAL_RESULTS_PATH}/data/{model}/{subset}/**/*.json", recursive=True)
            for json_filepath in result_files:
                with open(json_filepath) as fp:
                    data = json.load(fp)
                if "results" in data.keys():
                    for k,v in data["results"].items():
                        record = v.copy()
                        record["model"] = model
                        record["subset"] = subset
                        results.append(record)

    df_results = pd.DataFrame(results)
    del results

    # postprocess task/config data
    def split_alias(alias: str) -> pd.Series:
        if alias[-5:]=="_base":
            alias = alias[:-5]
        elif alias[-4:]=="_cot":
            alias = alias[:-4]

        if "_" not in alias:
            task = alias
            config = ""
        else:
            config, task = alias.split("_")

        return pd.Series({"task": task, "config": config})

    df_results = pd.concat([df_results, df_results.alias.apply(split_alias)], axis=1)

    # baseline accuracies in separete df
    df_baseline = df_results[df_results.subset.eq("base")].groupby(["model","task"])[["acc,none"]].mean()

    # build cot eval df with baseline accuracies in separate column
    df_tmp1 = df_results[df_results.subset.eq("cot")].sort_values(by=["model","task","config"])
    df_tmp1.reset_index(inplace=True, drop=True)

    df_cot = df_tmp1[["model","task","config"]].copy()
    df_cot["acc_cot"] = df_tmp1["acc,none"]
    df_cot["acc_base"] = df_cot.apply(lambda row: df_baseline.loc[(row.model, row.task)]["acc,none"], axis=1)

    df_cot["acc_gain"] = df_cot.acc_cot - df_cot.acc_base
    df_cot["delta_rel"] = (df_cot.acc_cot - df_cot.acc_base)/df_cot.acc_base

    # average eval results for all tasks in extra df
    df_cot_avg = df_cot.groupby(["model","config"]).mean(numeric_only=True).reset_index()
    df_cot_avg["task"] = "all"

    # add average results to cot df
    df_cot = pd.concat([df_cot_avg, df_cot], ignore_index=True)


    ####
    # Load the traces data
    ####

    # load traces data and extract configs
    print("Loading traces data...")
    dataset = datasets.load_dataset(TRACES_DATASET, split="test", token=TOKEN, num_proc=8)
    dataset = dataset.select_columns(["config_data"])
    df_cottraces = pd.DataFrame({"config_data": dataset["config_data"]})
    del dataset
    config_data = []
    for data in df_cottraces.config_data.to_list():
        if data is not None:
            config_data.append(dict(data))
    del df_cottraces
    df_cotconfigs = pd.DataFrame(config_data)
    df_cotconfigs.drop_duplicates(inplace=True, ignore_index=True)
    df_cotconfigs    

    # add cot configs data to df_cot
    def select_config_data(row):
        df_selected = df_cotconfigs[df_cotconfigs.name.eq(row.config) & df_cotconfigs.model.eq(row.model)]
        if len(df_selected) == 0:
            print(f"Config {row.config} not found for model {row.model}")
            return None
        return df_selected.drop(columns=["name", "model", "task"]).iloc[0]

    df_cot = pd.concat(
        [
            df_cot,
            df_cot.apply(select_config_data, axis=1)
        ],
        axis=1
    )

    # accuracy values in percent
    for col in ['acc_base', 'acc_cot', 'acc_gain']:
        df_cot[col] = 100 * df_cot[col]

    print("Regimes dataframe created:")
    print(df_cot.head(3))

    ####
    # Create error dataframe
    ####

    df_cot_err = df_cot.groupby(["model","task"]).agg({'acc_gain': ['mean', 'min', 'max'], "acc_base": "mean", "acc_cot": "mean"})
    df_cot_err.columns = ['-'.join(col).strip() for col in df_cot_err.columns.values]
    df_cot_err["acc_gain-err"] = 0.5 * (df_cot_err["acc_gain-max"] - df_cot_err["acc_gain-min"])
    df_cot_err.reset_index(inplace=True)
    df_cot_err.rename(columns={"acc_base-mean": "base accuracy", "acc_cot-mean": "cot accuracy", "acc_gain-mean": "marginal acc. gain"}, inplace=True)    

    print("Error dataframe created:")
    print(df_cot_err.head(3))


    return df_cot_err, df_cot