Spaces:
Running
on
Zero
Running
on
Zero
import warnings | |
from glob import glob | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import tyro | |
from PIL import Image | |
from scipy.ndimage import binary_erosion, distance_transform_edt | |
from scipy.stats import ks_2samp | |
from zea import log | |
from zea.io_lib import load_image | |
import fid_score | |
from plots import plot_metrics | |
def calculate_fid_score(denoised_image_dirs, ground_truth_dir): | |
if isinstance(denoised_image_dirs, (str, Path)): | |
denoised_image_dirs = [denoised_image_dirs] | |
elif not isinstance(denoised_image_dirs, list): | |
raise ValueError("Input must be a path or list of paths") | |
clean_images_folder = glob(str(ground_truth_dir) + "/*.png") | |
print(f"Looking for clean images in: {ground_truth_dir}") | |
print(f"Found {len(clean_images_folder)} clean images") | |
# Determine optimal batch size based on number of images | |
num_denoised = len(denoised_image_dirs) | |
num_clean = len(clean_images_folder) | |
optimal_batch_size = min(8, num_denoised, num_clean) | |
print(f"Using batch size: {optimal_batch_size}") | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", message="os.fork.*JAX is multithreaded") | |
fid_value = fid_score.calculate_fid_with_cached_ground_truth( | |
denoised_image_dirs, | |
clean_images_folder, | |
batch_size=optimal_batch_size, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
num_workers=2 if torch.cuda.is_available() else 0, | |
dims=2048, | |
) | |
return fid_value | |
def gcnr(img1, img2): | |
"""Generalized Contrast-to-Noise Ratio""" | |
_, bins = np.histogram(np.concatenate((img1, img2)), bins=256) | |
f, _ = np.histogram(img1, bins=bins, density=True) | |
g, _ = np.histogram(img2, bins=bins, density=True) | |
f /= f.sum() | |
g /= g.sum() | |
return 1 - np.sum(np.minimum(f, g)) | |
def cnr(img1, img2): | |
"""Contrast-to-Noise Ratio""" | |
return (img1.mean() - img2.mean()) / np.sqrt(img1.var() + img2.var()) | |
def calculate_cnr_gcnr(result_dehazed_cardiac_ultrasound, mask_path): | |
""" | |
Evaluate gCNR and CNR metrics for denoised images using paired masks. | |
Saves detailed and summary statistics to Excel. | |
""" | |
results = [] | |
mask = np.array(Image.open(mask_path).convert("L")) | |
roi1_pixels = result_dehazed_cardiac_ultrasound[mask == 255] # Foreground ROI | |
roi2_pixels = result_dehazed_cardiac_ultrasound[mask == 128] # Background/Noise ROI | |
gcnr_val = gcnr(roi1_pixels, roi2_pixels) | |
cnr_val = cnr(roi1_pixels, roi2_pixels) | |
results.append([cnr_val, gcnr_val]) | |
return results | |
def calculate_ks_statistics( | |
result_hazy_cardiac_ultrasound, result_dehazed_cardiac_ultrasound, mask_path | |
): | |
mask = np.array(Image.open(mask_path).convert("L")) | |
roi1_original = result_hazy_cardiac_ultrasound[mask == 255] # region A | |
roi1_denoised = result_dehazed_cardiac_ultrasound[mask == 255] | |
roi2_original = result_hazy_cardiac_ultrasound[mask == 128] # region B | |
roi2_denoised = result_dehazed_cardiac_ultrasound[mask == 128] | |
roi1_ks_stat, roi1_ks_p_value = (None, None) | |
roi2_ks_stat, roi2_ks_p_value = (None, None) | |
if roi1_original.size > 0 and roi1_denoised.size > 0: | |
roi1_ks_stat, roi1_ks_p_value = ks_2samp(roi1_original, roi1_denoised) | |
if roi2_original.size > 0 and roi2_denoised.size > 0: | |
roi2_ks_stat, roi2_ks_p_value = ks_2samp(roi2_original, roi2_denoised) | |
return roi1_ks_stat, roi1_ks_p_value, roi2_ks_stat, roi2_ks_p_value | |
def calculate_dice_asd(image_path, label_path, checkpoint_path, image_size=224): | |
try: | |
from test import inference # Our Segmentation Method | |
except ImportError: | |
raise ImportError( | |
"Segmentation method not available, skipping Dice/ASD calculation" | |
) | |
pred_img = inference(image_path, checkpoint_path, image_size) | |
pred = np.array(pred_img) > 127 | |
label = Image.open(label_path).convert("L") | |
label = label.resize((image_size, image_size), Image.NEAREST) | |
label = np.array(label) > 127 | |
# calculate Dice | |
intersection = np.logical_and(pred, label).sum() | |
dice = 2 * intersection / (pred.sum() + label.sum() + 1e-8) | |
# calculate ASD | |
if pred.sum() == 0 or label.sum() == 0: | |
asd = np.nan | |
else: | |
pred_dt = distance_transform_edt(~pred) | |
label_dt = distance_transform_edt(~label) | |
surface_pred = pred ^ binary_erosion(pred) | |
surface_label = label ^ binary_erosion(label) | |
d1 = pred_dt[surface_label].mean() | |
d2 = label_dt[surface_pred].mean() | |
asd = (d1 + d2) / 2 | |
return dice, asd | |
def calculate_final_score(aggregates): | |
try: | |
# (FID + CNR + gCNR):(KS^A + KS^B):(Dice + ASD)= 5:3:2 | |
group1_score = 0 # FID + CNR + gCNR | |
if aggregates.get("fid") is not None: | |
fid_min = 60.0 | |
fid_max = 150.0 | |
fid_score = (fid_max - aggregates["fid"]) / (fid_max - fid_min) | |
fid_score = max(0, min(1, fid_score)) | |
group1_score += fid_score * 100 * 0.33 | |
if aggregates.get("cnr_mean") is not None: | |
cnr_min = 1.0 | |
cnr_max = 1.5 | |
cnr_score = (aggregates["cnr_mean"] - cnr_min) / (cnr_max - cnr_min) | |
cnr_score = max(0, min(1, cnr_score)) | |
group1_score += cnr_score * 100 * 0.33 | |
if aggregates.get("gcnr_mean") is not None: | |
gcnr_min = 0.5 | |
gcnr_max = 0.8 | |
gcnr_score = (aggregates["gcnr_mean"] - gcnr_min) / (gcnr_max - gcnr_min) | |
gcnr_score = max(0, min(1, gcnr_score)) | |
group1_score += gcnr_score * 100 * 0.34 | |
group2_score = 0 # KS^A + KS^B | |
if aggregates.get("ks_roi1_ksstatistic_mean") is not None: | |
ks1_min = 0.1 | |
ks1_max = 0.3 | |
ks1_score = (ks1_max - aggregates["ks_roi1_ksstatistic_mean"]) / ( | |
ks1_max - ks1_min | |
) | |
ks1_score = max(0, min(1, ks1_score)) | |
group2_score += ks1_score * 100 * 0.5 | |
if aggregates.get("ks_roi2_ksstatistic_mean") is not None: | |
ks2_min = 0.0 | |
ks2_max = 0.5 | |
ks2_score = (aggregates["ks_roi2_ksstatistic_mean"] - ks2_min) / ( | |
ks2_max - ks2_min | |
) | |
ks2_score = max(0, min(1, ks2_score)) | |
group2_score += ks2_score * 100 * 0.5 | |
group3_score = 0 # Dice + ASD | |
if aggregates.get("dice_mean") is not None: | |
dice_min = 0.85 | |
dice_max = 0.95 | |
dice_score = (aggregates["dice_mean"] - dice_min) / (dice_max - dice_min) | |
dice_score = max(0, min(1, dice_score)) | |
group3_score += dice_score * 100 * 0.5 | |
if aggregates.get("asd_mean") is not None: | |
asd_min = 0.7 | |
asd_max = 2.0 | |
asd_score = (asd_max - aggregates["asd_mean"]) / (asd_max - asd_min) | |
asd_score = max(0, min(1, asd_score)) | |
group3_score += asd_score * 100 * 0.5 | |
# Final score calculation | |
final_score = (group1_score * 5 + group2_score * 3 + group3_score * 2) / 10 | |
return final_score | |
except Exception as e: | |
print(f"Error calculating final score: {str(e)}") | |
return 0 | |
def evaluate(folder: str, noisy_folder: str, roi_folder: str, reference_folder: str): | |
"""Evaluate the dehazing algorithm. | |
Args: | |
folder (str): Path to the folder containing the dehazed images. | |
Used for evaluating all metrics. | |
noisy_folder (str): Path to the folder containing the noisy images. | |
Only used for KS statistics. | |
roi_folder (str): Path to the folder containing the ROI images. | |
Used for contrast and KS statistic metrics. | |
reference_folder (str): Path to the folder containing the reference images. | |
Used only for FID calculation. | |
""" | |
folder = Path(folder) | |
noisy_folder = Path(noisy_folder) | |
roi_folder = Path(roi_folder) | |
reference_folder = Path(reference_folder) | |
folder_files = set(f.name for f in folder.glob("*.png")) | |
noisy_files = set(f.name for f in noisy_folder.glob("*.png")) | |
roi_files = set(f.name for f in roi_folder.glob("*.png")) | |
print(f"Found {len(folder_files)} .png files in output folder: {folder}") | |
print(f"Found {len(noisy_files)} .png files in noisy folder: {noisy_folder}") | |
print(f"Found {len(roi_files)} .png files in ROI folder: {roi_folder}") | |
# Find intersection of filenames | |
common_files = sorted(folder_files & roi_files & noisy_files) | |
print(f"Found {len(common_files)} matching images in noisy/dehazed/roi folders") | |
assert len(common_files) > 0, ( | |
"No matching .png files in all folders. Cannot proceed." | |
) | |
metrics = {"CNR": [], "gCNR": [], "KS_A": [], "KS_B": []} | |
limits = { | |
"CNR": [1.0, 1.5], | |
"gCNR": [0.5, 0.8], | |
"KS_A": [0.1, 0.3], | |
"KS_B": [0.0, 0.5], | |
} | |
for name in common_files: | |
dehazed_path = folder / name | |
noisy_path = noisy_folder / name | |
roi_path = roi_folder / name | |
try: | |
img_dehazed = np.array(load_image(str(dehazed_path))) | |
img_noisy = np.array(load_image(str(noisy_path))) | |
except Exception as e: | |
print(f"Error loading image {name}: {e}") | |
continue | |
# CNR/gCNR | |
cnr_gcnr = calculate_cnr_gcnr(img_dehazed, str(roi_path)) | |
metrics["CNR"].append(cnr_gcnr[0][0]) | |
metrics["gCNR"].append(cnr_gcnr[0][1]) | |
# KS statistics | |
ks_a, _, ks_b, _ = calculate_ks_statistics( | |
img_noisy, img_dehazed, str(roi_path) | |
) | |
metrics["KS_A"].append(ks_a) | |
metrics["KS_B"].append(ks_b) | |
# Compute statistics | |
stats = { | |
k: (np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in metrics.items() | |
} | |
print("Contrast statistics:") | |
for k, (mean, std, minv, maxv) in stats.items(): | |
print(f"{k}: mean={mean:.3f}, std={std:.3f}, min={minv:.3f}, max={maxv:.3f}") | |
fig = plot_metrics(metrics, limits, "contrast_metrics.png") | |
path = Path("contrast_metrics.png") | |
save_kwargs = {"bbox_inches": "tight", "dpi": 300} | |
fig.savefig(path, **save_kwargs) | |
fig.savefig(path.with_suffix(".pdf"), **save_kwargs) | |
log.success(f"Metrics plot saved to {log.yellow(path)}") | |
# Compute FID | |
fid_image_paths = [str(folder / name) for name in common_files] | |
fid_score = calculate_fid_score(fid_image_paths, str(reference_folder)) | |
print(f"FID between {folder} and {reference_folder}: {fid_score:.3f}") | |
# Create aggregates dictionary for final score calculation | |
aggregates = { | |
"fid": float(fid_score), | |
"cnr_mean": float(np.mean(metrics["CNR"])), | |
"cnr_std": float(np.std(metrics["CNR"])), | |
"gcnr_mean": float(np.mean(metrics["gCNR"])), | |
"gcnr_std": float(np.std(metrics["gCNR"])), | |
"ks_roi1_ksstatistic_mean": float(np.mean(metrics["KS_A"])), | |
"ks_roi1_ksstatistic_std": float(np.std(metrics["KS_A"])), | |
"ks_roi2_ksstatistic_mean": float(np.mean(metrics["KS_B"])), | |
"ks_roi2_ksstatistic_std": float(np.std(metrics["KS_B"])), | |
} | |
# Calculate final score | |
final_score = calculate_final_score(aggregates) | |
aggregates["final_score"] = float(final_score) | |
return aggregates | |
if __name__ == "__main__": | |
tyro.cli(evaluate) | |