Update app.py
Browse files
app.py
CHANGED
|
@@ -17,13 +17,13 @@ def load_model_and_processor(hf_token: str):
|
|
| 17 |
return _model_cache[hf_token]
|
| 18 |
device = torch.device("cpu")
|
| 19 |
model = AutoModelForCausalLM.from_pretrained(
|
| 20 |
-
"microsoft/maira-2",
|
| 21 |
-
trust_remote_code=True,
|
| 22 |
use_auth_token=hf_token
|
| 23 |
)
|
| 24 |
processor = AutoProcessor.from_pretrained(
|
| 25 |
-
"microsoft/maira-2",
|
| 26 |
-
trust_remote_code=True,
|
| 27 |
use_auth_token=hf_token
|
| 28 |
)
|
| 29 |
model.eval()
|
|
@@ -33,7 +33,7 @@ def load_model_and_processor(hf_token: str):
|
|
| 33 |
|
| 34 |
def get_sample_data() -> dict:
|
| 35 |
"""
|
| 36 |
-
|
| 37 |
"""
|
| 38 |
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
|
| 39 |
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
|
|
@@ -86,7 +86,14 @@ def generate_report(hf_token, frontal, lateral, indication, technique, compariso
|
|
| 86 |
return_tensors="pt",
|
| 87 |
get_grounding=use_grounding,
|
| 88 |
)
|
|
|
|
| 89 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
max_tokens = 450 if use_grounding else 300
|
| 91 |
with torch.no_grad():
|
| 92 |
output_decoding = model.generate(
|
|
@@ -121,6 +128,12 @@ def run_phrase_grounding(hf_token, frontal, phrase):
|
|
| 121 |
return_tensors="pt",
|
| 122 |
)
|
| 123 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
with torch.no_grad():
|
| 125 |
output_decoding = model.generate(
|
| 126 |
**processed_inputs,
|
|
@@ -132,6 +145,7 @@ def run_phrase_grounding(hf_token, frontal, phrase):
|
|
| 132 |
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
|
| 133 |
return prediction
|
| 134 |
|
|
|
|
| 135 |
def login_ui(hf_token):
|
| 136 |
"""Authenticate the user by loading the model."""
|
| 137 |
try:
|
|
@@ -177,14 +191,14 @@ def load_sample_findings():
|
|
| 177 |
sample = get_sample_data()
|
| 178 |
return [
|
| 179 |
save_temp_image(sample["frontal"]), # frontal image file path
|
| 180 |
-
save_temp_image(sample["lateral"]),
|
| 181 |
sample["indication"],
|
| 182 |
sample["technique"],
|
| 183 |
sample["comparison"],
|
| 184 |
None, # prior frontal (not used)
|
| 185 |
None, # prior lateral (not used)
|
| 186 |
None, # prior report (not used)
|
| 187 |
-
False
|
| 188 |
]
|
| 189 |
|
| 190 |
def load_sample_phrase():
|
|
@@ -276,4 +290,4 @@ with gr.Blocks(title="MAIRA-2 Medical Assistant") as demo:
|
|
| 276 |
outputs=pg_output
|
| 277 |
)
|
| 278 |
|
| 279 |
-
demo.launch()
|
|
|
|
| 17 |
return _model_cache[hf_token]
|
| 18 |
device = torch.device("cpu")
|
| 19 |
model = AutoModelForCausalLM.from_pretrained(
|
| 20 |
+
"microsoft/maira-2",
|
| 21 |
+
trust_remote_code=True,
|
| 22 |
use_auth_token=hf_token
|
| 23 |
)
|
| 24 |
processor = AutoProcessor.from_pretrained(
|
| 25 |
+
"microsoft/maira-2",
|
| 26 |
+
trust_remote_code=True,
|
| 27 |
use_auth_token=hf_token
|
| 28 |
)
|
| 29 |
model.eval()
|
|
|
|
| 33 |
|
| 34 |
def get_sample_data() -> dict:
|
| 35 |
"""
|
| 36 |
+
Downloads sample chest X-ray images and associated data.
|
| 37 |
"""
|
| 38 |
frontal_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-1001.png"
|
| 39 |
lateral_image_url = "https://openi.nlm.nih.gov/imgs/512/145/145/CXR145_IM-0290-2001.png"
|
|
|
|
| 86 |
return_tensors="pt",
|
| 87 |
get_grounding=use_grounding,
|
| 88 |
)
|
| 89 |
+
# Move all tensors to the CPU
|
| 90 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
| 91 |
+
# Remove keys containing "image_sizes" to prevent unexpected keyword errors.
|
| 92 |
+
processed_inputs = dict(processed_inputs)
|
| 93 |
+
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
|
| 94 |
+
for key in keys_to_remove:
|
| 95 |
+
processed_inputs.pop(key, None)
|
| 96 |
+
|
| 97 |
max_tokens = 450 if use_grounding else 300
|
| 98 |
with torch.no_grad():
|
| 99 |
output_decoding = model.generate(
|
|
|
|
| 128 |
return_tensors="pt",
|
| 129 |
)
|
| 130 |
processed_inputs = {k: v.to(device) for k, v in processed_inputs.items()}
|
| 131 |
+
# Remove keys containing "image_sizes" to prevent unexpected keyword errors.
|
| 132 |
+
processed_inputs = dict(processed_inputs)
|
| 133 |
+
keys_to_remove = [k for k in processed_inputs if "image_sizes" in k]
|
| 134 |
+
for key in keys_to_remove:
|
| 135 |
+
processed_inputs.pop(key, None)
|
| 136 |
+
|
| 137 |
with torch.no_grad():
|
| 138 |
output_decoding = model.generate(
|
| 139 |
**processed_inputs,
|
|
|
|
| 145 |
prediction = processor.convert_output_to_plaintext_or_grounded_sequence(decoded_text)
|
| 146 |
return prediction
|
| 147 |
|
| 148 |
+
|
| 149 |
def login_ui(hf_token):
|
| 150 |
"""Authenticate the user by loading the model."""
|
| 151 |
try:
|
|
|
|
| 191 |
sample = get_sample_data()
|
| 192 |
return [
|
| 193 |
save_temp_image(sample["frontal"]), # frontal image file path
|
| 194 |
+
save_temp_image(sample["lateral"]), # lateral image file path
|
| 195 |
sample["indication"],
|
| 196 |
sample["technique"],
|
| 197 |
sample["comparison"],
|
| 198 |
None, # prior frontal (not used)
|
| 199 |
None, # prior lateral (not used)
|
| 200 |
None, # prior report (not used)
|
| 201 |
+
False # grounding checkbox default
|
| 202 |
]
|
| 203 |
|
| 204 |
def load_sample_phrase():
|
|
|
|
| 290 |
outputs=pg_output
|
| 291 |
)
|
| 292 |
|
| 293 |
+
demo.launch()
|