Nefertury AccessAndrei commited on
Commit
2f2eda6
·
verified ·
1 Parent(s): 7327516

version updates (#1)

Browse files

- version updates (cfb35e64155b756114fb1ebc8f6d8e5446356781)


Co-authored-by: Aksenov Andrei <[email protected]>

Files changed (1) hide show
  1. app.py +84 -96
app.py CHANGED
@@ -7,42 +7,45 @@ from threading import Thread
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
8
  from peft import PeftModel
9
 
10
- # --- 1. Конфигурация и загрузка модели ---
11
-
12
- BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
13
- ADAPTER_ID = os.getenv("ADAPTER_ID")
14
- YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
15
- YANDEX_FOLDER_ID = os.getenv("YANDEX_FOLDER_ID")
16
 
17
  if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
18
  raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
19
 
20
- MAX_NEW_TOKENS = 256
21
- TEMPERATURE = 0.7
22
- TOP_P = 0.9
23
- REPETITION_PENALTY = 1.05
24
- SYS_PROMPT_TT = (
25
- "Син - татар цифрлы ярдәмчесе. Татар телендә һәрвакыт ачык һәм дустанә җавап бир."
26
- "мәгълүmat җитәрлек булмаса, 1-2 кыска аныклаучы сорау бир. "
27
- "Һәрвакыт татарча гына җавап бир."
28
  )
29
 
30
  print("Загрузка модели с 4-битной квантизацией...")
31
  quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
 
32
  tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
33
  if tok.pad_token is None:
34
  tok.pad_token = tok.eos_token
35
- base = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, quantization_config=quantization_config, device_map="auto")
 
 
 
 
 
 
 
36
  print("Применяем LoRA адаптер...")
37
  model = PeftModel.from_pretrained(base, ADAPTER_ID)
38
- model.config.use_cache = True
39
  model.eval()
40
  print("✅ Модель успешно загружена!")
41
 
42
- # --- 2. Логика приложения (с изменениями для стриминга) ---
43
-
44
  YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
45
- YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
46
 
47
  def detect_language(text: str) -> str:
48
  headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
@@ -50,10 +53,8 @@ def detect_language(text: str) -> str:
50
  try:
51
  resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
52
  resp.raise_for_status()
53
- data = resp.json()
54
- return data.get("languageCode", "ru")
55
- except requests.exceptions.RequestException as e:
56
- print(f"Ошибка определения языка: {e}")
57
  return "ru"
58
 
59
  def ru2tt(text: str) -> str:
@@ -63,48 +64,27 @@ def ru2tt(text: str) -> str:
63
  resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
64
  resp.raise_for_status()
65
  return resp.json()["translations"][0]["text"]
66
- except requests.exceptions.RequestException as e:
67
- print(f"Ошибка перевода: {e}")
68
- return f"Ошибка перевода: {text}"
69
 
70
  def render_prompt(messages: List[Dict[str, str]]) -> str:
71
- # Ваша функция render_prompt остается без изменений
72
- if getattr(tok, "chat_template", None):
73
- try:
74
- return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
- except Exception: pass
76
- sys_text = ""
77
- turns = []
78
- for m in messages:
79
- if m["role"] == "system": sys_text += m["content"].strip() + "\n"
80
- i = 0
81
- while i < len(messages):
82
- m = messages[i]
83
- if m["role"] == "user":
84
- next_assistant = None
85
- if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant":
86
- next_assistant = messages[i + 1]["content"]
87
- user_block = f"<<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n{m['content']}" if len(turns) == 0 and sys_text else m['content']
88
- if next_assistant is None:
89
- turns.append(f"<s>[INST] {user_block} [/INST]")
90
- else:
91
- turns.append(f"<s>[INST] {user_block} [/INST] {next_assistant}</s>")
92
- i += 1
93
- i += 1
94
- return "".join(turns) if turns else (f"<s>[INST] <<SYS>>\n{sys_text.strip()}\n<</SYS>>\n\n [/INST]" if sys_text else "<s>[INST] [/INST]")
95
-
96
- # ❗ ИЗМЕНЕННАЯ ФУНКЦИЯ ГЕНЕРАЦИИ
97
  @torch.inference_mode()
98
  def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
99
  prompt = render_prompt(messages)
100
- inputs = tok(prompt, return_tensors="pt").to(model.device)
101
-
102
- # Создаем streamer
103
- streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
104
 
105
- # Аргументы для генерации
106
- generation_kwargs = dict(
107
- inputs,
108
  streamer=streamer,
109
  max_new_tokens=MAX_NEW_TOKENS,
110
  do_sample=True,
@@ -114,52 +94,60 @@ def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
114
  eos_token_id=tok.eos_token_id,
115
  pad_token_id=tok.pad_token_id,
116
  )
117
-
118
- # Запускаем генерацию в отдельном потоке
119
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
120
  thread.start()
121
 
122
- # Yield'им каждый новый кусочек текста
123
- generated_text = ""
124
- for new_text in streamer:
125
- generated_text += new_text
126
- yield generated_text
127
-
128
- # --- 3. Gradio интерфейс (с изменениями для стриминга) ---
129
-
130
- # ИЗМЕНЕННАЯ ФУНКЦИЯ-КОНТРОЛЛЕР
131
- def chat_fn(message: str, history: list) -> Iterator[list]:
132
- # 1. Формируем историю для модели
133
- messages = [{"role": "system", "content": SYS_PROMPT_TT}]
134
- for user_msg, bot_msg in history:
135
- messages.append({"role": "user", "content": user_msg})
136
- if bot_msg:
137
- messages.append({"role": "assistant", "content": bot_msg})
138
-
139
- # 2. Определяем язык и переводим, если нужно
140
- detected_lang = detect_language(message)
141
- user_tt = ru2tt(message) if detected_lang != "tt" else message
142
-
143
- messages.append({"role": "user", "content": user_tt})
144
 
145
- # 3. Добавляем в историю сообщение пользователя и пустой ответ бота
146
- history.append([user_tt, ""])
 
 
 
 
147
 
148
- # 4. Стримим ответ модели и обновляем историю на лету
149
- for partial_response in generate_tt_reply_stream(messages):
150
- history[-1][1] = partial_response # Обновляем последнее сообщение в истории
151
- yield history # Возвращаем всю историю на каждом шаге
152
 
153
- # Создаем интерфейс
154
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
155
  gr.Markdown("## Татарский чат-бот от команды Сбера")
 
 
 
156
  chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False)
157
- msg = gr.Textbox(label="Хәбәрегезне рус яки татар телендә языгыз", placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?")
 
 
 
158
  clear = gr.Button("🗑️ Чистарту")
159
 
160
- msg.submit(chat_fn, inputs=[msg, chatbot], outputs=chatbot)
161
- clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
162
 
163
- # Запуск приложения
164
  if __name__ == "__main__":
165
- demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
8
  from peft import PeftModel
9
 
10
+ # --- 1) Конфигурация и загрузка модели ---
11
+ BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
12
+ ADAPTER_ID = os.getenv("ADAPTER_ID")
13
+ YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
14
+ YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID")
 
15
 
16
  if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
17
  raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
18
 
19
+ MAX_NEW_TOKENS = 256
20
+ TEMPERATURE = 0.7
21
+ TOP_P = 0.9
22
+ REPETITION_PENALTY = 1.05
23
+
24
+ SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә дә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү. Кулланучыга, фактлардан һәм саннардан качып, һәрвакыт кыска җавап бирергә кирәк"
 
 
25
  )
26
 
27
  print("Загрузка модели с 4-битной квантизацией...")
28
  quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
29
+
30
  tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
31
  if tok.pad_token is None:
32
  tok.pad_token = tok.eos_token
33
+
34
+
35
+ base = AutoModelForCausalLM.from_pretrained(
36
+ BASE_MODEL_ID,
37
+ quantization_config=quantization_config,
38
+ device_map="auto"
39
+ )
40
+
41
  print("Применяем LoRA адаптер...")
42
  model = PeftModel.from_pretrained(base, ADAPTER_ID)
43
+ model.config.use_cache = False
44
  model.eval()
45
  print("✅ Модель успешно загружена!")
46
 
 
 
47
  YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
48
+ YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
49
 
50
  def detect_language(text: str) -> str:
51
  headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
 
53
  try:
54
  resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
55
  resp.raise_for_status()
56
+ return resp.json().get("languageCode", "ru")
57
+ except requests.exceptions.RequestException:
 
 
58
  return "ru"
59
 
60
  def ru2tt(text: str) -> str:
 
64
  resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
65
  resp.raise_for_status()
66
  return resp.json()["translations"][0]["text"]
67
+ except requests.exceptions.RequestException:
68
+ return text
 
69
 
70
  def render_prompt(messages: List[Dict[str, str]]) -> str:
71
+ return tok.apply_chat_template(
72
+ messages,
73
+ tokenize=False,
74
+ add_generation_prompt=True
75
+ )
76
+
77
+
78
+ # --- 4) Стриминговая генерация (без тримминга) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  @torch.inference_mode()
80
  def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
81
  prompt = render_prompt(messages)
82
+ enc = tok(prompt, return_tensors="pt")
83
+ enc = {k: v.to(model.device) for k, v in enc.items()}
 
 
84
 
85
+ streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
86
+ gen_kwargs = dict(
87
+ **enc,
88
  streamer=streamer,
89
  max_new_tokens=MAX_NEW_TOKENS,
90
  do_sample=True,
 
94
  eos_token_id=tok.eos_token_id,
95
  pad_token_id=tok.pad_token_id,
96
  )
97
+
98
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
 
99
  thread.start()
100
 
101
+ acc = ""
102
+ for chunk in streamer:
103
+ acc += chunk
104
+ yield acc
105
+
106
+
107
+
108
+ def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]):
109
+ if not messages_state or messages_state[0].get("role") != "system":
110
+ messages_state = [{"role": "system", "content": SYS_PROMPT_TT}]
111
+
112
+ detected = detect_language(message)
113
+ user_tt = ru2tt(message) if detected != "tt" else message
114
+
115
+ messages = messages_state + [{"role": "user", "content": user_tt}]
116
+ ui_history = ui_history + [[user_tt, ""]]
 
 
 
 
 
 
117
 
118
+ for partial in generate_tt_reply_stream(messages):
119
+ ui_history[-1][1] = partial
120
+ yield ui_history, messages_state + [
121
+ {"role": "user", "content": user_tt},
122
+ {"role": "assistant", "content": partial},
123
+ ]
124
 
 
 
 
 
125
 
 
126
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
127
  gr.Markdown("## Татарский чат-бот от команды Сбера")
128
+
129
+ messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}])
130
+
131
  chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False)
132
+ msg = gr.Textbox(
133
+ label="Хәбәрегезне рус яки татар телендә языгыз",
134
+ placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?"
135
+ )
136
  clear = gr.Button("🗑️ Чистарту")
137
 
138
+ msg.submit(
139
+ chat_fn,
140
+ inputs=[msg, chatbot, messages_state],
141
+ outputs=[chatbot, messages_state],
142
+ )
143
+
144
+ msg.submit(lambda: "", None, msg)
145
+
146
+ def _reset():
147
+ return [], [{"role": "system", "content": SYS_PROMPT_TT}]
148
+
149
+ clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False)
150
+ clear.click(lambda: "", None, msg, queue=False)
151
 
 
152
  if __name__ == "__main__":
153
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))