"""Start page of the app

This page is used to initialize a model card that is either:

1. based on the skops template
2. empty
3. loads an existing model card

Optionally, users can add a model file, data, requirements, and choose a task.

"""

import glob
import io
import os
import pickle
import shutil
from pathlib import Path
from tempfile import mkdtemp

import pandas as pd
import sklearn
import streamlit as st
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
from sklearn.base import BaseEstimator
from sklearn.dummy import DummyClassifier

import skops.io as sio
from skops import card, hub_utils


tmp_path = Path(mkdtemp(prefix="skops-"))  # temporary files
description = """Create an sklearn model card

This Hugging Face Space that aims to provide a simple interface to use the
[`skops`](https://skops.readthedocs.io/) model card creation utilities.

"""


def load_model() -> None:
    if st.session_state.get("model_file") is None:
        st.session_state.model = DummyClassifier()
        return

    bytes_data = st.session_state.model_file.getvalue()
    if st.session_state.model_file.name.endswith("skops"):
        model = sio.loads(bytes_data, trusted=True)
    else:
        model = pickle.loads(bytes_data)
    assert isinstance(model, BaseEstimator), "model must be an sklearn model"

    st.session_state.model = model


def load_data() -> None:
    if st.session_state.get("data_file"):
        bytes_data = io.BytesIO(st.session_state.data_file.getvalue())
        df = pd.read_csv(bytes_data)
    else:
        df = pd.DataFrame([])

    st.session_state.data = df


def _clear_repo(path: str) -> None:
    for file_path in glob.glob(str(Path(path) / "*")):
        if os.path.isfile(file_path) or os.path.islink(file_path):
            os.unlink(file_path)
        elif os.path.isdir(file_path):
            shutil.rmtree(file_path)


def init_repo() -> None:
    path = st.session_state.hf_path
    _clear_repo(path)
    requirements = []
    task = "tabular-classification"
    data = pd.DataFrame([])

    if "requirements" in st.session_state:
        requirements = st.session_state.requirements.splitlines()
    if "task" in st.session_state:
        task = st.session_state.task
    if "data_file" in st.session_state:
        load_data()
        data = st.session_state.data

    if task.startswith("text") and isinstance(data, pd.DataFrame):
        data = data.values.tolist()

    try:
        file_name = tmp_path / "model.skops"
        sio.dump(st.session_state.model, file_name)

        hub_utils.init(
            model=file_name,
            dst=path,
            task=task,
            data=data,
            requirements=requirements,
        )
    except Exception as exc:
        print("Uh oh, something went wrong when initializing the repo:", exc)


def create_skops_model_card() -> None:
    init_repo()
    metadata = card.metadata_from_config(st.session_state.hf_path)
    model_card = card.Card(model=st.session_state.model, metadata=metadata)
    st.session_state.model_card = model_card
    st.session_state.model_card_type = "skops"
    st.session_state.screen.state = "edit"


def create_empty_model_card() -> None:
    init_repo()
    metadata = card.metadata_from_config(st.session_state.hf_path)
    model_card = card.Card(
        model=st.session_state.model, metadata=metadata, template=None
    )
    model_card.add(**{"Untitled": "[More Information Needed]"})
    st.session_state.model_card = model_card
    st.session_state.model_card_type = "empty"
    st.session_state.screen.state = "edit"


def create_hf_model_card() -> None:
    repo_id = st.session_state.get("hf_repo_id", "").strip().strip("'").strip('"')
    if not repo_id:
        return

    try:
        allow_patterns = [
            "*.md", ".txt", "*.png", "*.gif", "*.jpg", "*.jpeg", "*.bmp", "*.webp"
        ]
        path = snapshot_download(repo_id, allow_patterns=allow_patterns)
    except (HFValidationError, RepositoryNotFoundError):
        st.error(
            f"Repository '{repo_id}' could not be found on HF Hub, "
            "please check that the repo ID is correct."
        )
        return

    # move everything to the hf_path and working dir
    hf_path = st.session_state.hf_path
    shutil.copytree(path, hf_path, dirs_exist_ok=True)
    shutil.copytree(path, ".", dirs_exist_ok=True)

    model_card = card.parse_modelcard(hf_path / "README.md")
    st.session_state.model_card = model_card
    st.session_state.model_card_type = "loaded"
    st.session_state.screen.state = "edit"


def start_input_form():
    if "model" not in st.session_state:
        st.session_state.model = DummyClassifier()

    if "data" not in st.session_state:
        st.session_state.data = pd.DataFrame([])

    if "model_card" not in st.session_state:
        st.session_state.model_card = None

    st.markdown(description)
    st.markdown("---")

    st.text(
        "Upload an sklearn model (strongly recommended)\n"
        "The model can be used to automatically populate fields in the model card."
    )

    if not st.session_state.get("model_file"):
        st.file_uploader(
            "Upload an sklearn model (pickle or skops format)",
            on_change=load_model,
            key="model_file",
        )

    st.markdown("---")

    st.text(
        "Upload samples from your data (in csv format)\n"
        "This sample data can be attached to the metadata of the model card"
    )
    st.file_uploader(
        "Upload input data (csv)", type=["csv"], on_change=load_data, key="data_file"
    )
    st.markdown("---")

    st.selectbox(
        label="Choose the task type*",
        options=[
            "tabular-classification",
            "tabular-regression",
            "text-classification",
            "text-regression",
        ],
        key="task",
        on_change=init_repo,
    )
    st.markdown("---")

    st.text_area(
        label="Requirements*",
        value=f"scikit-learn=={sklearn.__version__}\n",
        key="requirements",
        on_change=init_repo,
    )
    st.markdown("---")

    st.markdown("Choose one of the options below to get started:")
    col_0, col_1, col_2 = st.columns([2, 2, 2])
    with col_0:
        st.button("Create a new skops model card", on_click=create_skops_model_card)

    with col_1:
        st.button("Create a new empty model card", on_click=create_empty_model_card)

    with col_2:
        with st.form("Load existing model card from HF Hub", clear_on_submit=False):
            st.text_input("Repo name (e.g. 'gpt2')", key="hf_repo_id")
            st.form_submit_button("Load", on_click=create_hf_model_card)


start_input_form()