|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
import requests |
|
|
from typing import List, Dict, Iterator |
|
|
from threading import Thread |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer |
|
|
from peft import PeftModel |
|
|
import json |
|
|
|
|
|
BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a" |
|
|
ADAPTER_ID = os.getenv("ADAPTER_ID") |
|
|
YANDEX_API_KEY = os.getenv("YANDEX_API_KEY") |
|
|
YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID") |
|
|
|
|
|
if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]): |
|
|
raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID") |
|
|
|
|
|
MAX_NEW_TOKENS = 2048 |
|
|
TEMPERATURE = 1 |
|
|
TOP_P = 0.9 |
|
|
REPETITION_PENALTY = 1.05 |
|
|
|
|
|
SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.") |
|
|
|
|
|
print("Загрузка модели с 4-битной квантизацией...") |
|
|
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
|
|
|
|
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
quantization_config=quantization_config, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
print("Применяем LoRA адаптер...") |
|
|
model = PeftModel.from_pretrained(base, ADAPTER_ID) |
|
|
model.config.use_cache = False |
|
|
model.eval() |
|
|
print("✅ Модель успешно загружена!") |
|
|
|
|
|
YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate" |
|
|
YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect" |
|
|
|
|
|
def detect_language(text: str) -> str: |
|
|
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"} |
|
|
payload = {"folderId": YANDEX_FOLDER_ID, "text": text} |
|
|
try: |
|
|
resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10) |
|
|
resp.raise_for_status() |
|
|
return resp.json().get("languageCode", "ru") |
|
|
except requests.exceptions.RequestException: |
|
|
return "ru" |
|
|
|
|
|
def ru2tt(text: str) -> str: |
|
|
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"} |
|
|
payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"} |
|
|
try: |
|
|
resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30) |
|
|
resp.raise_for_status() |
|
|
return resp.json()["translations"][0]["text"] |
|
|
except requests.exceptions.RequestException: |
|
|
return text |
|
|
|
|
|
def render_prompt(messages: List[Dict[str, str]]) -> str: |
|
|
return tok.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]: |
|
|
prompt = render_prompt(messages) |
|
|
enc = tok(prompt, return_tensors="pt") |
|
|
enc = {k: v.to(model.device) for k, v in enc.items()} |
|
|
|
|
|
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) |
|
|
gen_kwargs = dict( |
|
|
**enc, |
|
|
streamer=streamer, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
do_sample=True, |
|
|
temperature=TEMPERATURE, |
|
|
top_p=TOP_P, |
|
|
repetition_penalty=REPETITION_PENALTY, |
|
|
eos_token_id=tok.eos_token_id, |
|
|
pad_token_id=tok.pad_token_id, |
|
|
) |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=gen_kwargs) |
|
|
thread.start() |
|
|
|
|
|
acc = "" |
|
|
for chunk in streamer: |
|
|
acc += chunk |
|
|
yield acc |
|
|
|
|
|
|
|
|
|
|
|
def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]): |
|
|
if not messages_state or messages_state[0].get("role") != "system": |
|
|
messages_state = [{"role": "system", "content": SYS_PROMPT_TT}] |
|
|
|
|
|
detected = detect_language(message) |
|
|
user_tt = ru2tt(message) if detected != "tt" else message |
|
|
|
|
|
messages = messages_state + [{"role": "user", "content": user_tt}] |
|
|
ui_history = ui_history + [[user_tt, ""]] |
|
|
|
|
|
last = "" |
|
|
for partial in generate_tt_reply_stream(messages): |
|
|
last = partial |
|
|
ui_history[-1][1] = partial |
|
|
yield ui_history, messages_state + [ |
|
|
{"role": "user", "content": user_tt}, |
|
|
{"role": "assistant", "content": partial}, |
|
|
] |
|
|
|
|
|
final_state = messages + [{"role": "assistant", "content": last}] |
|
|
print("STATE:", json.dumps(final_state, ensure_ascii=False)) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("## Татарский чат-бот от команды Сбера") |
|
|
|
|
|
messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}]) |
|
|
|
|
|
chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False) |
|
|
msg = gr.Textbox( |
|
|
label="Хәбәрегезне рус яки татар телендә языгыз", |
|
|
placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?" |
|
|
) |
|
|
clear = gr.Button("🗑️ Чистарту") |
|
|
|
|
|
msg.submit( |
|
|
chat_fn, |
|
|
inputs=[msg, chatbot, messages_state], |
|
|
outputs=[chatbot, messages_state], |
|
|
) |
|
|
|
|
|
msg.submit(lambda: "", None, msg) |
|
|
|
|
|
def _reset(): |
|
|
return [], [{"role": "system", "content": SYS_PROMPT_TT}] |
|
|
|
|
|
clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False) |
|
|
clear.click(lambda: "", None, msg, queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) |
|
|
|