carlosep93
added accelerate support
ac24b2e
# from responses import start
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "BSC-LT/salamandraTA-2b"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
# Move model to GPU if available
languages = [ "Spanish", "Catalan", "English", "French", "German", "Italian", "Portuguese", "Euskera", "Galician",
"Bulgarian", "Czech", "Lithuanian", "Croatian", "Dutch", "Romanian", "Danish", "Greek", "Finnish",
"Hungarian", "Slovak", "Slovenian", "Estonian", "Polish", "Latvian", "Swedish", "Maltese",
"Irish", "Aranese", "Aragonese", "Asturian" ]
example_sentence = ["Ahir se'n va anar, va agafar les seves coses i es va posar a navegar."]
@spaces.GPU(duration=120)
def translate(input_text, source, target):
sentences = input_text.split('\n')
generated_text = []
for sentence in sentences:
prompt = f'[{source}] {sentence} \n[{target}]'
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(model.device)
output_ids = model.generate(input_ids, max_length=500, num_beams=5)
input_length = input_ids.shape[1]
generated_text.append(tokenizer.decode(output_ids[0, input_length:], skip_special_tokens=True).strip())
return '\n'.join(generated_text), ""
with gr.Blocks() as demo:
gr.HTML("""<html>
<head>
<style>
h1 {
text-align: center;
}
</style>
</head>
<body>
<h1>SalamandraTA 2B Translate</h1>
</body>
</html>""")
with gr.Row():
with gr.Column():
source_language_dropdown = gr.Dropdown(choices=languages,
value="Catalan",
label="Source Language")
input_textbox = gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text")
with gr.Column():
target_language_dropdown = gr.Dropdown(choices=languages,
value="English",
label="Target Language")
translated_textbox = gr.Textbox(lines=5, placeholder="", label="Translated Text")
info_label = gr.HTML("")
btn = gr.Button("Translate")
btn.click(translate, inputs=[input_textbox,
source_language_dropdown,
target_language_dropdown],
outputs=[translated_textbox, info_label])
gr.Examples(example_sentence, inputs=[input_textbox])
if __name__ == "__main__":
demo.launch()