Spaces:
Running
Running
import logging | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer | |
logging.basicConfig( | |
level=logging.INFO, | |
format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', | |
datefmt='%H:%M:%S' | |
) | |
logger = logging.getLogger() | |
AVAILABLE_MODELS = { | |
"NLLB for transliteration": "kesha-humonen/tr-eng_checkpoint-8556", | |
"NLLB for hieroglyphs": "kesha-humonen/hi-eng_dpo_checkpoint-3342" | |
} | |
LANGUAGES = { | |
"English": "eng_Latn", | |
"German": "deu_Latn" | |
} | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
current_model = None | |
current_tokenizer = None | |
current_model_name = None | |
def unload_model(): | |
global current_model | |
if current_model is not None: | |
logger.info(f"Unloading current model: {current_model.name_or_path}") | |
del current_model | |
torch.cuda.empty_cache() | |
current_model = None | |
def load_model(model_name: str): | |
global current_model, current_tokenizer, current_model_name | |
unload_model() | |
logger.info(f"Loading model: {model_name}") | |
current_model = AutoModelForSeq2SeqLM.from_pretrained( | |
AVAILABLE_MODELS[model_name] | |
).to(device) | |
current_tokenizer = NllbTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) | |
current_model_name = model_name | |
return "The model has been uploaded successfully!" | |
def generate(input_texts: str, model_name: str, language: str) -> str: | |
"""Генерирует текст на основе входных данных""" | |
if current_model is None or current_tokenizer is None: | |
return "Please select and upload the model first." | |
if model_name == "NLLB for transliteration": | |
current_tokenizer.src_lang = 'egy_Tnt' | |
elif model_name == "NLLB for hieroglyphs": | |
current_tokenizer.src_lang = 'egy_Hiero' | |
encoded_inputs = current_tokenizer( | |
input_texts, | |
padding=True, | |
truncation=True, | |
return_tensors="pt" | |
).to(current_model.device) | |
with torch.no_grad(): | |
# Устанавливаем язык для вывода на основе выбора пользователя | |
forced_bos_token_id = current_tokenizer.convert_tokens_to_ids(LANGUAGES[language]) | |
generated_tokens = current_model.generate( | |
**encoded_inputs, | |
forced_bos_token_id=forced_bos_token_id, | |
num_beams=4, | |
early_stopping=True, | |
repetition_penalty=3.0 | |
) | |
output_text = current_tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
response = output_text.replace(f'{LANGUAGES[language]} ', '') | |
return response | |
def predict(model_choice, message, language): | |
global current_model_name | |
if current_model is None or model_choice != current_model_name: | |
load_model(model_choice) | |
return generate(message, model_choice, language) | |
demo = gr.Interface( | |
allow_flagging="never", | |
fn=predict, | |
inputs=[ | |
gr.Dropdown( | |
choices=list(AVAILABLE_MODELS.keys()), | |
label="Select a model", | |
value=list(AVAILABLE_MODELS.keys())[0] | |
), | |
gr.Textbox( | |
label="Enter the sentence using transliteration or hieroglyphs.", | |
placeholder="", | |
lines=3 | |
), | |
gr.Dropdown( | |
choices=list(LANGUAGES.keys()), | |
label="Select output language", | |
value="English" | |
) | |
], | |
outputs=[ | |
gr.Textbox( | |
label="Translation", | |
lines=10 | |
) | |
], | |
title="", | |
examples=[ | |
["NLLB for transliteration", "wn sbte nb ẖn =f", "English"], | |
["NLLB for hieroglyphs", "𓅭 𓆑 𓉐𓉻𓌕𓏌 𓋴𓌉𓂖 𓊪𓏏𓎛𓊵𓏏𓊪", "English"], | |
["NLLB for transliteration", "m wṯs jb =k n z", "German"], | |
["NLLB for hieroglyphs", "𓍹𓅃𓇋𓂓𓅱𓍺𓌳𓍘𓃫𓌸𓊖", "German"], | |
], | |
theme='base', | |
# :root { | |
# --bg: rgb(22,28,38) !important; | |
# --bg-dark: rgb(22,28,38) !important; | |
# --col: #f4f4f5 !important; | |
# --col-dark: #f4f4f5 !important; | |
# } | |
# body, .gradio-container, .gradio-container > div, .gradio-container .panel, .gradio-container .output, .gradio-container .input { | |
# background: rgb(22,28,38) !important; | |
# color: #f4f4f5 !important; | |
# margin: 0; | |
# padding: 0; | |
# height: 100%; | |
# width: 100%; | |
# overflow: hidden; | |
# } | |
css=""" | |
button { | |
background-color: rgb(15, 138, 129) !important; | |
color: white !important; | |
} | |
button:hover { | |
background-color: rgb(15, 138, 129) !important; | |
} | |
footer { | |
display: none !important; | |
} | |
""", | |
) | |
if __name__ == "__main__": | |
default_model = list(AVAILABLE_MODELS.keys())[0] | |
load_model(default_model) | |
demo.launch(share=False) |