middle_egyptian / app.py
kesha-humonen's picture
Update app.py
05beb09 verified
import logging
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
datefmt='%H:%M:%S'
)
logger = logging.getLogger()
AVAILABLE_MODELS = {
"NLLB for transliteration": "kesha-humonen/tr-eng_checkpoint-8556",
"NLLB for hieroglyphs": "kesha-humonen/hi-eng_dpo_checkpoint-3342"
}
LANGUAGES = {
"English": "eng_Latn",
"German": "deu_Latn"
}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
current_model = None
current_tokenizer = None
current_model_name = None
def unload_model():
global current_model
if current_model is not None:
logger.info(f"Unloading current model: {current_model.name_or_path}")
del current_model
torch.cuda.empty_cache()
current_model = None
def load_model(model_name: str):
global current_model, current_tokenizer, current_model_name
unload_model()
logger.info(f"Loading model: {model_name}")
current_model = AutoModelForSeq2SeqLM.from_pretrained(
AVAILABLE_MODELS[model_name]
).to(device)
current_tokenizer = NllbTokenizer.from_pretrained(AVAILABLE_MODELS[model_name])
current_model_name = model_name
return "The model has been uploaded successfully!"
def generate(input_texts: str, model_name: str, language: str) -> str:
"""Генерирует текст на основе входных данных"""
if current_model is None or current_tokenizer is None:
return "Please select and upload the model first."
if model_name == "NLLB for transliteration":
current_tokenizer.src_lang = 'egy_Tnt'
elif model_name == "NLLB for hieroglyphs":
current_tokenizer.src_lang = 'egy_Hiero'
encoded_inputs = current_tokenizer(
input_texts,
padding=True,
truncation=True,
return_tensors="pt"
).to(current_model.device)
with torch.no_grad():
# Устанавливаем язык для вывода на основе выбора пользователя
forced_bos_token_id = current_tokenizer.convert_tokens_to_ids(LANGUAGES[language])
generated_tokens = current_model.generate(
**encoded_inputs,
forced_bos_token_id=forced_bos_token_id,
num_beams=4,
early_stopping=True,
repetition_penalty=3.0
)
output_text = current_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
response = output_text.replace(f'{LANGUAGES[language]} ', '')
return response
def predict(model_choice, message, language):
global current_model_name
if current_model is None or model_choice != current_model_name:
load_model(model_choice)
return generate(message, model_choice, language)
demo = gr.Interface(
allow_flagging="never",
fn=predict,
inputs=[
gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
label="Select a model",
value=list(AVAILABLE_MODELS.keys())[0]
),
gr.Textbox(
label="Enter the sentence using transliteration or hieroglyphs.",
placeholder="",
lines=3
),
gr.Dropdown(
choices=list(LANGUAGES.keys()),
label="Select output language",
value="English"
)
],
outputs=[
gr.Textbox(
label="Translation",
lines=10
)
],
title="",
examples=[
["NLLB for transliteration", "wn sbte nb ẖn =f", "English"],
["NLLB for hieroglyphs", "𓅭 𓆑 𓉐𓉻𓌕𓏌 𓋴𓌉𓂖 𓊪𓏏𓎛𓊵𓏏𓊪", "English"],
["NLLB for transliteration", "m wṯs jb =k n z", "German"],
["NLLB for hieroglyphs", "𓍹𓅃𓇋𓂓𓅱𓍺𓌳𓍘𓃫𓌸𓊖", "German"],
],
theme='base',
# :root {
# --bg: rgb(22,28,38) !important;
# --bg-dark: rgb(22,28,38) !important;
# --col: #f4f4f5 !important;
# --col-dark: #f4f4f5 !important;
# }
# body, .gradio-container, .gradio-container > div, .gradio-container .panel, .gradio-container .output, .gradio-container .input {
# background: rgb(22,28,38) !important;
# color: #f4f4f5 !important;
# margin: 0;
# padding: 0;
# height: 100%;
# width: 100%;
# overflow: hidden;
# }
css="""
button {
background-color: rgb(15, 138, 129) !important;
color: white !important;
}
button:hover {
background-color: rgb(15, 138, 129) !important;
}
footer {
display: none !important;
}
""",
)
if __name__ == "__main__":
default_model = list(AVAILABLE_MODELS.keys())[0]
load_model(default_model)
demo.launch(share=False)