krishna195 commited on
Commit
f91db4a
Β·
verified Β·
1 Parent(s): ac3fa16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -49
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import time
5
- import spaces # ADDED THIS IMPORT
6
 
7
  # Model configuration
8
  MODEL_NAME = "krishna195/medgemma-anatomy-v1.2"
@@ -48,7 +48,7 @@ def load_model():
48
  print("Initializing MedGemma...")
49
  model, tokenizer = load_model()
50
 
51
- @spaces.GPU(duration=60) # MOVED DECORATOR HERE - applied to inference function
52
  def generate_response(question, max_tokens=512, temperature=0.7, top_p=0.9):
53
  """
54
  Generate medical response for a given question
@@ -59,51 +59,59 @@ def generate_response(question, max_tokens=512, temperature=0.7, top_p=0.9):
59
  temperature: Sampling temperature (0.0-1.0)
60
  top_p: Nucleus sampling parameter
61
  """
62
- if not question.strip():
63
- return "Please enter a medical question."
64
-
65
- # Format prompt with Gemma chat template
66
- prompt = f"""<start_of_turn>user
 
 
 
 
67
  {question}<end_of_turn>
68
  <start_of_turn>model
69
  """
70
-
71
- # Tokenize
72
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73
-
74
- # Generate
75
- start_time = time.time()
76
-
77
- with torch.no_grad():
78
- outputs = model.generate(
79
- **inputs,
80
- max_new_tokens=max_tokens,
81
- temperature=temperature,
82
- do_sample=True,
83
- top_p=top_p,
84
- repetition_penalty=1.1,
85
- pad_token_id=tokenizer.eos_token_id
86
- )
87
-
88
- generation_time = time.time() - start_time
89
-
90
- # Decode response
91
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=False)
92
-
93
- # Extract model response
94
- if "<start_of_turn>model" in full_output:
95
- response = full_output.split("<start_of_turn>model")[-1]
96
- response = response.split("<end_of_turn>")[0].strip()
97
- else:
98
- response = full_output.strip()
99
-
100
- # Add metadata
101
- tokens_generated = outputs.shape[1] - inputs['input_ids'].shape[1]
102
- tokens_per_sec = tokens_generated / generation_time if generation_time > 0 else 0
103
-
104
- metadata = f"\n\n---\n*Generated in {generation_time:.2f}s ({tokens_per_sec:.1f} tokens/sec)*"
105
-
106
- return response + metadata
 
 
 
 
107
 
108
  # Example questions
109
  examples = [
@@ -120,6 +128,7 @@ css = """
120
  #warning {background-color: #FFCCCB; padding: 10px; border-radius: 5px; margin-bottom: 10px;}
121
  .generate-btn {background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white;}
122
  footer {visibility: hidden;}
 
123
  """
124
 
125
  # Build Gradio interface
@@ -183,29 +192,36 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
183
  info="Nucleus sampling parameter"
184
  )
185
 
186
- generate_btn = gr.Button("Generate Response", variant="primary", elem_classes="generate-btn")
 
187
 
188
  with gr.Column(scale=3):
189
- output = gr.Markdown(label="Response")
 
 
 
 
190
 
191
  with gr.Row():
192
  gr.Examples(
193
  examples=examples,
194
  inputs=question_input,
195
- label="Example Questions"
196
  )
197
 
198
  # Event handlers
199
  generate_btn.click(
200
  fn=generate_response,
201
  inputs=[question_input, max_tokens, temperature, top_p],
202
- outputs=output
 
203
  )
204
 
205
  question_input.submit(
206
  fn=generate_response,
207
  inputs=[question_input, max_tokens, temperature, top_p],
208
- outputs=output
 
209
  )
210
 
211
  gr.Markdown(
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import time
5
+ import spaces
6
 
7
  # Model configuration
8
  MODEL_NAME = "krishna195/medgemma-anatomy-v1.2"
 
48
  print("Initializing MedGemma...")
49
  model, tokenizer = load_model()
50
 
51
+ @spaces.GPU(duration=60)
52
  def generate_response(question, max_tokens=512, temperature=0.7, top_p=0.9):
53
  """
54
  Generate medical response for a given question
 
59
  temperature: Sampling temperature (0.0-1.0)
60
  top_p: Nucleus sampling parameter
61
  """
62
+ try:
63
+ if not question.strip():
64
+ return "⚠️ Please enter a medical question."
65
+
66
+ # Show processing message
67
+ yield "πŸ”„ **Processing your question...**\n\nGenerating response, please wait..."
68
+
69
+ # Format prompt with Gemma chat template
70
+ prompt = f"""<start_of_turn>user
71
  {question}<end_of_turn>
72
  <start_of_turn>model
73
  """
74
+
75
+ # Tokenize
76
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
77
+
78
+ # Generate
79
+ start_time = time.time()
80
+
81
+ with torch.no_grad():
82
+ outputs = model.generate(
83
+ **inputs,
84
+ max_new_tokens=max_tokens,
85
+ temperature=temperature,
86
+ do_sample=True,
87
+ top_p=top_p,
88
+ repetition_penalty=1.1,
89
+ pad_token_id=tokenizer.eos_token_id
90
+ )
91
+
92
+ generation_time = time.time() - start_time
93
+
94
+ # Decode response
95
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=False)
96
+
97
+ # Extract model response
98
+ if "<start_of_turn>model" in full_output:
99
+ response = full_output.split("<start_of_turn>model")[-1]
100
+ response = response.split("<end_of_turn>")[0].strip()
101
+ else:
102
+ response = full_output.strip()
103
+
104
+ # Add metadata
105
+ tokens_generated = outputs.shape[1] - inputs['input_ids'].shape[1]
106
+ tokens_per_sec = tokens_generated / generation_time if generation_time > 0 else 0
107
+
108
+ metadata = f"\n\n---\nβœ… *Generated in {generation_time:.2f}s ({tokens_per_sec:.1f} tokens/sec) | Device: {DEVICE.upper()}*"
109
+
110
+ yield response + metadata
111
+
112
+ except Exception as e:
113
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n```\n\nPlease try again or contact support if the issue persists."
114
+ yield error_msg
115
 
116
  # Example questions
117
  examples = [
 
128
  #warning {background-color: #FFCCCB; padding: 10px; border-radius: 5px; margin-bottom: 10px;}
129
  .generate-btn {background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white;}
130
  footer {visibility: hidden;}
131
+ #output-box {min-height: 200px; border: 1px solid #e0e0e0; border-radius: 8px; padding: 15px;}
132
  """
133
 
134
  # Build Gradio interface
 
192
  info="Nucleus sampling parameter"
193
  )
194
 
195
+ generate_btn = gr.Button("πŸš€ Generate Response", variant="primary", elem_classes="generate-btn")
196
+ clear_btn = gr.ClearButton([question_input], value="πŸ—‘οΈ Clear")
197
 
198
  with gr.Column(scale=3):
199
+ output = gr.Markdown(
200
+ label="Response",
201
+ value="*Your medical answer will appear here...*",
202
+ elem_id="output-box"
203
+ )
204
 
205
  with gr.Row():
206
  gr.Examples(
207
  examples=examples,
208
  inputs=question_input,
209
+ label="πŸ“‹ Example Questions - Click to try"
210
  )
211
 
212
  # Event handlers
213
  generate_btn.click(
214
  fn=generate_response,
215
  inputs=[question_input, max_tokens, temperature, top_p],
216
+ outputs=output,
217
+ show_progress=True
218
  )
219
 
220
  question_input.submit(
221
  fn=generate_response,
222
  inputs=[question_input, max_tokens, temperature, top_p],
223
+ outputs=output,
224
+ show_progress=True
225
  )
226
 
227
  gr.Markdown(