Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
MODEL_NAME = "Tamazight-NLP/NLLB-200-600M-Tamazight-All-Data-1.25-epoch" | |
REVISION = "7cacdb000c5a6264150f203a595f6cd681e20844" | |
NLLB_LANG_MAPPING = { | |
"English": "eng_Latn", | |
"Standard Moroccan Tamazight": "tzm_Tfng", | |
"Tachelhit/Central Atlas Tamazight": "taq_Tfng", | |
"Tachelhit/Central Atlas Tamazight (Latin)": "taq_Latn", | |
"Tarifit (Latin)": "kab_Latn", | |
"Moroccan Darija": "ary_Arab", | |
"Catalan": "cat_Latn", | |
"Spanish": "spa_Latn", | |
"French": "fra_Latn", | |
"Modern Standard Arabic": "arb_Arab", | |
"German": "deu_Latn", | |
"Dutch": "nld_Latn", | |
"Russian": "rus_Cyrl", | |
"Italian": "ita_Latn", | |
"Turkish": "tur_Latn", | |
"Esperanto": "epo_Latn" | |
} | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, revision=REVISION).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, revision=REVISION) | |
def translate(text, source_lang, target_lang, max_length=238, num_beams=4): | |
""" | |
Translate text from source language to target language | |
""" | |
print(text) | |
tokenizer.src_lang = NLLB_LANG_MAPPING[source_lang] | |
inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
translated_tokens = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.convert_tokens_to_ids(NLLB_LANG_MAPPING[target_lang]), | |
max_length=max_length, | |
num_beams=num_beams | |
) | |
translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
return translation | |
gradio_ui= gr.Interface( | |
fn=translate, | |
title="NLLB Tamazight Translation Demo", | |
inputs= [ | |
gr.components.Textbox(label="Text", lines=4, placeholder="ⵙⵙⴽⵛⵎ ⴰⴹⵕⵉⵚ...\nEnter text to translate..."), | |
gr.components.Dropdown(label="Source Language", choices=list(NLLB_LANG_MAPPING.keys()), value="English"), | |
gr.components.Dropdown(label="Target Language", choices=list(NLLB_LANG_MAPPING.keys()), value="Standard Moroccan Tamazight"), | |
gr.components.Slider(8, 400, value=238, step=8, label="Max Length (in tokens). Increase in case the output looks truncated."), | |
gr.components.Slider(1, 25, value=4, step=1, label="Number of beams. Higher values might improve translation accuracy at the cost of speed.") | |
], | |
outputs=gr.components.Textbox(label="Translated text") | |
) | |
gradio_ui.launch() |