webui / app.py
glodov's picture
app.py
54e2073 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset
import torch
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v0.6"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
if torch.cuda.is_available():
model.to("cuda")
history = []
def chat_fn(message, chat_history):
inputs = tokenizer.encode(message, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to("cuda")
outputs = model.generate(inputs, max_new_tokens=128, do_sample=True, top_p=0.9)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
chat_history.append((message, response))
return "", chat_history
def train_model(text):
dataset = Dataset.from_dict({"text": [text]})
tokenized = dataset.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=128), batched=True)
args = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=1,
save_steps=10,
logging_steps=5,
report_to="none",
fp16=torch.cuda.is_available()
)
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
)
trainer.train()
return "✅ Модель донавчена на вашому тексті!"
chat_ui = gr.ChatInterface(fn=chat_fn)
train_ui = gr.Interface(
fn=train_model,
inputs=gr.Textbox(lines=10, label="Введіть текст для донавчання"),
outputs="text",
)
gr.TabbedInterface([chat_ui, train_ui], ["💬 Chat", "🧠 Train"]).launch()