Thanush1 commited on
Commit
845f6f3
Β·
verified Β·
1 Parent(s): 4320235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -54,17 +54,17 @@ def generate_response(input_text, source_lang, target_lang, task_type, max_lengt
54
  formatted_input = f"{input_text} </s> {src_code}"
55
  decoder_start_token = tgt_code
56
 
57
- # Tokenize input - KEY FIX: Explicitly set return_token_type_ids=False
58
  inputs = tokenizer(
59
  formatted_input,
60
  return_tensors="pt",
61
  padding=True,
62
  truncation=True,
63
  max_length=512,
64
- return_token_type_ids=False # This prevents the error
65
  )
66
 
67
- # Alternative fix: Remove token_type_ids if present
68
  if 'token_type_ids' in inputs:
69
  del inputs['token_type_ids']
70
 
@@ -74,10 +74,11 @@ def generate_response(input_text, source_lang, target_lang, task_type, max_lengt
74
  except:
75
  decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)
76
 
77
- # Generate on CPU
78
  with torch.no_grad():
79
  outputs = model.generate(
80
- **inputs,
 
81
  decoder_start_token_id=decoder_start_token_id,
82
  max_length=max_length,
83
  num_beams=2,
@@ -150,7 +151,7 @@ with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Sof
150
  label="Max Length"
151
  )
152
 
153
- # Simplified examples to avoid caching issues
154
  gr.Markdown("### πŸ’‘ Try these examples:")
155
 
156
  with gr.Row():
@@ -204,12 +205,11 @@ with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Sof
204
  outputs=[input_text, source_lang, target_lang, task_type]
205
  )
206
 
207
- # Launch with all fixes applied
208
  if __name__ == "__main__":
209
  demo.launch(
210
  share=True,
211
- ssr_mode=False, # Disable SSR
212
- cache_examples=False, # Disable example caching - KEY FIX
213
  show_error=True,
214
- enable_queue=False # Disable queue to avoid startup issues
 
215
  )
 
54
  formatted_input = f"{input_text} </s> {src_code}"
55
  decoder_start_token = tgt_code
56
 
57
+ # FIX 1: Tokenize with explicit token_type_ids=False
58
  inputs = tokenizer(
59
  formatted_input,
60
  return_tensors="pt",
61
  padding=True,
62
  truncation=True,
63
  max_length=512,
64
+ return_token_type_ids=False # KEY FIX: Prevent token_type_ids
65
  )
66
 
67
+ # FIX 2: Alternative approach - manually remove if present
68
  if 'token_type_ids' in inputs:
69
  del inputs['token_type_ids']
70
 
 
74
  except:
75
  decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token)
76
 
77
+ # FIX 3: Use explicit parameters instead of **inputs (most reliable)
78
  with torch.no_grad():
79
  outputs = model.generate(
80
+ input_ids=inputs['input_ids'], # Explicit parameter
81
+ attention_mask=inputs['attention_mask'], # Explicit parameter
82
  decoder_start_token_id=decoder_start_token_id,
83
  max_length=max_length,
84
  num_beams=2,
 
151
  label="Max Length"
152
  )
153
 
154
+ # Simple examples without caching
155
  gr.Markdown("### πŸ’‘ Try these examples:")
156
 
157
  with gr.Row():
 
205
  outputs=[input_text, source_lang, target_lang, task_type]
206
  )
207
 
208
+ # FIX 4: Updated launch parameters (removed cache_examples)
209
  if __name__ == "__main__":
210
  demo.launch(
211
  share=True,
 
 
212
  show_error=True,
213
+ enable_queue=False,
214
+ # Removed cache_examples parameter - not supported in newer Gradio versions
215
  )