vaibhavm29 commited on
Commit
f48d0d7
·
1 Parent(s): a7f8e09

included new user and updated buttons

Browse files
Files changed (2) hide show
  1. interface.py +18 -17
  2. medrax/tools/llava_med.py +2 -0
interface.py CHANGED
@@ -16,7 +16,8 @@ REPORT_DIR.mkdir(exist_ok=True)
16
  SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
17
 
18
  USERS = {
19
- 'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6'
 
20
  }
21
 
22
  class ChatInterface:
@@ -73,7 +74,7 @@ class ChatInterface:
73
  else:
74
  self.display_file_path = str(saved_path)
75
 
76
- return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True)
77
 
78
  def add_message(
79
  self, message: str, display_image: str, history: List[dict]
@@ -266,7 +267,7 @@ def create_demo(agent, tools_dict):
266
  )
267
  with gr.Row():
268
  analyze_btn = gr.Button("Analyze", interactive=False)
269
- ground_btn = gr.Button("Ground", interactive=False)
270
  segment_btn = gr.Button("Segment", interactive=False)
271
  with gr.Row():
272
  clear_btn = gr.Button("Clear Chat")
@@ -394,38 +395,38 @@ def create_demo(agent, tools_dict):
394
  bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
395
 
396
  analyze_btn.click(
397
- lambda: gr.update(value="Analyze the above image and identify the probabilites of occurrence of various diseases along with proper reason"), None, txt
398
  ).then(
399
- interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] # add message & clear box
400
  ).then(
401
  interface.process_message,
402
  inputs=[txt, image_display, chatbot],
403
  outputs=[chatbot, image_display, txt],
404
  ).then(lambda: gr.Textbox(interactive=True), None, [txt])
405
 
406
- ground_btn.click(
407
- lambda: gr.update(value="Ground the main disease in this CXR"), None, txt
408
- ).then(
409
- interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] # add message & clear box
410
- ).then(
411
- interface.process_message,
412
- inputs=[txt, image_display, chatbot],
413
- outputs=[chatbot, image_display, txt],
414
- ).then(lambda: gr.Textbox(interactive=True), None, [txt])
415
 
416
  segment_btn.click(
417
  lambda: gr.update(value="Segment the major affected lung"), None, txt
418
  ).then(
419
- interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] # add message & clear box
420
  ).then(
421
  interface.process_message,
422
  inputs=[txt, image_display, chatbot],
423
  outputs=[chatbot, image_display, txt],
424
  ).then(lambda: gr.Textbox(interactive=True), None, [txt])
425
 
426
- upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn, ground_btn, segment_btn])
427
 
428
- dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn, ground_btn, segment_btn])
429
 
430
  clear_btn.click(clear_chat, outputs=[chatbot, image_display])
431
  new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
 
16
  SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
17
 
18
  USERS = {
19
+ 'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6',
20
+ 'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C'
21
  }
22
 
23
  class ChatInterface:
 
74
  else:
75
  self.display_file_path = str(saved_path)
76
 
77
+ return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True)
78
 
79
  def add_message(
80
  self, message: str, display_image: str, history: List[dict]
 
267
  )
268
  with gr.Row():
269
  analyze_btn = gr.Button("Analyze", interactive=False)
270
+ # ground_btn = gr.Button("Ground", interactive=False)
271
  segment_btn = gr.Button("Segment", interactive=False)
272
  with gr.Row():
273
  clear_btn = gr.Button("Clear Chat")
 
395
  bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
396
 
397
  analyze_btn.click(
398
+ lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt
399
  ).then(
400
+ interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
401
  ).then(
402
  interface.process_message,
403
  inputs=[txt, image_display, chatbot],
404
  outputs=[chatbot, image_display, txt],
405
  ).then(lambda: gr.Textbox(interactive=True), None, [txt])
406
 
407
+ # ground_btn.click(
408
+ # lambda: gr.update(value="Ground the main disease in this CXR"), None, txt
409
+ # ).then(
410
+ # interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
411
+ # ).then(
412
+ # interface.process_message,
413
+ # inputs=[txt, image_display, chatbot],
414
+ # outputs=[chatbot, image_display, txt],
415
+ # ).then(lambda: gr.Textbox(interactive=True), None, [txt])
416
 
417
  segment_btn.click(
418
  lambda: gr.update(value="Segment the major affected lung"), None, txt
419
  ).then(
420
+ interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
421
  ).then(
422
  interface.process_message,
423
  inputs=[txt, image_display, chatbot],
424
  outputs=[chatbot, image_display, txt],
425
  ).then(lambda: gr.Textbox(interactive=True), None, [txt])
426
 
427
+ upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn, segment_btn])
428
 
429
+ dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn, segment_btn])
430
 
431
  clear_btn.click(clear_chat, outputs=[chatbot, image_display])
432
  new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
medrax/tools/llava_med.py CHANGED
@@ -56,6 +56,7 @@ class LlavaMedTool(BaseTool):
56
  def __init__(
57
  self,
58
  model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
 
59
  cache_dir: str = "/model-weights",
60
  low_cpu_mem_usage: bool = True,
61
  torch_dtype: torch.dtype = torch.bfloat16,
@@ -68,6 +69,7 @@ class LlavaMedTool(BaseTool):
68
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
69
  model_path=model_path,
70
  model_base=None,
 
71
  model_name=model_path,
72
  load_in_4bit=load_in_4bit,
73
  load_in_8bit=load_in_8bit,
 
56
  def __init__(
57
  self,
58
  model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
59
+ # model_path: str = "microsoft/llava-rad",
60
  cache_dir: str = "/model-weights",
61
  low_cpu_mem_usage: bool = True,
62
  torch_dtype: torch.dtype = torch.bfloat16,
 
69
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
70
  model_path=model_path,
71
  model_base=None,
72
+ # model_base="lmsys/vicuna-7b-v1.5",
73
  model_name=model_path,
74
  load_in_4bit=load_in_4bit,
75
  load_in_8bit=load_in_8bit,