MedSam / medsam2_infer_3D_CT.py
Dhruv-Ty's picture
initial commit
ac239ba
from glob import glob
from tqdm import tqdm
import os
from os.path import join, basename
import re
import matplotlib.pyplot as plt
from collections import OrderedDict
import pandas as pd
import numpy as np
import argparse
from PIL import Image
import SimpleITK as sitk
import torch
import torch.multiprocessing as mp
from sam2.build_sam import build_sam2_video_predictor_npz
import SimpleITK as sitk
from skimage import measure, morphology
torch.set_float32_matmul_precision('high')
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)
np.random.seed(2024)
parser = argparse.ArgumentParser()
parser.add_argument(
'--checkpoint',
type=str,
default="checkpoints/MedSAM2_latest.pt",
help='checkpoint path',
)
parser.add_argument(
'--cfg',
type=str,
default="configs/sam2.1_hiera_t512.yaml",
help='model config',
)
parser.add_argument(
'-i',
'--imgs_path',
type=str,
default="CT_DeepLesion/images",
help='imgs path',
)
parser.add_argument(
'--gts_path',
default=None,
help='simulate prompts based on ground truth',
)
parser.add_argument(
'-o',
'--pred_save_dir',
type=str,
default="./DeeLesion_results",
help='path to save segmentation results',
)
# add option to propagate with either box or mask
parser.add_argument(
'--propagate_with_box',
default=True,
action='store_true',
help='whether to propagate with box'
)
args = parser.parse_args()
checkpoint = args.checkpoint
model_cfg = args.cfg
imgs_path = args.imgs_path
gts_path = args.gts_path
pred_save_dir = args.pred_save_dir
os.makedirs(pred_save_dir, exist_ok=True)
propagate_with_box = args.propagate_with_box
def getLargestCC(segmentation):
labels = measure.label(segmentation)
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
return largestCC
def dice_multi_class(preds, targets):
smooth = 1.0
assert preds.shape == targets.shape
labels = np.unique(targets)[1:]
dices = []
for label in labels:
pred = preds == label
target = targets == label
intersection = (pred * target).sum()
dices.append((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth))
return np.mean(dices)
def show_mask(mask, ax, mask_color=None, alpha=0.5):
"""
show mask on the image
Parameters
----------
mask : numpy.ndarray
mask of the image
ax : matplotlib.axes.Axes
axes to plot the mask
mask_color : numpy.ndarray
color of the mask
alpha : float
transparency of the mask
"""
if mask_color is not None:
color = np.concatenate([mask_color, np.array([alpha])], axis=0)
else:
color = np.array([251/255, 252/255, 30/255, alpha])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax, edgecolor='blue'):
"""
show bounding box on the image
Parameters
----------
box : numpy.ndarray
bounding box coordinates in the original image
ax : matplotlib.axes.Axes
axes to plot the bounding box
edgecolor : str
color of the bounding box
"""
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2))
def resize_grayscale_to_rgb_and_resize(array, image_size):
"""
Resize a 3D grayscale NumPy array to an RGB image and then resize it.
Parameters:
array (np.ndarray): Input array of shape (d, h, w).
image_size (int): Desired size for the width and height.
Returns:
np.ndarray: Resized array of shape (d, 3, image_size, image_size).
"""
d, h, w = array.shape
resized_array = np.zeros((d, 3, image_size, image_size))
for i in range(d):
img_pil = Image.fromarray(array[i].astype(np.uint8))
img_rgb = img_pil.convert("RGB")
img_resized = img_rgb.resize((image_size, image_size))
img_array = np.array(img_resized).transpose(2, 0, 1) # (3, image_size, image_size)
resized_array[i] = img_array
return resized_array
def mask2D_to_bbox(gt2D, max_shift=20):
y_indices, x_indices = np.where(gt2D > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = gt2D.shape
bbox_shift = np.random.randint(0, max_shift + 1, 1)[0]
x_min = max(0, x_min - bbox_shift)
x_max = min(W-1, x_max + bbox_shift)
y_min = max(0, y_min - bbox_shift)
y_max = min(H-1, y_max + bbox_shift)
boxes = np.array([x_min, y_min, x_max, y_max])
return boxes
def mask3D_to_bbox(gt3D, max_shift=20):
z_indices, y_indices, x_indices = np.where(gt3D > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
z_min, z_max = np.min(z_indices), np.max(z_indices)
D, H, W = gt3D.shape
bbox_shift = np.random.randint(0, max_shift + 1, 1)[0]
x_min = max(0, x_min - bbox_shift)
x_max = min(W-1, x_max + bbox_shift)
y_min = max(0, y_min - bbox_shift)
y_max = min(H-1, y_max + bbox_shift)
z_min = max(0, z_min)
z_max = min(D-1, z_max)
boxes3d = np.array([x_min, y_min, z_min, x_max, y_max, z_max])
return boxes3d
DL_info = pd.read_csv('CT_DeepLesion/DeepLesion_Dataset_Info.csv')
nii_fnames = sorted(os.listdir(imgs_path))
nii_fnames = [i for i in nii_fnames if i.endswith('.nii.gz')]
nii_fnames = [i for i in nii_fnames if not i.startswith('._')]
print(f'Processing {len(nii_fnames)} nii files')
seg_info = OrderedDict()
seg_info['nii_name'] = []
seg_info['key_slice_index'] = []
seg_info['DICOM_windows'] = []
# initialized predictor
predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint)
for nii_fname in tqdm(nii_fnames):
# get corresponding case info
range_suffix = re.findall(r'\d{3}-\d{3}', nii_fname)[0]
slice_range = range_suffix.split('-')
slice_range = [str(int(s)) for s in slice_range]
slice_range = ', '.join(slice_range)
nii_image = sitk.ReadImage(join(imgs_path, nii_fname))
nii_image_data = sitk.GetArrayFromImage(nii_image)
case_name = re.findall(r'^(\d{6}_\d{2}_\d{2})', nii_fname)[0]
case_df = DL_info[
DL_info['File_name'].str.contains(case_name) &
DL_info['Slice_range'].str.contains(slice_range)
].copy()
segs_3D = np.zeros(nii_image_data.shape, dtype=np.uint8)
for row_id, row in case_df.iterrows():
# print(f'Processing {case_name} tumor {tumor_idx}')
# get the key slice info
lower_bound, upper_bound = row['DICOM_windows'].split(',')
lower_bound, upper_bound = float(lower_bound), float(upper_bound)
nii_image_data_pre = np.clip(nii_image_data, lower_bound, upper_bound)
nii_image_data_pre = (nii_image_data_pre - np.min(nii_image_data_pre))/(np.max(nii_image_data_pre)-np.min(nii_image_data_pre))*255.0
nii_image_data_pre = np.uint8(nii_image_data_pre)
key_slice_idx = row['Key_slice_index']
key_slice_idx = int(key_slice_idx)
slice_range = row['Slice_range']
slice_idx_start, slice_idx_end = slice_range.split(',')
slice_idx_start, slice_idx_end = int(slice_idx_start), int(slice_idx_end)
bbox_coords = row['Bounding_boxes']
bbox_coords = bbox_coords.split(',')
bbox_coords = [int(float(coord)) for coord in bbox_coords]
#bbox_coords = expand_box(bbox_coords)
bbox = np.array(bbox_coords) # y_min, x_min, y_max, x_max
bbox = np.array([bbox[1], bbox[0], bbox[3], bbox[2]])
key_slice_idx_offset = key_slice_idx - slice_idx_start
key_slice_img = nii_image_data_pre[key_slice_idx_offset, :,:]
img_3D_ori = nii_image_data_pre
assert np.max(img_3D_ori) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D_ori)}'
video_height = key_slice_img.shape[0]
video_width = key_slice_img.shape[1]
img_resized = resize_grayscale_to_rgb_and_resize(img_3D_ori, 512)
img_resized = img_resized / 255.0
img_resized = torch.from_numpy(img_resized).cuda()
img_mean=(0.485, 0.456, 0.406)
img_std=(0.229, 0.224, 0.225)
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda()
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda()
img_resized -= img_mean
img_resized /= img_std
z_mids = []
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
inference_state = predictor.init_state(img_resized, video_height, video_width)
if propagate_with_box:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=key_slice_idx_offset,
obj_id=1,
box=bbox,
)
else: # gt
pass
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
predictor.reset_state(inference_state)
if propagate_with_box:
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=key_slice_idx_offset,
obj_id=1,
box=bbox,
)
else: # gt
pass
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True):
segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1
predictor.reset_state(inference_state)
if np.max(segs_3D) > 0:
segs_3D = getLargestCC(segs_3D)
segs_3D = np.uint8(segs_3D)
sitk_image = sitk.GetImageFromArray(img_3D_ori)
sitk_image.CopyInformation(nii_image)
sitk_mask = sitk.GetImageFromArray(segs_3D)
sitk_mask.CopyInformation(nii_image)
# save single lesion
key_slice_idx = row['Key_slice_index']
save_seg_name = nii_fname.split('.nii.gz')[0] + f'_k{key_slice_idx}_mask.nii.gz'
sitk.WriteImage(sitk_image, os.path.join(pred_save_dir, nii_fname.replace('.nii.gz', '_img.nii.gz')))
sitk.WriteImage(sitk_mask, os.path.join(pred_save_dir, save_seg_name))
seg_info['nii_name'].append(save_seg_name)
seg_info['key_slice_index'].append(key_slice_idx)
seg_info['DICOM_windows'].append(row['DICOM_windows'])
seg_info_df = pd.DataFrame(seg_info)
seg_info_df.to_csv(join(pred_save_dir, 'tiny_seg_info202412.csv'), index=False)