|
import warnings |
|
from glob import glob |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
import tyro |
|
from PIL import Image |
|
from scipy.ndimage import binary_erosion, distance_transform_edt |
|
from scipy.stats import ks_2samp |
|
from zea import log |
|
from zea.io_lib import load_image |
|
|
|
import fid_score |
|
from plots import plot_metrics |
|
|
|
|
|
def calculate_fid_score(denoised_image_dirs, ground_truth_dir): |
|
if isinstance(denoised_image_dirs, (str, Path)): |
|
denoised_image_dirs = [denoised_image_dirs] |
|
elif not isinstance(denoised_image_dirs, list): |
|
raise ValueError("Input must be a path or list of paths") |
|
|
|
clean_images_folder = glob(str(ground_truth_dir) + "/*.png") |
|
|
|
print(f"Looking for clean images in: {ground_truth_dir}") |
|
print(f"Found {len(clean_images_folder)} clean images") |
|
|
|
|
|
num_denoised = len(denoised_image_dirs) |
|
num_clean = len(clean_images_folder) |
|
optimal_batch_size = min(8, num_denoised, num_clean) |
|
print(f"Using batch size: {optimal_batch_size}") |
|
|
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", message="os.fork.*JAX is multithreaded") |
|
|
|
fid_value = fid_score.calculate_fid_with_cached_ground_truth( |
|
denoised_image_dirs, |
|
clean_images_folder, |
|
batch_size=optimal_batch_size, |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
num_workers=2 if torch.cuda.is_available() else 0, |
|
dims=2048, |
|
) |
|
return fid_value |
|
|
|
|
|
def gcnr(img1, img2): |
|
"""Generalized Contrast-to-Noise Ratio""" |
|
_, bins = np.histogram(np.concatenate((img1, img2)), bins=256) |
|
f, _ = np.histogram(img1, bins=bins, density=True) |
|
g, _ = np.histogram(img2, bins=bins, density=True) |
|
f /= f.sum() |
|
g /= g.sum() |
|
return 1 - np.sum(np.minimum(f, g)) |
|
|
|
|
|
def cnr(img1, img2): |
|
"""Contrast-to-Noise Ratio""" |
|
return (img1.mean() - img2.mean()) / np.sqrt(img1.var() + img2.var()) |
|
|
|
|
|
def calculate_cnr_gcnr(result_dehazed_cardiac_ultrasound, mask_path): |
|
""" |
|
Evaluate gCNR and CNR metrics for denoised images using paired masks. |
|
Saves detailed and summary statistics to Excel. |
|
""" |
|
results = [] |
|
|
|
mask = np.array(Image.open(mask_path).convert("L")) |
|
|
|
roi1_pixels = result_dehazed_cardiac_ultrasound[mask == 255] |
|
roi2_pixels = result_dehazed_cardiac_ultrasound[mask == 128] |
|
|
|
gcnr_val = gcnr(roi1_pixels, roi2_pixels) |
|
cnr_val = cnr(roi1_pixels, roi2_pixels) |
|
|
|
results.append([cnr_val, gcnr_val]) |
|
|
|
return results |
|
|
|
|
|
def calculate_ks_statistics( |
|
result_hazy_cardiac_ultrasound, result_dehazed_cardiac_ultrasound, mask_path |
|
): |
|
mask = np.array(Image.open(mask_path).convert("L")) |
|
|
|
roi1_original = result_hazy_cardiac_ultrasound[mask == 255] |
|
roi1_denoised = result_dehazed_cardiac_ultrasound[mask == 255] |
|
roi2_original = result_hazy_cardiac_ultrasound[mask == 128] |
|
roi2_denoised = result_dehazed_cardiac_ultrasound[mask == 128] |
|
|
|
roi1_ks_stat, roi1_ks_p_value = (None, None) |
|
roi2_ks_stat, roi2_ks_p_value = (None, None) |
|
|
|
if roi1_original.size > 0 and roi1_denoised.size > 0: |
|
roi1_ks_stat, roi1_ks_p_value = ks_2samp(roi1_original, roi1_denoised) |
|
|
|
if roi2_original.size > 0 and roi2_denoised.size > 0: |
|
roi2_ks_stat, roi2_ks_p_value = ks_2samp(roi2_original, roi2_denoised) |
|
|
|
return roi1_ks_stat, roi1_ks_p_value, roi2_ks_stat, roi2_ks_p_value |
|
|
|
|
|
def calculate_dice_asd(image_path, label_path, checkpoint_path, image_size=224): |
|
try: |
|
from test import inference |
|
except ImportError: |
|
raise ImportError( |
|
"Segmentation method not available, skipping Dice/ASD calculation" |
|
) |
|
|
|
pred_img = inference(image_path, checkpoint_path, image_size) |
|
pred = np.array(pred_img) > 127 |
|
|
|
label = Image.open(label_path).convert("L") |
|
label = label.resize((image_size, image_size), Image.NEAREST) |
|
label = np.array(label) > 127 |
|
|
|
|
|
intersection = np.logical_and(pred, label).sum() |
|
dice = 2 * intersection / (pred.sum() + label.sum() + 1e-8) |
|
|
|
|
|
if pred.sum() == 0 or label.sum() == 0: |
|
asd = np.nan |
|
else: |
|
pred_dt = distance_transform_edt(~pred) |
|
label_dt = distance_transform_edt(~label) |
|
|
|
surface_pred = pred ^ binary_erosion(pred) |
|
surface_label = label ^ binary_erosion(label) |
|
|
|
d1 = pred_dt[surface_label].mean() |
|
d2 = label_dt[surface_pred].mean() |
|
asd = (d1 + d2) / 2 |
|
|
|
return dice, asd |
|
|
|
|
|
def calculate_final_score(aggregates): |
|
try: |
|
|
|
|
|
group1_score = 0 |
|
if aggregates.get("fid") is not None: |
|
fid_min = 60.0 |
|
fid_max = 150.0 |
|
fid_score = (fid_max - aggregates["fid"]) / (fid_max - fid_min) |
|
fid_score = max(0, min(1, fid_score)) |
|
group1_score += fid_score * 100 * 0.33 |
|
|
|
if aggregates.get("cnr_mean") is not None: |
|
cnr_min = 1.0 |
|
cnr_max = 1.5 |
|
cnr_score = (aggregates["cnr_mean"] - cnr_min) / (cnr_max - cnr_min) |
|
cnr_score = max(0, min(1, cnr_score)) |
|
group1_score += cnr_score * 100 * 0.33 |
|
|
|
if aggregates.get("gcnr_mean") is not None: |
|
gcnr_min = 0.5 |
|
gcnr_max = 0.8 |
|
gcnr_score = (aggregates["gcnr_mean"] - gcnr_min) / (gcnr_max - gcnr_min) |
|
gcnr_score = max(0, min(1, gcnr_score)) |
|
group1_score += gcnr_score * 100 * 0.34 |
|
|
|
group2_score = 0 |
|
if aggregates.get("ks_roi1_ksstatistic_mean") is not None: |
|
ks1_min = 0.1 |
|
ks1_max = 0.3 |
|
ks1_score = (ks1_max - aggregates["ks_roi1_ksstatistic_mean"]) / ( |
|
ks1_max - ks1_min |
|
) |
|
ks1_score = max(0, min(1, ks1_score)) |
|
group2_score += ks1_score * 100 * 0.5 |
|
|
|
if aggregates.get("ks_roi2_ksstatistic_mean") is not None: |
|
ks2_min = 0.0 |
|
ks2_max = 0.5 |
|
ks2_score = (aggregates["ks_roi2_ksstatistic_mean"] - ks2_min) / ( |
|
ks2_max - ks2_min |
|
) |
|
ks2_score = max(0, min(1, ks2_score)) |
|
group2_score += ks2_score * 100 * 0.5 |
|
|
|
group3_score = 0 |
|
if aggregates.get("dice_mean") is not None: |
|
dice_min = 0.85 |
|
dice_max = 0.95 |
|
dice_score = (aggregates["dice_mean"] - dice_min) / (dice_max - dice_min) |
|
dice_score = max(0, min(1, dice_score)) |
|
group3_score += dice_score * 100 * 0.5 |
|
if aggregates.get("asd_mean") is not None: |
|
asd_min = 0.7 |
|
asd_max = 2.0 |
|
asd_score = (asd_max - aggregates["asd_mean"]) / (asd_max - asd_min) |
|
asd_score = max(0, min(1, asd_score)) |
|
group3_score += asd_score * 100 * 0.5 |
|
|
|
|
|
final_score = (group1_score * 5 + group2_score * 3 + group3_score * 2) / 10 |
|
|
|
return final_score |
|
|
|
except Exception as e: |
|
print(f"Error calculating final score: {str(e)}") |
|
return 0 |
|
|
|
|
|
def evaluate(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str): |
|
"""Evaluate the dehazing algorithm. |
|
|
|
Args: |
|
folder (str): Path to the folder containing the dehazed images. |
|
Used for evaluating all metrics. |
|
noisy_folder (str): Path to the folder containing the noisy images. |
|
Only used for KS statistics. |
|
roi_folder (str): Path to the folder containing the ROI images. |
|
Used for contrast and KS statistic metrics. |
|
reference_folder (str): Path to the folder containing the reference images. |
|
Used only for FID calculation. |
|
""" |
|
folder = Path(folder) |
|
noisy_folder = Path(noisy_folder) |
|
roi_folder = Path(roi_folder) |
|
reference_folder = Path(reference_folder) |
|
|
|
folder_files = set(f.name for f in folder.glob("*.png")) |
|
noisy_files = set(f.name for f in noisy_folder.glob("*.png")) |
|
roi_files = set(f.name for f in roi_folder.glob("*.png")) |
|
|
|
print(f"Found {len(folder_files)} .png files in output folder: {folder}") |
|
print(f"Found {len(noisy_files)} .png files in noisy folder: {noisy_folder}") |
|
print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}") |
|
|
|
|
|
common_files = sorted(folder_files & roi_files & noisy_files) |
|
print(f"Found {len(common_files)} matching images in noisy/dehazed/roi folders") |
|
assert len(common_files) > 0, ( |
|
"No matching .png files in all folders. Cannot proceed." |
|
) |
|
|
|
metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []} |
|
limits = { |
|
"CNR": [1.0, 1.5], |
|
"gCNR": [0.5, 0.8], |
|
"KS_A": [0.1, 0.3], |
|
"KS_B": [0.0, 0.5], |
|
} |
|
|
|
for name in common_files: |
|
dehazed_path = folder / name |
|
noisy_path = noisy_folder / name |
|
roi_path = roi_folder / name |
|
|
|
try: |
|
img_dehazed = np.array(load_image(str(dehazed_path))) |
|
img_noisy = np.array(load_image(str(noisy_path))) |
|
except Exception as e: |
|
print(f"Error loading image {name}: {e}") |
|
continue |
|
|
|
|
|
cnr_gcnr = calculate_cnr_gcnr(img_dehazed, str(roi_path)) |
|
metrics["CNR"].append(cnr_gcnr[0][0]) |
|
metrics["gCNR"].append(cnr_gcnr[0][1]) |
|
|
|
|
|
ks_a, _, ks_b, _ = calculate_ks_statistics( |
|
img_noisy, img_dehazed, str(roi_path) |
|
) |
|
metrics["KS_A"].append(ks_a) |
|
metrics["KS_B"].append(ks_b) |
|
|
|
|
|
stats = { |
|
k: (np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in metrics.items() |
|
} |
|
print("Contrast statistics:") |
|
for k, (mean, std, minv, maxv) in stats.items(): |
|
print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}") |
|
|
|
fig = plot_metrics(metrics, limits, "contrast_metrics.png") |
|
|
|
path = Path("contrast_metrics.png") |
|
save_kwargs = {"bbox_inches": "tight", "dpi": 300} |
|
fig.savefig(path, **save_kwargs) |
|
fig.savefig(path.with_suffix(".pdf"), **save_kwargs) |
|
log.success(f"Metrics plot saved to {log.yellow(path)}") |
|
|
|
|
|
fid_image_paths = [str(folder / name) for name in common_files] |
|
fid_score = calculate_fid_score(fid_image_paths, str(reference_folder)) |
|
print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}") |
|
|
|
|
|
aggregates = { |
|
"fid": float(fid_score), |
|
"cnr_mean": float(np.mean(metrics["CNR"])), |
|
"cnr_std": float(np.std(metrics["CNR"])), |
|
"gcnr_mean": float(np.mean(metrics["gCNR"])), |
|
"gcnr_std": float(np.std(metrics["gCNR"])), |
|
"ks_roi1_ksstatistic_mean": float(np.mean(metrics["KS_A"])), |
|
"ks_roi1_ksstatistic_std": float(np.std(metrics["KS_A"])), |
|
"ks_roi2_ksstatistic_mean": float(np.mean(metrics["KS_B"])), |
|
"ks_roi2_ksstatistic_std": float(np.std(metrics["KS_B"])), |
|
} |
|
|
|
|
|
final_score = calculate_final_score(aggregates) |
|
aggregates["final_score"] = float(final_score) |
|
|
|
return aggregates |
|
|
|
|
|
if __name__ == "__main__": |
|
tyro.cli(evaluate) |
|
|