RishabA's picture
Update app.py
4ad5c1b verified
raw
history blame
2.63 kB
import os
import torch
import spacy
import gradio as gr
from model import make_model, translate_sentence, Vocab
import __main__
__main__.Vocab = Vocab
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_tokenizers():
try:
spacy_es = spacy.load("es_core_news_sm")
except OSError:
os.system("python -m spacy download es_core_news_sm")
spacy_es = spacy.load("es_core_news_sm")
try:
spacy_en = spacy.load("en_core_web_sm")
except OSError:
os.system("python -m spacy download en_core_web_sm")
spacy_en = spacy.load("en_core_web_sm")
print("Tokenizers loaded.")
return spacy_es, spacy_en
spacy_es, spacy_en = load_tokenizers()
if os.path.exists("vocab.pt"):
torch.serialization.add_safe_globals([__main__.Vocab])
vocab_src, vocab_trg = torch.load("vocab.pt", weights_only=False)
else:
raise FileNotFoundError(
"vocab.pt not found. Please build and save the vocabularies first."
)
model = make_model(
device,
vocab_src,
vocab_trg,
n_layers=3,
d_model=512,
d_ffn=512,
n_heads=8,
dropout=0.1,
max_length=50,
)
model.to(device)
if os.path.exists("translation_model.pt"):
model.load_state_dict(torch.load("translation_model.pt", map_location=device))
print("Pretrained model loaded.")
else:
raise FileNotFoundError(
"translation_model.pt not found. Please train and save the model first."
)
def translate(text):
translation = translate_sentence(
text, model, vocab_src, vocab_trg, spacy_es, device, max_length=50
)
return translation
css_str = """
body {
background-color: #121212;
color: #e0e0e0;
}
.container {
max-width: 700px;
margin: 10px auto;
}
h1 {
font-size: 36px;
font-weight: bold;
text-align: center;
color: #ffffff;
}
.description {
font-size: 18px;
text-align: center;
color: #b0b0b0;
}
"""
with gr.Blocks(css=css_str) as demo:
gr.HTML("<div class='container'>")
gr.Markdown("<h1>Spanish-to-English Translator</h1>")
gr.Markdown(
"<div class='description'>Enter a Spanish sentence below to receive its English translation.</div>"
)
with gr.Row():
txt_input = gr.Textbox(
label="Enter Spanish sentence", lines=2, placeholder="Ej: ¿Cómo estás?"
)
translate_btn = gr.Button("Translate")
txt_output = gr.Textbox(label="English Translation", lines=2)
translate_btn.click(fn=translate, inputs=txt_input, outputs=txt_output)
gr.HTML("</div>")
if __name__ == "__main__":
demo.launch(share=True)