import logging import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM, NllbTokenizer logging.basicConfig( level=logging.INFO, format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', datefmt='%H:%M:%S' ) logger = logging.getLogger() AVAILABLE_MODELS = { "NLLB for transliteration": "kesha-humonen/tr-eng_checkpoint-8556", "NLLB for hieroglyphs": "kesha-humonen/hi-eng_dpo_checkpoint-3342" } LANGUAGES = { "English": "eng_Latn", "German": "deu_Latn" } device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") current_model = None current_tokenizer = None current_model_name = None def unload_model(): global current_model if current_model is not None: logger.info(f"Unloading current model: {current_model.name_or_path}") del current_model torch.cuda.empty_cache() current_model = None def load_model(model_name: str): global current_model, current_tokenizer, current_model_name unload_model() logger.info(f"Loading model: {model_name}") current_model = AutoModelForSeq2SeqLM.from_pretrained( AVAILABLE_MODELS[model_name] ).to(device) current_tokenizer = NllbTokenizer.from_pretrained(AVAILABLE_MODELS[model_name]) current_model_name = model_name return "The model has been uploaded successfully!" def generate(input_texts: str, model_name: str, language: str) -> str: """Генерирует текст на основе входных данных""" if current_model is None or current_tokenizer is None: return "Please select and upload the model first." if model_name == "NLLB for transliteration": current_tokenizer.src_lang = 'egy_Tnt' elif model_name == "NLLB for hieroglyphs": current_tokenizer.src_lang = 'egy_Hiero' encoded_inputs = current_tokenizer( input_texts, padding=True, truncation=True, return_tensors="pt" ).to(current_model.device) with torch.no_grad(): # Устанавливаем язык для вывода на основе выбора пользователя forced_bos_token_id = current_tokenizer.convert_tokens_to_ids(LANGUAGES[language]) generated_tokens = current_model.generate( **encoded_inputs, forced_bos_token_id=forced_bos_token_id, num_beams=4, early_stopping=True, repetition_penalty=3.0 ) output_text = current_tokenizer.decode(generated_tokens[0], skip_special_tokens=True) response = output_text.replace(f'{LANGUAGES[language]} ', '') return response def predict(model_choice, message, language): global current_model_name if current_model is None or model_choice != current_model_name: load_model(model_choice) return generate(message, model_choice, language) demo = gr.Interface( allow_flagging="never", fn=predict, inputs=[ gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), label="Select a model", value=list(AVAILABLE_MODELS.keys())[0] ), gr.Textbox( label="Enter the sentence using transliteration or hieroglyphs.", placeholder="", lines=3 ), gr.Dropdown( choices=list(LANGUAGES.keys()), label="Select output language", value="English" ) ], outputs=[ gr.Textbox( label="Translation", lines=10 ) ], title="", examples=[ ["NLLB for transliteration", "wn sbte nb ẖn =f", "English"], ["NLLB for hieroglyphs", "𓅭 𓆑 𓉐𓉻𓌕𓏌 𓋴𓌉𓂖 𓊪𓏏𓎛𓊵𓏏𓊪", "English"], ["NLLB for transliteration", "m wṯs jb =k n z", "German"], ["NLLB for hieroglyphs", "𓍹𓅃𓇋𓂓𓅱𓍺𓌳𓍘𓃫𓌸𓊖", "German"], ], theme='base', # :root { # --bg: rgb(22,28,38) !important; # --bg-dark: rgb(22,28,38) !important; # --col: #f4f4f5 !important; # --col-dark: #f4f4f5 !important; # } # body, .gradio-container, .gradio-container > div, .gradio-container .panel, .gradio-container .output, .gradio-container .input { # background: rgb(22,28,38) !important; # color: #f4f4f5 !important; # margin: 0; # padding: 0; # height: 100%; # width: 100%; # overflow: hidden; # } css=""" button { background-color: rgb(15, 138, 129) !important; color: white !important; } button:hover { background-color: rgb(15, 138, 129) !important; } footer { display: none !important; } """, ) if __name__ == "__main__": default_model = list(AVAILABLE_MODELS.keys())[0] load_model(default_model) demo.launch(share=False)