glodov commited on
Commit
54e2073
·
verified ·
1 Parent(s): 654e108
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
3
+ from datasets import Dataset
4
+ import torch
5
+
6
+ MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v0.6"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
10
+
11
+ if torch.cuda.is_available():
12
+ model.to("cuda")
13
+
14
+ history = []
15
+
16
+ def chat_fn(message, chat_history):
17
+ inputs = tokenizer.encode(message, return_tensors="pt")
18
+ if torch.cuda.is_available():
19
+ inputs = inputs.to("cuda")
20
+ outputs = model.generate(inputs, max_new_tokens=128, do_sample=True, top_p=0.9)
21
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
22
+ chat_history.append((message, response))
23
+ return "", chat_history
24
+
25
+ def train_model(text):
26
+ dataset = Dataset.from_dict({"text": [text]})
27
+ tokenized = dataset.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=128), batched=True)
28
+
29
+ args = TrainingArguments(
30
+ output_dir="./results",
31
+ num_train_epochs=1,
32
+ per_device_train_batch_size=1,
33
+ save_steps=10,
34
+ logging_steps=5,
35
+ report_to="none",
36
+ fp16=torch.cuda.is_available()
37
+ )
38
+
39
+ trainer = Trainer(
40
+ model=model,
41
+ args=args,
42
+ train_dataset=tokenized,
43
+ data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
44
+ )
45
+
46
+ trainer.train()
47
+ return "✅ Модель донавчена на вашому тексті!"
48
+
49
+ chat_ui = gr.ChatInterface(fn=chat_fn)
50
+
51
+ train_ui = gr.Interface(
52
+ fn=train_model,
53
+ inputs=gr.Textbox(lines=10, label="Введіть текст для донавчання"),
54
+ outputs="text",
55
+ )
56
+
57
+ gr.TabbedInterface([chat_ui, train_ui], ["💬 Chat", "🧠 Train"]).launch()