Deeptranslation / app.py
Thanush1's picture
Update app.py
845f6f3 verified
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Model configuration
model_name = "ai4bharat/IndicBART"
# Load tokenizer and model on CPU
print("Loading IndicBART tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True)
print("Loading IndicBART model on CPU...")
model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="cpu"
)
# Language mapping
LANGUAGE_CODES = {
"Assamese": "<2as>",
"Bengali": "<2bn>",
"English": "<2en>",
"Gujarati": "<2gu>",
"Hindi": "<2hi>",
"Kannada": "<2kn>",
"Malayalam": "<2ml>",
"Marathi": "<2mr>",
"Oriya": "<2or>",
"Punjabi": "<2pa>",
"Tamil": "<2ta>",
"Telugu": "<2te>"
}
def generate_response(input_text, source_lang, target_lang, task_type, max_length):
"""Generate response using IndicBART on CPU"""
if not input_text.strip():
return "Please enter some text to process."
try:
# Get language codes
src_code = LANGUAGE_CODES[source_lang]
tgt_code = LANGUAGE_CODES[target_lang]
# Format input based on task type
if task_type == "Translation":
formatted_input = f"{input_text} </s> {src_code}"
decoder_start_token = tgt_code
elif task_type == "Text Completion":
formatted_input = f"{input_text} </s> {tgt_code}"
decoder_start_token = tgt_code
else: # Text Generation
formatted_input = f"{input_text} </s> {src_code}"
decoder_start_token = tgt_code
# FIX 1: Tokenize with explicit token_type_ids=False
inputs = tokenizer(
formatted_input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_token_type_ids=False # KEY FIX: Prevent token_type_ids
)
# FIX 2: Alternative approach - manually remove if present
if 'token_type_ids' in inputs:
del inputs['token_type_ids']
# Get decoder start token id
try:
decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token)
except:
decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)
# FIX 3: Use explicit parameters instead of **inputs (most reliable)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs['input_ids'], # Explicit parameter
attention_mask=inputs['attention_mask'], # Explicit parameter
decoder_start_token_id=decoder_start_token_id,
max_length=max_length,
num_beams=2,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
do_sample=False
)
# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
return generated_text
except Exception as e:
return f"Error generating response: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ๐Ÿ‡ฎ๐Ÿ‡ณ IndicBART Multilingual Assistant (CPU Version)
Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation.
**Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English
""")
with gr.Row():
with gr.Column(scale=3):
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter text in any supported language...",
lines=3
)
output_text = gr.Textbox(
label="Generated Output",
lines=5,
interactive=False
)
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary", size="lg")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column(scale=1):
task_type = gr.Dropdown(
choices=["Translation", "Text Completion", "Text Generation"],
value="Translation",
label="Task Type"
)
source_lang = gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="English",
label="Source Language"
)
target_lang = gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="Hindi",
label="Target Language"
)
max_length = gr.Slider(
minimum=20,
maximum=200,
value=80,
step=10,
label="Max Length"
)
# Simple examples without caching
gr.Markdown("### ๐Ÿ’ก Try these examples:")
with gr.Row():
with gr.Column():
gr.Markdown("**English to Hindi**")
example1_btn = gr.Button("Hello, how are you?")
with gr.Column():
gr.Markdown("**Hindi to English**")
example2_btn = gr.Button("เคฎเฅˆเค‚ เคเค• เค›เคพเคคเฅเคฐ เคนเฅ‚เค‚")
with gr.Column():
gr.Markdown("**Bengali to English**")
example3_btn = gr.Button("เฆ†เฆฎเฆฟ เฆญเฆพเฆค เฆ–เฆพเฆ‡")
# Event handlers
def clear_fields():
return "", ""
def set_example1():
return "Hello, how are you?", "English", "Hindi", "Translation"
def set_example2():
return "เคฎเฅˆเค‚ เคเค• เค›เคพเคคเฅเคฐ เคนเฅ‚เค‚", "Hindi", "English", "Translation"
def set_example3():
return "เฆ†เฆฎเฆฟ เฆญเฆพเฆค เฆ–เฆพเฆ‡", "Bengali", "English", "Translation"
# Connect buttons
generate_btn.click(
generate_response,
inputs=[input_text, source_lang, target_lang, task_type, max_length],
outputs=output_text
)
clear_btn.click(
clear_fields,
outputs=[input_text, output_text]
)
example1_btn.click(
set_example1,
outputs=[input_text, source_lang, target_lang, task_type]
)
example2_btn.click(
set_example2,
outputs=[input_text, source_lang, target_lang, task_type]
)
example3_btn.click(
set_example3,
outputs=[input_text, source_lang, target_lang, task_type]
)
# FIX 4: Updated launch parameters (removed cache_examples)
if __name__ == "__main__":
demo.launch(
share=True,
show_error=True,
enable_queue=False,
# Removed cache_examples parameter - not supported in newer Gradio versions
)