import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Model configuration model_name = "ai4bharat/IndicBART" # Load tokenizer and model on CPU print("Loading IndicBART tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) print("Loading IndicBART model on CPU...") model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float32, device_map="cpu" ) # Language mapping LANGUAGE_CODES = { "Assamese": "<2as>", "Bengali": "<2bn>", "English": "<2en>", "Gujarati": "<2gu>", "Hindi": "<2hi>", "Kannada": "<2kn>", "Malayalam": "<2ml>", "Marathi": "<2mr>", "Oriya": "<2or>", "Punjabi": "<2pa>", "Tamil": "<2ta>", "Telugu": "<2te>" } def generate_response(input_text, source_lang, target_lang, task_type, max_length): """Generate response using IndicBART on CPU""" if not input_text.strip(): return "Please enter some text to process." try: # Get language codes src_code = LANGUAGE_CODES[source_lang] tgt_code = LANGUAGE_CODES[target_lang] # Format input based on task type if task_type == "Translation": formatted_input = f"{input_text} {src_code}" decoder_start_token = tgt_code elif task_type == "Text Completion": formatted_input = f"{input_text} {tgt_code}" decoder_start_token = tgt_code else: # Text Generation formatted_input = f"{input_text} {src_code}" decoder_start_token = tgt_code # FIX 1: Tokenize with explicit token_type_ids=False inputs = tokenizer( formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512, return_token_type_ids=False # KEY FIX: Prevent token_type_ids ) # FIX 2: Alternative approach - manually remove if present if 'token_type_ids' in inputs: del inputs['token_type_ids'] # Get decoder start token id try: decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token) except: decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token) # FIX 3: Use explicit parameters instead of **inputs (most reliable) with torch.no_grad(): outputs = model.generate( input_ids=inputs['input_ids'], # Explicit parameter attention_mask=inputs['attention_mask'], # Explicit parameter decoder_start_token_id=decoder_start_token_id, max_length=max_length, num_beams=2, early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, do_sample=False ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) return generated_text except Exception as e: return f"Error generating response: {str(e)}" # Create Gradio interface with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ЁЯЗоЁЯЗ│ IndicBART Multilingual Assistant (CPU Version) Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation. **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English """) with gr.Row(): with gr.Column(scale=3): input_text = gr.Textbox( label="Input Text", placeholder="Enter text in any supported language...", lines=3 ) output_text = gr.Textbox( label="Generated Output", lines=5, interactive=False ) with gr.Row(): generate_btn = gr.Button("Generate", variant="primary", size="lg") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=1): task_type = gr.Dropdown( choices=["Translation", "Text Completion", "Text Generation"], value="Translation", label="Task Type" ) source_lang = gr.Dropdown( choices=list(LANGUAGE_CODES.keys()), value="English", label="Source Language" ) target_lang = gr.Dropdown( choices=list(LANGUAGE_CODES.keys()), value="Hindi", label="Target Language" ) max_length = gr.Slider( minimum=20, maximum=200, value=80, step=10, label="Max Length" ) # Simple examples without caching gr.Markdown("### ЁЯТб Try these examples:") with gr.Row(): with gr.Column(): gr.Markdown("**English to Hindi**") example1_btn = gr.Button("Hello, how are you?") with gr.Column(): gr.Markdown("**Hindi to English**") example2_btn = gr.Button("рдореИрдВ рдПрдХ рдЫрд╛рддреНрд░ рд╣реВрдВ") with gr.Column(): gr.Markdown("**Bengali to English**") example3_btn = gr.Button("ржЖржорж┐ ржнрж╛ржд ржЦрж╛ржЗ") # Event handlers def clear_fields(): return "", "" def set_example1(): return "Hello, how are you?", "English", "Hindi", "Translation" def set_example2(): return "рдореИрдВ рдПрдХ рдЫрд╛рддреНрд░ рд╣реВрдВ", "Hindi", "English", "Translation" def set_example3(): return "ржЖржорж┐ ржнрж╛ржд ржЦрж╛ржЗ", "Bengali", "English", "Translation" # Connect buttons generate_btn.click( generate_response, inputs=[input_text, source_lang, target_lang, task_type, max_length], outputs=output_text ) clear_btn.click( clear_fields, outputs=[input_text, output_text] ) example1_btn.click( set_example1, outputs=[input_text, source_lang, target_lang, task_type] ) example2_btn.click( set_example2, outputs=[input_text, source_lang, target_lang, task_type] ) example3_btn.click( set_example3, outputs=[input_text, source_lang, target_lang, task_type] ) # FIX 4: Updated launch parameters (removed cache_examples) if __name__ == "__main__": demo.launch( share=True, show_error=True, enable_queue=False, # Removed cache_examples parameter - not supported in newer Gradio versions )