import gradio as gr
from transformers import pipeline

def generate_text(
    model_name,
    text,
    min_length,
    max_length,
    temperature,
    top_k,
    top_p
):
    models_map = {
        "Мои любимые юморески": "gpt2-vk-aneki",
        "бугро тред": "gpt2-vk-bugro",
        "Калик)": "gpt2-vk-kalik"
    }

    model = "MesonWarrior/" + models_map[model_name]

    pipe = pipeline(
        'text-generation',
        model=model,
        tokenizer=model,
        min_length=min_length,
        max_length=max_length
    )

    return pipe(text, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True)[0]['generated_text']

def interface():
    with gr.Row():
        with gr.Column():
            with gr.Row():
                model = gr.Dropdown(
                    ["Мои любимые юморески", "бугро тред", "Калик)"],
                    label="Модель (Текст какого паблика генерировать)",
                    value="Мои любимые юморески",
                )
            text = gr.Textbox(lines=7, label="Входной текст", placeholder="Введите текст который продолжит нейросеть...")
        output = gr.Textbox(lines=12, label="Выходной текст", placeholder="Здесь будет текст сгенерированный нейросетью...")
    with gr.Row():
        with gr.Column():
            min_length = gr.Slider(
                minimum=0, maximum=100, value=32, step=1,
                label="Минимальная длина",
                info="Минимальное количество символов в выходном тексте."
            )
            max_length = gr.Slider(
                minimum=0, maximum=200, value=64, step=1,
                label="Максимальная длина",
                info="Максимальное количество символов в выходном тексте."
            )
            temperature = gr.Slider(
                minimum=0.05, maximum=1.95, value=0.9, step=0.05,
                label="Температура",
                info="Чем выше тем рандомнее, чем ниже тем больше повторений."
            )
            top_k = gr.Slider(
                minimum=0, maximum=100, value=50, step=0.05,
                label="Top K",
            )
            top_p = gr.Slider(
                minimum=0, maximum=1, value=0.9, step=0.05,
                label="Top P",
            )
        with gr.Column():
            with gr.Row():
                generate_btn = gr.Button(
                    "Сгенерировать", variant="primary", label="Generate",
                )

        generate_btn.click(
            fn=generate_text,
            inputs=[
                model,
                text,
                min_length,
                max_length,
                temperature,
                top_k,
                top_p
            ],
            outputs=output,
        )

with gr.Blocks(
    title="GPT2 VK") as demo:
        gr.Markdown("""
        # GPT2 VK
        Файнтюны [этой](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) модели по вашим любимым пабликам ВКонтакте.
        #### Паблики представленные в моделях:
        - [Мои любимые юморески 🎩](https://huggingface.co/MesonWarrior/gpt2-vk-aneki)
        - [бугро тред 💥](https://huggingface.co/MesonWarrior/gpt2-vk-bugro)
        - [Калик) 🍏🍎💨](https://huggingface.co/MesonWarrior/gpt2-vk-kalik) <sub><sup>(Обучено на спорном датасете из постов и комментариев, надо бы переобучить на данных получше)</sup></sub>
        """)
        interface()

demo.queue().launch()