import os import torch import gradio as gr import requests from typing import List, Dict, Iterator from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer from peft import PeftModel import json BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a" ADAPTER_ID = os.getenv("ADAPTER_ID") YANDEX_API_KEY = os.getenv("YANDEX_API_KEY") YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID") if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]): raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID") MAX_NEW_TOKENS = 2048 TEMPERATURE = 1 TOP_P = 0.9 REPETITION_PENALTY = 1.05 SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.") print("Загрузка модели с 4-битной квантизацией...") quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False) if tok.pad_token is None: tok.pad_token = tok.eos_token base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, quantization_config=quantization_config, device_map="auto" ) print("Применяем LoRA адаптер...") model = PeftModel.from_pretrained(base, ADAPTER_ID) model.config.use_cache = False model.eval() print("✅ Модель успешно загружена!") YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate" YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect" def detect_language(text: str) -> str: headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"} payload = {"folderId": YANDEX_FOLDER_ID, "text": text} try: resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10) resp.raise_for_status() return resp.json().get("languageCode", "ru") except requests.exceptions.RequestException: return "ru" def ru2tt(text: str) -> str: headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"} payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"} try: resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30) resp.raise_for_status() return resp.json()["translations"][0]["text"] except requests.exceptions.RequestException: return text def render_prompt(messages: List[Dict[str, str]]) -> str: return tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # --- 4) Стриминговая генерация (без тримминга) --- @torch.inference_mode() def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]: prompt = render_prompt(messages) enc = tok(prompt, return_tensors="pt") enc = {k: v.to(model.device) for k, v in enc.items()} streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **enc, streamer=streamer, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, ) thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() acc = "" for chunk in streamer: acc += chunk yield acc def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]): if not messages_state or messages_state[0].get("role") != "system": messages_state = [{"role": "system", "content": SYS_PROMPT_TT}] detected = detect_language(message) user_tt = ru2tt(message) if detected != "tt" else message messages = messages_state + [{"role": "user", "content": user_tt}] ui_history = ui_history + [[user_tt, ""]] last = "" for partial in generate_tt_reply_stream(messages): last = partial ui_history[-1][1] = partial yield ui_history, messages_state + [ {"role": "user", "content": user_tt}, {"role": "assistant", "content": partial}, ] final_state = messages + [{"role": "assistant", "content": last}] print("STATE:", json.dumps(final_state, ensure_ascii=False)) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Татарский чат-бот от команды Сбера") messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}]) chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False) msg = gr.Textbox( label="Хәбәрегезне рус яки татар телендә языгыз", placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?" ) clear = gr.Button("🗑️ Чистарту") msg.submit( chat_fn, inputs=[msg, chatbot, messages_state], outputs=[chatbot, messages_state], ) msg.submit(lambda: "", None, msg) def _reset(): return [], [{"role": "system", "content": SYS_PROMPT_TT}] clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False) clear.click(lambda: "", None, msg, queue=False) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))