|
import os |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel, PeftConfig |
|
|
|
MAX_NEW_TOKENS = 100 |
|
TEMPERATURE = 0.5 |
|
TOP_P = 0.95 |
|
TOP_K = 50 |
|
REPETITION_PENALTY = 1.05 |
|
SPECIAL_TOKEN = "->:" |
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
|
|
def load_model(): |
|
base_model_id = "meta-llama/Llama-2-7b-hf" |
|
peft_model_id = "somosnlp-hackathon-2025/Llama-2-7b-hf-lora-refranes" |
|
|
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_id, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
token=HF_TOKEN |
|
) |
|
|
|
model = PeftModel.from_pretrained(base_model, peft_model_id) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
return model, tokenizer |
|
|
|
model = None |
|
tokenizer = None |
|
|
|
def generate_response(input_text, max_tokens, temperature, top_p, repetition_penalty): |
|
global model, tokenizer |
|
|
|
if model is None or tokenizer is None: |
|
model, tokenizer = load_model() |
|
|
|
inputs = tokenizer(input_text + SPECIAL_TOKEN, return_tensors="pt").to(model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=top_p, |
|
top_k=TOP_K, |
|
repetition_penalty=repetition_penalty |
|
) |
|
|
|
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
if SPECIAL_TOKEN in full_response: |
|
response_parts = full_response.split(SPECIAL_TOKEN, 1) |
|
if len(response_parts) > 1: |
|
return response_parts[1].strip() |
|
|
|
return full_response.strip() |
|
|
|
def chat_interface(message, history, system_message, max_tokens, temperature, top_p, repetition_penalty): |
|
prompt = f"{message}" |
|
if system_message: |
|
prompt = f"{system_message}\n{message}" |
|
|
|
response = generate_response( |
|
prompt, |
|
max_tokens, |
|
temperature, |
|
top_p, |
|
repetition_penalty |
|
) |
|
return response |
|
|
|
demo = gr.ChatInterface( |
|
chat_interface, |
|
title="Sabiduría Popular - Refranes", |
|
description="Esta aplicación explica el significado de refranes en español utilizando un modelo de lenguaje. Escribe un refrán y el modelo te explicará su significado.", |
|
examples=[ |
|
["A caballo regalado no le mires el diente"], |
|
["Más vale pájaro en mano que ciento volando"], |
|
["Quien a buen árbol se arrima, buena sombra le cobija"], |
|
["No por mucho madrugar amanece más temprano"] |
|
], |
|
additional_inputs=[ |
|
gr.Textbox( |
|
value="Eres un experto en sabiduría popular española. Tu tarea es explicar el significado de refranes en español de manera clara y concisa.", |
|
label="System message" |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=500, |
|
value=MAX_NEW_TOKENS, |
|
step=1, |
|
label="Max new tokens" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=TEMPERATURE, |
|
step=0.1, |
|
label="Temperature" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=TOP_P, |
|
step=0.05, |
|
label="Top-p (nucleus sampling)" |
|
), |
|
gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=REPETITION_PENALTY, |
|
step=0.05, |
|
label="Repetition penalty" |
|
), |
|
], |
|
theme="soft" |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("Iniciando la aplicación. El modelo se cargará con la primera consulta.") |
|
demo.launch() |