import os import torch import gradio as gr import requests from typing import List, Dict from threading import Lock from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel # --- 1. Конфигурация и загрузка модели --- # ID базовой модели BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a" # ID адаптера и ключи API загружаются из переменных окружения Render ADAPTER_ID = os.getenv("ADAPTER_ID") YANDEX_API_KEY = os.getenv("YANDEX_API_KEY") YANDEX_FOLDER_ID = os.getenv("YANDEX_FOLDER_ID") # Проверяем, что все переменные окружения установлены if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]): raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID") # Параметры генерации MAX_NEW_TOKENS = 256 TEMPERATURE = 0.7 TOP_P = 0.9 REPETITION_PENALTY = 1.05 SYS_PROMPT_TT = ( "Син - татар цифрлы ярдәмчесе. Татар телендә һәрвакыт ачык һәм дустанә җавап бир." "мәгълүмат җитәрлек булмаса, 1-2 кыска аныклаучы сорау бир. " "Һәрвакыт татарча гына җавап бир." ) print("Загрузка модели с 4-битной квантизацией...") # Используем квантизацию для экономии оперативной памяти quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 ) # Загружаем токенизатор из приватного репозитория # Библиотека transformers автоматически использует токен HF_TOKEN из переменных окружения tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False) if tok.pad_token is None: tok.pad_token = tok.eos_token base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, quantization_config=quantization_config, device_map="auto", ) print("Применяем LoRA адаптер...") model = PeftModel.from_pretrained(base, ADAPTER_ID) model.config.use_cache = True model.eval() print("✅ Модель успешно загружена!") # --- 2. Логика приложения (функции перевода и генерации) --- YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate" generation_lock = Lock() # Для обработки одного запроса за раз def _yandex_translate(texts: List[str], source: str, target: str) -> List[str]: headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"} payload = { "folderId": YANDEX_FOLDER_ID, "texts": texts, "sourceLanguageCode": source, "targetLanguageCode": target, } try: resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30) resp.raise_for_status() data = resp.json() return [item["text"] for item in data["translations"]] except requests.exceptions.RequestException as e: print(f"Ошибка перевода: {e}") return [f"Ошибка перевода: {text}" for text in texts] def ru2tt(text: str) -> str: return _yandex_translate([text], "ru", "tt")[0] def tt2ru(text: str) -> str: return _yandex_translate([text], "tt", "ru")[0] def render_prompt(messages: List[Dict[str, str]]) -> str: # Ваша функция рендеринга промпта без изменений if getattr(tok, "chat_template", None): try: return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception: pass sys_text = "" turns = [] for m in messages: if m["role"] == "system": sys_text += m["content"].strip() + "\n" i = 0 while i < len(messages): m = messages[i] if m["role"] == "user": next_assistant = None if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant": next_assistant = messages[i + 1]["content"] if len(turns) == 0 and sys_text: user_block = f"<>\n{sys_text.strip()}\n<>\n\n{m['content']}" else: user_block = m["content"] if next_assistant is None: turns.append(f"[INST] {user_block} [/INST]") else: turns.append(f"[INST] {user_block} [/INST] {next_assistant}") i += 1 i += 1 if not turns: return f"[INST] <>\n{sys_text.strip()}\n<>\n\n [/INST]" if sys_text else "[INST] [/INST]" return "".join(turns) @torch.inference_mode() def generate_tt_reply(messages: List[Dict[str, str]]) -> str: with generation_lock: prompt = render_prompt(messages) inputs = tok(prompt, return_tensors="pt").to(model.device) out = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=REPETITION_PENALTY, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, ) gen_text = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) return gen_text.strip() # --- 3. Gradio интерфейс --- def chat_fn(message, history): messages = [{"role": "system", "content": SYS_PROMPT_TT}] for user_msg, bot_msg in history: messages.append({"role": "user", "content": ru2tt(user_msg)}) messages.append({"role": "assistant", "content": ru2tt(bot_msg)}) user_tt = ru2tt(message) messages.append({"role": "user", "content": user_tt}) tt_reply = generate_tt_reply(messages) ru_reply = tt2ru(tt_reply) return ru_reply # Создаем и запускаем интерфейс with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Татарский Чат-Бот на базе Tweety-7B") chatbot = gr.Chatbot(label="Диалог", height=500) msg = gr.Textbox(label="Ваше сообщение (на русском)", placeholder="Как дела?") clear = gr.Button("🗑️ Очистить") msg.submit(chat_fn, [msg, chatbot], chatbot) clear.click(lambda: None, None, chatbot, queue=False) # server_name="0.0.0.0" и server_port=int(os.getenv("PORT", 7860)) важны для Render if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))