Spaces:
Running
on
Zero
Running
on
Zero
# from responses import start | |
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
model_id = "BSC-LT/salamandraTA-2b" | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") | |
# Move model to GPU if available | |
languages = [ "Spanish", "Catalan", "English", "French", "German", "Italian", "Portuguese", "Euskera", "Galician", | |
"Bulgarian", "Czech", "Lithuanian", "Croatian", "Dutch", "Romanian", "Danish", "Greek", "Finnish", | |
"Hungarian", "Slovak", "Slovenian", "Estonian", "Polish", "Latvian", "Swedish", "Maltese", | |
"Irish", "Aranese", "Aragonese", "Asturian" ] | |
example_sentence = ["Ahir se'n va anar, va agafar les seves coses i es va posar a navegar."] | |
def translate(input_text, source, target): | |
sentences = input_text.split('\n') | |
generated_text = [] | |
for sentence in sentences: | |
prompt = f'[{source}] {sentence} \n[{target}]' | |
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device) | |
output_ids = model.generate(input_ids, max_length=500, num_beams=5) | |
input_length = input_ids.shape[1] | |
generated_text.append(tokenizer.decode(output_ids[0, input_length:], skip_special_tokens=True).strip()) | |
return '\n'.join(generated_text), "" | |
with gr.Blocks() as demo: | |
gr.HTML("""<html> | |
<head> | |
<style> | |
h1 { | |
text-align: center; | |
} | |
</style> | |
</head> | |
<body> | |
<h1>SalamandraTA 2B Translate</h1> | |
</body> | |
</html>""") | |
with gr.Row(): | |
with gr.Column(): | |
source_language_dropdown = gr.Dropdown(choices=languages, | |
value="Catalan", | |
label="Source Language") | |
input_textbox = gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text") | |
with gr.Column(): | |
target_language_dropdown = gr.Dropdown(choices=languages, | |
value="English", | |
label="Target Language") | |
translated_textbox = gr.Textbox(lines=5, placeholder="", label="Translated Text") | |
info_label = gr.HTML("") | |
btn = gr.Button("Translate") | |
btn.click(translate, inputs=[input_textbox, | |
source_language_dropdown, | |
target_language_dropdown], | |
outputs=[translated_textbox, info_label]) | |
gr.Examples(example_sentence, inputs=[input_textbox]) | |
if __name__ == "__main__": | |
demo.launch() |