haritsahm's picture
Reformat Overview (#1)
1f4f53e
from pathlib import Path
from typing import List
import cv2
import gradio as gr
import numpy as np
import torch
from PIL import Image
from models import phc_models
from utils import utils, page_utils
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
BILATERIAL_WEIGHT = 'weights/phresnet18_cbis2views.pt'
BILATERAL_MODEL = phc_models.PHCResNet18(
channels=2, n=2, num_classes=1, visualize=True)
BILATERAL_MODEL.add_top_blocks(num_classes=1)
BILATERAL_MODEL.load_state_dict(torch.load(
BILATERIAL_WEIGHT, map_location='cpu'))
BILATERAL_MODEL = BILATERAL_MODEL.to(device)
BILATERAL_MODEL.eval()
INPUT_HEIGHT, INPUT_WIDTH = 600, 500
SUPPORTED_IMG_EXT = ['.png', '.jpg', '.jpeg']
EXAMPLE_IMAGES = [
['examples/f4b2d377f43ba0bd_left_cc.png',
'examples/f4b2d377f43ba0bd_left_mlo.jpg'],
['examples/f4b2d377f43ba0bd_right_cc.png',
'examples/f4b2d377f43ba0bd_right_mlo.jpeg'],
['examples/P_00001_LEFT_cc.jpg', 'examples/P_00001_LEFT_mlo.jpeg'],
]
# Model warmup
test_images = np.random.randint(0, 255, (2, INPUT_HEIGHT, INPUT_WIDTH))
test_images = torch.from_numpy(test_images).to(device)
test_images = test_images.unsqueeze(0) # Add batch dimension
for _ in range(10):
_, _, _ = BILATERAL_MODEL(test_images)
test_images = None
def filter_files(files: List) -> List:
"""Filter uploaded files.
The model requires a pair of CC-MLO view of the breast scan.
This function will filter and ensure the inputs are as expected.
FIlter:
- Not enough number of files
- Unsupported extensions
- Missing required pair or part
Parameters
----------
files : List[tempfile._TemporaryFileWrapper]
List of path to downloaded files
Returns
-------
List[pathlib.Path]
List of path to downloaded files
Raises
------
gr.Error
If the files is not equal to 2,
gr.Error
If the extension is unsupported
gr.Error
If specific view or side of mammography is missing.
"""
if len(files) != 2:
raise gr.Error(
f'Need exactly 2 images. Currently have {len(files)} images!')
file_paths = [Path(file.name) for file in files]
if not all([path.suffix in SUPPORTED_IMG_EXT for path in file_paths]):
raise gr.Error(f'There is a file with unsupported type. \
Make sure all files are in {SUPPORTED_IMG_EXT}!')
# Table to store view(row), side(column)
table = np.zeros((2, 2), dtype=bool)
bin_left = 0
bin_right = 0
cc_first = False
for idx, file in enumerate(file_paths):
splits = file.name.split('_')
# Check if view is present
if any(['cc' in part.lower() for part in splits]):
table[0, :] = [True, True]
if idx == 0:
cc_first = True
if any(['mlo' in part.lower() for part in splits]):
table[1, :] = [True, True]
# Check if side is present
if any(['left' in part.lower() for part in splits]):
table[:, 0] &= True
bin_left += 1
elif any(['right' in part.lower() for part in splits]):
table[:, 1] &= True
bin_right += 1
# Ensure cc_first
if not cc_first:
file_paths.reverse()
# Reset side that has not enough files
if bin_left < 2:
table[:, 0] &= False
if bin_right < 2:
table[:, 1] &= False
if not any([all(table[:, 0]), all(table[:, 1])]):
raise gr.Error('Missing bilateral-view pair for Left or Right side.')
return file_paths
def predict_bilateral(cc_file, mlo_file):
"""Predict Bilateral Mammography.
Parameters
----------
files : List[tempfile._TemporaryFileWrapper]
TemporaryFile object for the uploaded file
Returns
-------
List[List, Dict]
List of objects that will be used to display the result
"""
filtered_files = filter_files([cc_file, mlo_file])
displays_imgs = []
images = []
for path in filtered_files:
image = np.array(Image.open(str(path)))
image = cv2.normalize(
image, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
image = cv2.resize(
image, (INPUT_WIDTH, INPUT_HEIGHT), interpolation=cv2.INTER_LINEAR)
images.append(image)
images = np.asarray(images).astype(np.float32)
im_h, im_w = images[0].shape[:2]
images_t = torch.from_numpy(images)
images_t = images_t.unsqueeze(0) # Add batch dimension
images_t = images_t.to(device)
out, _, out_refiner = BILATERAL_MODEL(images_t)
out_refiner = utils.mean_activations(out_refiner).numpy()
probability = torch.sigmoid(out).detach().cpu().item()
label_name = 'Malignant' if probability > 0.5 else 'Normal/Benign'
lebels_dict = {label_name: probability}
refined_view_norm = cv2.normalize(
out_refiner, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
refined_view = cv2.applyColorMap(refined_view_norm, cv2.COLORMAP_JET)
refined_view = cv2.resize(
refined_view, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
image0_colored = cv2.normalize(
images[0], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
image0_colored = cv2.cvtColor(image0_colored, cv2.COLOR_GRAY2RGB)
image1_colored = cv2.normalize(
images[1], None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
image1_colored = cv2.cvtColor(image1_colored, cv2.COLOR_GRAY2RGB)
heatmap0_overlay = cv2.addWeighted(
image0_colored, 1.0, refined_view, 0.5, 0)
heatmap1_overlay = cv2.addWeighted(
image1_colored, 1.0, refined_view, 0.5, 0)
displays_imgs += [(image0_colored, 'CC'), (image1_colored, 'MLO')]
displays_imgs.append((heatmap0_overlay, 'CC Interest Area'))
displays_imgs.append((heatmap1_overlay, 'MLO Interest Area'))
return displays_imgs, lebels_dict
def run():
"""Run Gradio App."""
with open('index.html', encoding='utf-8') as f:
html_content = f.read()
with gr.Blocks(theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set(
button_primary_background_fill='*primary_600',
button_primary_background_fill_hover='*primary_500',
button_primary_text_color='white',
)) as demo:
with gr.Column():
gr.HTML(html_content)
with gr.Row():
with gr.Column():
cc_file = gr.File(file_count='single',
file_types=SUPPORTED_IMG_EXT, label='CC View')
mlo_file = gr.File(file_count='single',
file_types=SUPPORTED_IMG_EXT, label='MLO View')
with gr.Row():
clear_btn = gr.Button('Clear')
process_btn = gr.Button('Process', variant="primary")
with gr.Column():
output_gallery = gr.Gallery(
label='Highlighted Area').style(grid=[2], height='auto')
cancer_type = gr.Label(label='Cancer Type')
gr.Examples(
examples=EXAMPLE_IMAGES,
inputs=[cc_file, mlo_file],
)
gr.Markdown('Note that this method is sensitive to input image types.\
Current pipeline expect the values between 0.0-255.0')
process_btn.click(
fn=predict_bilateral,
inputs=[cc_file, mlo_file],
outputs=[output_gallery, cancer_type]
)
clear_btn.click(
lambda _: (
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
gr.update(value=None),
),
inputs=None,
outputs=[
cc_file,
mlo_file,
output_gallery,
cancer_type,
],
)
demo.launch(server_name='0.0.0.0', server_port=7860) # nosec B104
if __name__ == '__main__':
run()