import os
import torch
import gradio as gr
import requests
from typing import List, Dict
from threading import Lock
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
# --- 1. Конфигурация и загрузка модели ---
# ID базовой модели
BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
# ID адаптера и ключи API загружаются из переменных окружения Render
ADAPTER_ID = os.getenv("ADAPTER_ID")
YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
YANDEX_FOLDER_ID = os.getenv("YANDEX_FOLDER_ID")
# Проверяем, что все переменные окружения установлены
if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
# Параметры генерации
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.9
REPETITION_PENALTY = 1.05
SYS_PROMPT_TT = (
"Син - татар цифрлы ярдәмчесе. Татар телендә һәрвакыт ачык һәм дустанә җавап бир."
"мәгълүмат җитәрлек булмаса, 1-2 кыска аныклаучы сорау бир. "
"Һәрвакыт татарча гына җавап бир."
)
print("Загрузка модели с 4-битной квантизацией...")
# Используем квантизацию для экономии оперативной памяти
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# Загружаем токенизатор из приватного репозитория
# Библиотека transformers автоматически использует токен HF_TOKEN из переменных окружения
tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
quantization_config=quantization_config,
device_map="auto",
)
print("Применяем LoRA адаптер...")
model = PeftModel.from_pretrained(base, ADAPTER_ID)
model.config.use_cache = True
model.eval()
print("✅ Модель успешно загружена!")
# --- 2. Логика приложения (функции перевода и генерации) ---
YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
generation_lock = Lock() # Для обработки одного запроса за раз
def _yandex_translate(texts: List[str], source: str, target: str) -> List[str]:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {
"folderId": YANDEX_FOLDER_ID,
"texts": texts,
"sourceLanguageCode": source,
"targetLanguageCode": target,
}
try:
resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
return [item["text"] for item in data["translations"]]
except requests.exceptions.RequestException as e:
print(f"Ошибка перевода: {e}")
return [f"Ошибка перевода: {text}" for text in texts]
def ru2tt(text: str) -> str:
return _yandex_translate([text], "ru", "tt")[0]
def tt2ru(text: str) -> str:
return _yandex_translate([text], "tt", "ru")[0]
def render_prompt(messages: List[Dict[str, str]]) -> str:
# Ваша функция рендеринга промпта без изменений
if getattr(tok, "chat_template", None):
try:
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
except Exception:
pass
sys_text = ""
turns = []
for m in messages:
if m["role"] == "system":
sys_text += m["content"].strip() + "\n"
i = 0
while i < len(messages):
m = messages[i]
if m["role"] == "user":
next_assistant = None
if i + 1 < len(messages) and messages[i + 1]["role"] == "assistant":
next_assistant = messages[i + 1]["content"]
if len(turns) == 0 and sys_text:
user_block = f"<>\n{sys_text.strip()}\n<>\n\n{m['content']}"
else:
user_block = m["content"]
if next_assistant is None:
turns.append(f"[INST] {user_block} [/INST]")
else:
turns.append(f"[INST] {user_block} [/INST] {next_assistant}")
i += 1
i += 1
if not turns:
return f"[INST] <>\n{sys_text.strip()}\n<>\n\n [/INST]" if sys_text else "[INST] [/INST]"
return "".join(turns)
@torch.inference_mode()
def generate_tt_reply(messages: List[Dict[str, str]]) -> str:
with generation_lock:
prompt = render_prompt(messages)
inputs = tok(prompt, return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
gen_text = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return gen_text.strip()
# --- 3. Gradio интерфейс ---
def chat_fn(message, history):
messages = [{"role": "system", "content": SYS_PROMPT_TT}]
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": ru2tt(user_msg)})
messages.append({"role": "assistant", "content": ru2tt(bot_msg)})
user_tt = ru2tt(message)
messages.append({"role": "user", "content": user_tt})
tt_reply = generate_tt_reply(messages)
ru_reply = tt2ru(tt_reply)
return ru_reply
# Создаем и запускаем интерфейс
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Татарский Чат-Бот на базе Tweety-7B")
chatbot = gr.Chatbot(label="Диалог", height=500)
msg = gr.Textbox(label="Ваше сообщение (на русском)", placeholder="Как дела?")
clear = gr.Button("🗑️ Очистить")
msg.submit(chat_fn, [msg, chatbot], chatbot)
clear.click(lambda: None, None, chatbot, queue=False)
# server_name="0.0.0.0" и server_port=int(os.getenv("PORT", 7860)) важны для Render
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))