|
from glob import glob |
|
from tqdm import tqdm |
|
import os |
|
from os.path import join, basename |
|
import re |
|
import matplotlib.pyplot as plt |
|
from collections import OrderedDict |
|
import pandas as pd |
|
import numpy as np |
|
import argparse |
|
|
|
from PIL import Image |
|
import SimpleITK as sitk |
|
import torch |
|
import torch.multiprocessing as mp |
|
from sam2.build_sam import build_sam2_video_predictor_npz |
|
import SimpleITK as sitk |
|
from skimage import measure, morphology |
|
|
|
torch.set_float32_matmul_precision('high') |
|
torch.manual_seed(2024) |
|
torch.cuda.manual_seed(2024) |
|
np.random.seed(2024) |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
'--checkpoint', |
|
type=str, |
|
default="checkpoints/MedSAM2_latest.pt", |
|
help='checkpoint path', |
|
) |
|
parser.add_argument( |
|
'--cfg', |
|
type=str, |
|
default="configs/sam2.1_hiera_t512.yaml", |
|
help='model config', |
|
) |
|
|
|
parser.add_argument( |
|
'-i', |
|
'--imgs_path', |
|
type=str, |
|
default="CT_DeepLesion/images", |
|
help='imgs path', |
|
) |
|
parser.add_argument( |
|
'--gts_path', |
|
default=None, |
|
help='simulate prompts based on ground truth', |
|
) |
|
parser.add_argument( |
|
'-o', |
|
'--pred_save_dir', |
|
type=str, |
|
default="./DeeLesion_results", |
|
help='path to save segmentation results', |
|
) |
|
|
|
parser.add_argument( |
|
'--propagate_with_box', |
|
default=True, |
|
action='store_true', |
|
help='whether to propagate with box' |
|
) |
|
|
|
args = parser.parse_args() |
|
checkpoint = args.checkpoint |
|
model_cfg = args.cfg |
|
imgs_path = args.imgs_path |
|
gts_path = args.gts_path |
|
pred_save_dir = args.pred_save_dir |
|
os.makedirs(pred_save_dir, exist_ok=True) |
|
propagate_with_box = args.propagate_with_box |
|
|
|
def getLargestCC(segmentation): |
|
labels = measure.label(segmentation) |
|
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 |
|
return largestCC |
|
|
|
def dice_multi_class(preds, targets): |
|
smooth = 1.0 |
|
assert preds.shape == targets.shape |
|
labels = np.unique(targets)[1:] |
|
dices = [] |
|
for label in labels: |
|
pred = preds == label |
|
target = targets == label |
|
intersection = (pred * target).sum() |
|
dices.append((2.0 * intersection + smooth) / (pred.sum() + target.sum() + smooth)) |
|
return np.mean(dices) |
|
|
|
def show_mask(mask, ax, mask_color=None, alpha=0.5): |
|
""" |
|
show mask on the image |
|
|
|
Parameters |
|
---------- |
|
mask : numpy.ndarray |
|
mask of the image |
|
ax : matplotlib.axes.Axes |
|
axes to plot the mask |
|
mask_color : numpy.ndarray |
|
color of the mask |
|
alpha : float |
|
transparency of the mask |
|
""" |
|
if mask_color is not None: |
|
color = np.concatenate([mask_color, np.array([alpha])], axis=0) |
|
else: |
|
color = np.array([251/255, 252/255, 30/255, alpha]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
|
|
def show_box(box, ax, edgecolor='blue'): |
|
""" |
|
show bounding box on the image |
|
|
|
Parameters |
|
---------- |
|
box : numpy.ndarray |
|
bounding box coordinates in the original image |
|
ax : matplotlib.axes.Axes |
|
axes to plot the bounding box |
|
edgecolor : str |
|
color of the bounding box |
|
""" |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0,0,0,0), lw=2)) |
|
|
|
|
|
def resize_grayscale_to_rgb_and_resize(array, image_size): |
|
""" |
|
Resize a 3D grayscale NumPy array to an RGB image and then resize it. |
|
|
|
Parameters: |
|
array (np.ndarray): Input array of shape (d, h, w). |
|
image_size (int): Desired size for the width and height. |
|
|
|
Returns: |
|
np.ndarray: Resized array of shape (d, 3, image_size, image_size). |
|
""" |
|
d, h, w = array.shape |
|
resized_array = np.zeros((d, 3, image_size, image_size)) |
|
|
|
for i in range(d): |
|
img_pil = Image.fromarray(array[i].astype(np.uint8)) |
|
img_rgb = img_pil.convert("RGB") |
|
img_resized = img_rgb.resize((image_size, image_size)) |
|
img_array = np.array(img_resized).transpose(2, 0, 1) |
|
resized_array[i] = img_array |
|
|
|
return resized_array |
|
|
|
def mask2D_to_bbox(gt2D, max_shift=20): |
|
y_indices, x_indices = np.where(gt2D > 0) |
|
x_min, x_max = np.min(x_indices), np.max(x_indices) |
|
y_min, y_max = np.min(y_indices), np.max(y_indices) |
|
H, W = gt2D.shape |
|
bbox_shift = np.random.randint(0, max_shift + 1, 1)[0] |
|
x_min = max(0, x_min - bbox_shift) |
|
x_max = min(W-1, x_max + bbox_shift) |
|
y_min = max(0, y_min - bbox_shift) |
|
y_max = min(H-1, y_max + bbox_shift) |
|
boxes = np.array([x_min, y_min, x_max, y_max]) |
|
return boxes |
|
|
|
def mask3D_to_bbox(gt3D, max_shift=20): |
|
z_indices, y_indices, x_indices = np.where(gt3D > 0) |
|
x_min, x_max = np.min(x_indices), np.max(x_indices) |
|
y_min, y_max = np.min(y_indices), np.max(y_indices) |
|
z_min, z_max = np.min(z_indices), np.max(z_indices) |
|
D, H, W = gt3D.shape |
|
bbox_shift = np.random.randint(0, max_shift + 1, 1)[0] |
|
x_min = max(0, x_min - bbox_shift) |
|
x_max = min(W-1, x_max + bbox_shift) |
|
y_min = max(0, y_min - bbox_shift) |
|
y_max = min(H-1, y_max + bbox_shift) |
|
z_min = max(0, z_min) |
|
z_max = min(D-1, z_max) |
|
boxes3d = np.array([x_min, y_min, z_min, x_max, y_max, z_max]) |
|
return boxes3d |
|
|
|
|
|
DL_info = pd.read_csv('CT_DeepLesion/DeepLesion_Dataset_Info.csv') |
|
nii_fnames = sorted(os.listdir(imgs_path)) |
|
nii_fnames = [i for i in nii_fnames if i.endswith('.nii.gz')] |
|
nii_fnames = [i for i in nii_fnames if not i.startswith('._')] |
|
print(f'Processing {len(nii_fnames)} nii files') |
|
seg_info = OrderedDict() |
|
seg_info['nii_name'] = [] |
|
seg_info['key_slice_index'] = [] |
|
seg_info['DICOM_windows'] = [] |
|
|
|
predictor = build_sam2_video_predictor_npz(model_cfg, checkpoint) |
|
|
|
for nii_fname in tqdm(nii_fnames): |
|
|
|
range_suffix = re.findall(r'\d{3}-\d{3}', nii_fname)[0] |
|
slice_range = range_suffix.split('-') |
|
slice_range = [str(int(s)) for s in slice_range] |
|
slice_range = ', '.join(slice_range) |
|
nii_image = sitk.ReadImage(join(imgs_path, nii_fname)) |
|
nii_image_data = sitk.GetArrayFromImage(nii_image) |
|
|
|
case_name = re.findall(r'^(\d{6}_\d{2}_\d{2})', nii_fname)[0] |
|
case_df = DL_info[ |
|
DL_info['File_name'].str.contains(case_name) & |
|
DL_info['Slice_range'].str.contains(slice_range) |
|
].copy() |
|
|
|
segs_3D = np.zeros(nii_image_data.shape, dtype=np.uint8) |
|
|
|
for row_id, row in case_df.iterrows(): |
|
|
|
|
|
lower_bound, upper_bound = row['DICOM_windows'].split(',') |
|
lower_bound, upper_bound = float(lower_bound), float(upper_bound) |
|
nii_image_data_pre = np.clip(nii_image_data, lower_bound, upper_bound) |
|
nii_image_data_pre = (nii_image_data_pre - np.min(nii_image_data_pre))/(np.max(nii_image_data_pre)-np.min(nii_image_data_pre))*255.0 |
|
nii_image_data_pre = np.uint8(nii_image_data_pre) |
|
key_slice_idx = row['Key_slice_index'] |
|
key_slice_idx = int(key_slice_idx) |
|
slice_range = row['Slice_range'] |
|
slice_idx_start, slice_idx_end = slice_range.split(',') |
|
slice_idx_start, slice_idx_end = int(slice_idx_start), int(slice_idx_end) |
|
bbox_coords = row['Bounding_boxes'] |
|
bbox_coords = bbox_coords.split(',') |
|
bbox_coords = [int(float(coord)) for coord in bbox_coords] |
|
|
|
bbox = np.array(bbox_coords) |
|
bbox = np.array([bbox[1], bbox[0], bbox[3], bbox[2]]) |
|
|
|
key_slice_idx_offset = key_slice_idx - slice_idx_start |
|
key_slice_img = nii_image_data_pre[key_slice_idx_offset, :,:] |
|
|
|
img_3D_ori = nii_image_data_pre |
|
assert np.max(img_3D_ori) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3D_ori)}' |
|
|
|
video_height = key_slice_img.shape[0] |
|
video_width = key_slice_img.shape[1] |
|
img_resized = resize_grayscale_to_rgb_and_resize(img_3D_ori, 512) |
|
img_resized = img_resized / 255.0 |
|
img_resized = torch.from_numpy(img_resized).cuda() |
|
img_mean=(0.485, 0.456, 0.406) |
|
img_std=(0.229, 0.224, 0.225) |
|
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None].cuda() |
|
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None].cuda() |
|
img_resized -= img_mean |
|
img_resized /= img_std |
|
z_mids = [] |
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
inference_state = predictor.init_state(img_resized, video_height, video_width) |
|
if propagate_with_box: |
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( |
|
inference_state=inference_state, |
|
frame_idx=key_slice_idx_offset, |
|
obj_id=1, |
|
box=bbox, |
|
) |
|
else: |
|
pass |
|
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): |
|
segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1 |
|
predictor.reset_state(inference_state) |
|
if propagate_with_box: |
|
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( |
|
inference_state=inference_state, |
|
frame_idx=key_slice_idx_offset, |
|
obj_id=1, |
|
box=bbox, |
|
) |
|
else: |
|
pass |
|
|
|
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True): |
|
segs_3D[out_frame_idx, (out_mask_logits[0] > 0.0).cpu().numpy()[0]] = 1 |
|
predictor.reset_state(inference_state) |
|
if np.max(segs_3D) > 0: |
|
segs_3D = getLargestCC(segs_3D) |
|
segs_3D = np.uint8(segs_3D) |
|
sitk_image = sitk.GetImageFromArray(img_3D_ori) |
|
sitk_image.CopyInformation(nii_image) |
|
sitk_mask = sitk.GetImageFromArray(segs_3D) |
|
sitk_mask.CopyInformation(nii_image) |
|
|
|
key_slice_idx = row['Key_slice_index'] |
|
save_seg_name = nii_fname.split('.nii.gz')[0] + f'_k{key_slice_idx}_mask.nii.gz' |
|
sitk.WriteImage(sitk_image, os.path.join(pred_save_dir, nii_fname.replace('.nii.gz', '_img.nii.gz'))) |
|
sitk.WriteImage(sitk_mask, os.path.join(pred_save_dir, save_seg_name)) |
|
seg_info['nii_name'].append(save_seg_name) |
|
seg_info['key_slice_index'].append(key_slice_idx) |
|
seg_info['DICOM_windows'].append(row['DICOM_windows']) |
|
|
|
seg_info_df = pd.DataFrame(seg_info) |
|
seg_info_df.to_csv(join(pred_save_dir, 'tiny_seg_info202412.csv'), index=False) |
|
|
|
|
|
|
|
|