Commit
·
f48d0d7
1
Parent(s):
a7f8e09
included new user and updated buttons
Browse files- interface.py +18 -17
- medrax/tools/llava_med.py +2 -0
interface.py
CHANGED
@@ -16,7 +16,8 @@ REPORT_DIR.mkdir(exist_ok=True)
|
|
16 |
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
|
17 |
|
18 |
USERS = {
|
19 |
-
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6'
|
|
|
20 |
}
|
21 |
|
22 |
class ChatInterface:
|
@@ -73,7 +74,7 @@ class ChatInterface:
|
|
73 |
else:
|
74 |
self.display_file_path = str(saved_path)
|
75 |
|
76 |
-
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True)
|
77 |
|
78 |
def add_message(
|
79 |
self, message: str, display_image: str, history: List[dict]
|
@@ -266,7 +267,7 @@ def create_demo(agent, tools_dict):
|
|
266 |
)
|
267 |
with gr.Row():
|
268 |
analyze_btn = gr.Button("Analyze", interactive=False)
|
269 |
-
ground_btn = gr.Button("Ground", interactive=False)
|
270 |
segment_btn = gr.Button("Segment", interactive=False)
|
271 |
with gr.Row():
|
272 |
clear_btn = gr.Button("Clear Chat")
|
@@ -394,38 +395,38 @@ def create_demo(agent, tools_dict):
|
|
394 |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
395 |
|
396 |
analyze_btn.click(
|
397 |
-
lambda: gr.update(value="Analyze
|
398 |
).then(
|
399 |
-
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
400 |
).then(
|
401 |
interface.process_message,
|
402 |
inputs=[txt, image_display, chatbot],
|
403 |
outputs=[chatbot, image_display, txt],
|
404 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
405 |
|
406 |
-
ground_btn.click(
|
407 |
-
|
408 |
-
).then(
|
409 |
-
|
410 |
-
).then(
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
415 |
|
416 |
segment_btn.click(
|
417 |
lambda: gr.update(value="Segment the major affected lung"), None, txt
|
418 |
).then(
|
419 |
-
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
420 |
).then(
|
421 |
interface.process_message,
|
422 |
inputs=[txt, image_display, chatbot],
|
423 |
outputs=[chatbot, image_display, txt],
|
424 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
425 |
|
426 |
-
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn,
|
427 |
|
428 |
-
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn,
|
429 |
|
430 |
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
431 |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
|
|
16 |
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme'
|
17 |
|
18 |
USERS = {
|
19 |
+
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6',
|
20 |
+
'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C'
|
21 |
}
|
22 |
|
23 |
class ChatInterface:
|
|
|
74 |
else:
|
75 |
self.display_file_path = str(saved_path)
|
76 |
|
77 |
+
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True)
|
78 |
|
79 |
def add_message(
|
80 |
self, message: str, display_image: str, history: List[dict]
|
|
|
267 |
)
|
268 |
with gr.Row():
|
269 |
analyze_btn = gr.Button("Analyze", interactive=False)
|
270 |
+
# ground_btn = gr.Button("Ground", interactive=False)
|
271 |
segment_btn = gr.Button("Segment", interactive=False)
|
272 |
with gr.Row():
|
273 |
clear_btn = gr.Button("Clear Chat")
|
|
|
395 |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
396 |
|
397 |
analyze_btn.click(
|
398 |
+
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt
|
399 |
).then(
|
400 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
401 |
).then(
|
402 |
interface.process_message,
|
403 |
inputs=[txt, image_display, chatbot],
|
404 |
outputs=[chatbot, image_display, txt],
|
405 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
406 |
|
407 |
+
# ground_btn.click(
|
408 |
+
# lambda: gr.update(value="Ground the main disease in this CXR"), None, txt
|
409 |
+
# ).then(
|
410 |
+
# interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
411 |
+
# ).then(
|
412 |
+
# interface.process_message,
|
413 |
+
# inputs=[txt, image_display, chatbot],
|
414 |
+
# outputs=[chatbot, image_display, txt],
|
415 |
+
# ).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
416 |
|
417 |
segment_btn.click(
|
418 |
lambda: gr.update(value="Segment the major affected lung"), None, txt
|
419 |
).then(
|
420 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
421 |
).then(
|
422 |
interface.process_message,
|
423 |
inputs=[txt, image_display, chatbot],
|
424 |
outputs=[chatbot, image_display, txt],
|
425 |
).then(lambda: gr.Textbox(interactive=True), None, [txt])
|
426 |
|
427 |
+
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=[image_display, analyze_btn, segment_btn])
|
428 |
|
429 |
+
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=[image_display, analyze_btn, segment_btn])
|
430 |
|
431 |
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
432 |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
medrax/tools/llava_med.py
CHANGED
@@ -56,6 +56,7 @@ class LlavaMedTool(BaseTool):
|
|
56 |
def __init__(
|
57 |
self,
|
58 |
model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
|
|
|
59 |
cache_dir: str = "/model-weights",
|
60 |
low_cpu_mem_usage: bool = True,
|
61 |
torch_dtype: torch.dtype = torch.bfloat16,
|
@@ -68,6 +69,7 @@ class LlavaMedTool(BaseTool):
|
|
68 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
69 |
model_path=model_path,
|
70 |
model_base=None,
|
|
|
71 |
model_name=model_path,
|
72 |
load_in_4bit=load_in_4bit,
|
73 |
load_in_8bit=load_in_8bit,
|
|
|
56 |
def __init__(
|
57 |
self,
|
58 |
model_path: str = "microsoft/llava-med-v1.5-mistral-7b",
|
59 |
+
# model_path: str = "microsoft/llava-rad",
|
60 |
cache_dir: str = "/model-weights",
|
61 |
low_cpu_mem_usage: bool = True,
|
62 |
torch_dtype: torch.dtype = torch.bfloat16,
|
|
|
69 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
70 |
model_path=model_path,
|
71 |
model_base=None,
|
72 |
+
# model_base="lmsys/vicuna-7b-v1.5",
|
73 |
model_name=model_path,
|
74 |
load_in_4bit=load_in_4bit,
|
75 |
load_in_8bit=load_in_8bit,
|