Nefertury commited on
Commit
3297f8d
·
verified ·
1 Parent(s): aecd924

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import requests
5
+ from typing import List, Dict
6
+ from threading import Lock
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
9
+ from peft import PeftModel
10
+
11
+ # --- 1. Конфигурация и загрузка модели ---
12
+
13
+ # ID базовой модели
14
+ BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
15
+
16
+ # ID адаптера и ключи API загружаются из переменных окружения Render
17
+ ADAPTER_ID = os.getenv("ADAPTER_ID")
18
+ YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
19
+ YANDEX_FOLDER_ID = os.getenv("YANDEX_FOLDER_ID")
20
+
21
+ # Проверяем, что все переменные окружения установлены
22
+ if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
23
+ raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
24
+
25
+ # Параметры генерации
26
+ MAX_NEW_TOKENS = 256
27
+ TEMPERATURE = 0.7
28
+ TOP_P = 0.9
29
+ REPETITION_PENALTY = 1.05
30
+ SYS_PROMPT_TT = (
31
+ "Син - татар цифрлы ярдәмчесе. Татар телендә һәрвакыт ачык һәм дустанә җавап бир."
32
+ "мәгълүмат җитәрлек булмаса, 1-2 кыска аныклаучы сорау бир. "
33
+ "Һәрвакыт татарча гына җавап бир."
34
+ )
35
+
36
+ print("Загрузка модели с 4-битной квантизацией...")
37
+ # Используем квантизацию для экономии оперативной памяти
38
+ quantization_config = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_compute_dtype=torch.bfloat16
41
+ )
42
+
43
+ # Загружаем токенизатор из приватного репозитория
44
+ # Библиотека transformers автоматически использует токен HF_TOKEN из переменных окружения
45
+ tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
46
+ if tok.pad_token is None:
47
+ tok.pad_token = tok.eos_token
48
+
49
+ base = AutoModelForCausalLM.from_pretrained(
50
+ BASE_MODEL_ID,
51
+ quantization_config=quantization_config,
52
+ device_map="auto",
53
+ )
54
+
55
+ print("Применяем LoRA адаптер...")
56
+ model = PeftModel.from_pretrained(base, ADAPTER_ID)
57
+ model.config.use_cache = True
58
+ model.eval()
59
+ print("✅ Модель успешно загружена!")
60
+
61
+
62
+ # --- 2. Логика приложения (функции перевода и генерации) ---
63
+
64
+ YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
65
+ generation_lock = Lock() # Для обработки одного запроса за раз
66
+
67
+ def _yandex_translate(texts: List[str], source: str, target: str) -> List[str]:
68
+ headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
69
+ payload = {
70
+ "folderId": YANDEX_FOLDER_ID,
71
+ "texts": texts,
72
+ "sourceLanguageCode": source,
73
+ "targetLanguageCode": target,
74
+ }
75
+ try:
76
+ resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
77
+ resp.raise_for_status()
78
+ data = resp.json()
79
+ return [item["text"] for item in data["translations"]]
80
+ except requests.exceptions.RequestException as e:
81
+ print(f"Ошибка перевода: {e}")
82
+ return [f"Ошибка перевода: {text}" for text in texts]
83
+
84
+ def ru2tt(text: str) -> str:
85
+ return _yandex_translate([text], "ru", "tt")[0]
86
+
87
+ def tt2ru(text: str) -> str:
88
+ return _yandex_translate([text], "tt", "ru")[0]
89
+
90
+ def render_prompt(messages: List[Dict[str, str]]) -> str:
91
+ # Ваша функция рендеринга промпта без изменений
92
+ if getattr(tok, "chat_template", None):
93
+ try:
94
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
95
+ except Exception:
96
+ pass
97
+ sys_text = ""
98
+ turns = []
99
+ for m in messages:
100
+ if m["role"] == "system":
101
+ sys_text += m["content"].strip() + "\n"
102
+ i = 0
103
+ while i < len(messages):
104
+ m = messages[i]
105
+ if m["role"] == "user":
106
+ next_assistant = None
107
+ if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant":
108
+ next_assistant = messages[i + 1]["content"]
109
+ if len(turns) == 0 and sys_text:
110
+ user_block = f"<<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n{m['content']}"
111
+ else:
112
+ user_block = m["content"]
113
+ if next_assistant is None:
114
+ turns.append(f"<s>[INST] {user_block} [/INST]")
115
+ else:
116
+ turns.append(f"<s>[INST] {user_block} [/INST] {next_assistant}</s>")
117
+ i += 1
118
+ i += 1
119
+ if not turns:
120
+ return f"<s>[INST] <<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n [/INST]" if sys_text else "<s>[INST] [/INST]"
121
+ return "".join(turns)
122
+
123
+ @torch.inference_mode()
124
+ def generate_tt_reply(messages: List[Dict[str, str]]) -> str:
125
+ with generation_lock:
126
+ prompt = render_prompt(messages)
127
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
128
+ out = model.generate(
129
+ **inputs,
130
+ max_new_tokens=MAX_NEW_TOKENS,
131
+ do_sample=True,
132
+ temperature=TEMPERATURE,
133
+ top_p=TOP_P,
134
+ repetition_penalty=REPETITION_PENALTY,
135
+ eos_token_id=tok.eos_token_id,
136
+ pad_token_id=tok.pad_token_id,
137
+ )
138
+ gen_text = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
139
+ return gen_text.strip()
140
+
141
+ # --- 3. Gradio интерфейс ---
142
+
143
+ def chat_fn(message, history):
144
+ messages = [{"role": "system", "content": SYS_PROMPT_TT}]
145
+ for user_msg, bot_msg in history:
146
+ messages.append({"role": "user", "content": ru2tt(user_msg)})
147
+ messages.append({"role": "assistant", "content": ru2tt(bot_msg)})
148
+
149
+ user_tt = ru2tt(message)
150
+ messages.append({"role": "user", "content": user_tt})
151
+
152
+ tt_reply = generate_tt_reply(messages)
153
+ ru_reply = tt2ru(tt_reply)
154
+
155
+ return ru_reply
156
+
157
+ # Создаем и запускаем интерфейс
158
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
+ gr.Markdown("## Татарский Чат-Бот на базе Tweety-7B")
160
+ chatbot = gr.Chatbot(label="Диалог", height=500)
161
+ msg = gr.Textbox(label="Ваше сообщение (на русском)", placeholder="Как дела?")
162
+ clear = gr.Button("🗑️ Очистить")
163
+
164
+ msg.submit(chat_fn, [msg, chatbot], chatbot)
165
+ clear.click(lambda: None, None, chatbot, queue=False)
166
+
167
+ # server_name="0.0.0.0" и server_port=int(os.getenv("PORT", 7860)) важны для Render
168
+ if __name__ == "__main__":
169
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))