|
import gradio as gr |
|
import torch |
|
|
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
|
|
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B") |
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") |
|
|
|
langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), |
|
Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), |
|
Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), |
|
Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)""" |
|
lang_list = [lang.strip() for lang in langs.split(',')] |
|
|
|
examples = [["Korean (ko)", "English (en)", "μλ§νλ€λ μλΌκ° μλ€."]] |
|
|
|
def translate(src, tgt, text): |
|
src = src.split(" ")[-1][1:-1] |
|
tgt = tgt.split(" ")[-1][1:-1] |
|
|
|
|
|
tokenizer.src_lang = src |
|
encoded_src = tokenizer(text, return_tensors="pt") |
|
generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True) |
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
return result |
|
|
|
output_text = gr.outputs.Textbox() |
|
gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=output_text, title="Translate Between 100 languages", |
|
description="M2M100-1.2B λͺ¨λΈμ κ°μ§κ³ 100κ°κ΅μ΄ μΈμ΄λ₯Ό λ²μνλ λ°λͺ¨νμ΄μ§ μ
λλ€.(a demo page that translate between 100 languages using M2M100-1.2B)", |
|
examples=examples |
|
).launch() |
|
|