from transformers import MarianMTModel, MarianTokenizer import gradio as gr import torch # Define available models and BLEU scores MODEL_OPTIONS = { "Version v1 (EN→RU BLEU: 35.93 | RU→EN BLEU: 41.11)": { "key": "v1", "en_ru": "kafarasi/marian-en-ru-finetuned", "ru_en": "kafarasi/marian-ru-en-finetuned" }, "Version v2 (EN→RU BLEU: 36.61 | RU→EN BLEU: 50.00)": { "key": "v2", "en_ru": "kafarasi/marian-en-ru-finetunedv2", "ru_en": "kafarasi/marian-ru-en-finetunedv2" }, "Version v3 (EN→RU BLEU: 37.52 | RU→EN BLEU: 47.06)": { "key": "v3", "en_ru": "kafarasi/marian-en-ru-finetunedv4", "ru_en": "kafarasi/marian-ru-en-finetunedv4" } } # Select device (CPU for Hugging Face Spaces) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_float16 = device.type == "cuda" loaded_models = {} # Load and cache models def get_model_and_tokenizer(display_label, direction): version_info = MODEL_OPTIONS[display_label] version_key = version_info["key"] cache_key = f"{version_key}_{direction}" if cache_key in loaded_models: return loaded_models[cache_key] model_name = version_info["en_ru"] if direction == "English → Runyankore" else version_info["ru_en"] tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained( model_name, torch_dtype=torch.float16 if use_float16 else torch.float32 ).to(device) loaded_models[cache_key] = (tokenizer, model) return tokenizer, model # Translation function (no splitting) def translate_text(text, direction, version_label): tokenizer, model = get_model_and_tokenizer(version_label, direction) text = text.strip() if not text: return "" inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): outputs = model.generate(**inputs, max_length=128, num_beams=1) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Clear inputs def clear_fields(): return "", "" # UI Styling custom_css = """ body { background: linear-gradient(to right, #e3f2fd, #fce4ec); font-family: 'Segoe UI', sans-serif; } h1 { color: #2c3e50; font-size: 32px; text-align: center; margin-bottom: 10px; } p { text-align: center; font-size: 16px; color: #4e4e4e; } .gradio-container { max-width: 900px; margin: auto; padding: 30px; background: #ffffff; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); border-radius: 16px; } textarea, input { font-size: 16px !important; border: 2px solid #2980b9; border-radius: 8px !important; } .gr-button { background-color: #3498db !important; color: white !important; border-radius: 8px !important; font-size: 16px !important; padding: 10px 20px !important; transition: background-color 0.3s ease; } .gr-button:hover { background-color: #2c81ba !important; } """ # Gradio Interface with gr.Blocks(css=custom_css) as iface: gr.Markdown("

Runyankore ↔ English Translator

") gr.Markdown( "

Select a model version and translation direction. Input text will be translated efficiently even on CPU.

" ) with gr.Row(): with gr.Column(scale=1): model_selector = gr.Dropdown(list(MODEL_OPTIONS.keys()), label="Model Version", value=list(MODEL_OPTIONS.keys())[1]) text_input = gr.Textbox(lines=5, label="Input Text", placeholder="Enter text...", interactive=True, show_copy_button=True) direction = gr.Radio(["English → Runyankore", "Runyankore → English"], label="Translation Direction") with gr.Row(): translate_btn = gr.Button("🔄 Translate") clear_btn = gr.Button("🗑️ Clear") with gr.Column(scale=1): output_text = gr.Textbox(lines=5, label="Translated Output", interactive=False, show_copy_button=True) translate_btn.click(fn=translate_text, inputs=[text_input, direction, model_selector], outputs=output_text) clear_btn.click(fn=clear_fields, outputs=[text_input, output_text]) gr.Markdown( "

💬 Feedback? " "Click here to help improve

" ) # Launch iface.launch()