|
import gradio as gr
|
|
import torch
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
|
|
model_name = "./model"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
|
|
lang_code_to_token = {
|
|
"eng_Latn": "__eng_Latn__",
|
|
"deu_Latn": "__deu_Latn__",
|
|
"fra_Latn": "__fra_Latn__",
|
|
"spa_Latn": "__spa_Latn__"
|
|
}
|
|
|
|
|
|
target_lang_code = "eng_Latn"
|
|
src_lang_code = "mya"
|
|
|
|
def translate(text):
|
|
if not text.strip():
|
|
return "Please enter some Burmese text."
|
|
|
|
|
|
tokenizer.src_lang = src_lang_code
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
|
|
|
|
|
|
lang_token = lang_code_to_token[target_lang_code]
|
|
forced_bos_token_id = tokenizer.convert_tokens_to_ids(lang_token)
|
|
|
|
|
|
with torch.no_grad():
|
|
generated_tokens = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=forced_bos_token_id,
|
|
max_length=128
|
|
)
|
|
|
|
return tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
|
|
|
|
|
with gr.Blocks(title="Burmese to English Translator") as demo:
|
|
gr.Markdown("## 🇲🇲 ➡️ 🇬🇧 Burmese to English Translator")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_text = gr.Textbox(label="Burmese Text", placeholder="Type Burmese here...", lines=4)
|
|
btn = gr.Button("Translate")
|
|
with gr.Column():
|
|
output_text = gr.Textbox(label="Translation", interactive=False)
|
|
|
|
btn.click(translate, inputs=input_text, outputs=output_text)
|
|
|
|
demo.launch()
|
|
|