AccessAndrei's picture
prompt, cfg changes
0c085b6 verified
raw
history blame
6.26 kB
import os
import torch
import gradio as gr
import requests
from typing import List, Dict, Iterator
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
from peft import PeftModel
import json
BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
ADAPTER_ID = os.getenv("ADAPTER_ID")
YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID")
if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
MAX_NEW_TOKENS = 2048
TEMPERATURE = 1
TOP_P = 0.9
REPETITION_PENALTY = 1.05
SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.")
print("Загрузка модели с 4-битной квантизацией...")
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
quantization_config=quantization_config,
device_map="auto"
)
print("Применяем LoRA адаптер...")
model = PeftModel.from_pretrained(base, ADAPTER_ID)
model.config.use_cache = False
model.eval()
print("✅ Модель успешно загружена!")
YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
def detect_language(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "text": text}
try:
resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
resp.raise_for_status()
return resp.json().get("languageCode", "ru")
except requests.exceptions.RequestException:
return "ru"
def ru2tt(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"}
try:
resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
return resp.json()["translations"][0]["text"]
except requests.exceptions.RequestException:
return text
def render_prompt(messages: List[Dict[str, str]]) -> str:
return tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# --- 4) Стриминговая генерация (без тримминга) ---
@torch.inference_mode()
def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
prompt = render_prompt(messages)
enc = tok(prompt, return_tensors="pt")
enc = {k: v.to(model.device) for k, v in enc.items()}
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**enc,
streamer=streamer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
acc = ""
for chunk in streamer:
acc += chunk
yield acc
def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]):
if not messages_state or messages_state[0].get("role") != "system":
messages_state = [{"role": "system", "content": SYS_PROMPT_TT}]
detected = detect_language(message)
user_tt = ru2tt(message) if detected != "tt" else message
messages = messages_state + [{"role": "user", "content": user_tt}]
ui_history = ui_history + [[user_tt, ""]]
last = ""
for partial in generate_tt_reply_stream(messages):
last = partial
ui_history[-1][1] = partial
yield ui_history, messages_state + [
{"role": "user", "content": user_tt},
{"role": "assistant", "content": partial},
]
final_state = messages + [{"role": "assistant", "content": last}]
print("STATE:", json.dumps(final_state, ensure_ascii=False))
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Татарский чат-бот от команды Сбера")
messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}])
chatbot = gr.Chatbot(label="Диалог", height=500, bubble_full_width=False)
msg = gr.Textbox(
label="Хәбәрегезне рус яки татар телендә языгыз",
placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?"
)
clear = gr.Button("🗑️ Чистарту")
msg.submit(
chat_fn,
inputs=[msg, chatbot, messages_state],
outputs=[chatbot, messages_state],
)
msg.submit(lambda: "", None, msg)
def _reset():
return [], [{"role": "system", "content": SYS_PROMPT_TT}]
clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False)
clear.click(lambda: "", None, msg, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))