Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1c76709
1
Parent(s):
1398519
added gif
Browse files- .gitignore +1 -0
- README.md +4 -0
- assets/patient-17-4C-frame-11.png +0 -0
- assets/patient-2-4C-frame-9.png +0 -0
- assets/patient-50-4C-frame-53.png +0 -0
- configs/semantic_dps_opt.yaml +32 -0
- main.py +18 -132
- plots.py +62 -0
- sweeper.py +0 -1
- utils.py +133 -0
.gitignore
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
.env
|
3 |
temp/
|
4 |
*.png
|
|
|
5 |
*.pdf
|
6 |
*.hash
|
7 |
*.npz
|
|
|
2 |
.env
|
3 |
temp/
|
4 |
*.png
|
5 |
+
!assets/*.png
|
6 |
*.pdf
|
7 |
*.hash
|
8 |
*.npz
|
README.md
CHANGED
@@ -13,6 +13,9 @@
|
|
13 |
<p>Eindhoven University of Technology, the Netherlands</p>
|
14 |
</div>
|
15 |
|
|
|
|
|
|
|
16 |
|
17 |
### Installation
|
18 |
|
@@ -21,6 +24,7 @@ The algorithm is implemented using Keras with JAX backend. Furthermore it heavil
|
|
21 |
Either install the following in your Python environment, or use the [Dockerfile](./Dockerfile) provided in this repository.
|
22 |
|
23 |
```bash
|
|
|
24 |
pip install tyro optuna zea==0.0.4
|
25 |
pip install -U "jax[cuda12]"
|
26 |
```
|
|
|
13 |
<p>Eindhoven University of Technology, the Netherlands</p>
|
14 |
</div>
|
15 |
|
16 |
+
<p align="center">
|
17 |
+
<img src="animation.gif" alt="Cardiac Ultrasound Dehazing Animation" style="max-width: 100%; height: auto;">
|
18 |
+
</p>
|
19 |
|
20 |
### Installation
|
21 |
|
|
|
24 |
Either install the following in your Python environment, or use the [Dockerfile](./Dockerfile) provided in this repository.
|
25 |
|
26 |
```bash
|
27 |
+
# requires Python>=3.10
|
28 |
pip install tyro optuna zea==0.0.4
|
29 |
pip install -U "jax[cuda12]"
|
30 |
```
|
assets/patient-17-4C-frame-11.png
DELETED
Binary file (27.4 kB)
|
|
assets/patient-2-4C-frame-9.png
ADDED
![]() |
assets/patient-50-4C-frame-53.png
DELETED
Binary file (22.1 kB)
|
|
configs/semantic_dps_opt.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# While these params optimize for the final score of the challenge
|
2 |
+
# generally we observe lower perceptual quality of the dehazed results
|
3 |
+
# and in some cases even artifacts.
|
4 |
+
# we therefore recommend using configs/semantic_dps.yaml instead.
|
5 |
+
diffusion_model_path: "hf://tristan-deep/semantic-diffusion-echo-dehazing"
|
6 |
+
segmentation_model_path: "hf://tristan-deep/semantic-segmentation-echo-dehazing"
|
7 |
+
seed: 42
|
8 |
+
|
9 |
+
params:
|
10 |
+
diffusion_steps: 480
|
11 |
+
initial_diffusion_step: 0
|
12 |
+
batch_size: 16
|
13 |
+
threshold_output_quantile: 0.17447
|
14 |
+
preserve_bottom_percent: 32.0
|
15 |
+
bottom_transition_width: 7.0
|
16 |
+
|
17 |
+
mask_params:
|
18 |
+
sigma: 0.4704516
|
19 |
+
threshold: 0.18935
|
20 |
+
fixed_mask_params:
|
21 |
+
top_px: 20
|
22 |
+
bottom_px: 40
|
23 |
+
skeleton_params:
|
24 |
+
sigma_pre: 9.919
|
25 |
+
sigma_post: 9.5347840479
|
26 |
+
threshold: 0.73917
|
27 |
+
guidance_kwargs:
|
28 |
+
omega: 0.78
|
29 |
+
omega_vent: 0.0001
|
30 |
+
omega_sept: 15.84
|
31 |
+
eta: 0.01105
|
32 |
+
smooth_l1_beta: 6.3726
|
main.py
CHANGED
@@ -10,7 +10,7 @@ import tyro
|
|
10 |
import zea
|
11 |
from keras import ops
|
12 |
from PIL import Image
|
13 |
-
from skimage import filters
|
14 |
from zea import Config, init_device, log
|
15 |
from zea.internal.operators import Operator
|
16 |
from zea.models.diffusion import (
|
@@ -21,136 +21,14 @@ from zea.models.diffusion import (
|
|
21 |
from zea.tensor_ops import L2
|
22 |
from zea.utils import translate
|
23 |
|
24 |
-
from plots import plot_batch_with_named_masks, plot_dehazed_results
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
return ops.sum(ops.abs(x))
|
33 |
-
|
34 |
-
|
35 |
-
def smooth_L1(x, beta=0.4):
|
36 |
-
"""Smooth L1 loss function.
|
37 |
-
|
38 |
-
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
|
39 |
-
while small beta values make it similar to L2 loss.
|
40 |
-
"""
|
41 |
-
abs_x = ops.abs(x)
|
42 |
-
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
|
43 |
-
return ops.sum(loss)
|
44 |
-
|
45 |
-
|
46 |
-
def postprocess(data, normalization_range):
|
47 |
-
"""Postprocess data from model output to image."""
|
48 |
-
data = ops.clip(data, *normalization_range)
|
49 |
-
data = translate(data, normalization_range, (0, 255))
|
50 |
-
data = ops.convert_to_numpy(data)
|
51 |
-
data = np.squeeze(data, axis=-1)
|
52 |
-
return np.clip(data, 0, 255).astype("uint8")
|
53 |
-
|
54 |
-
|
55 |
-
def preprocess(data, normalization_range):
|
56 |
-
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
|
57 |
-
data = ops.convert_to_tensor(data, dtype="float32")
|
58 |
-
data = translate(data, (0, 255), normalization_range)
|
59 |
-
data = ops.expand_dims(data, axis=-1)
|
60 |
-
return data
|
61 |
-
|
62 |
-
|
63 |
-
def apply_bottom_preservation(
|
64 |
-
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
|
65 |
-
):
|
66 |
-
"""Apply bottom preservation with smooth windowed transition.
|
67 |
-
|
68 |
-
Args:
|
69 |
-
output_images: Model output images, (batch, height, width, channels)
|
70 |
-
input_images: Original input images, (batch, height, width, channels)
|
71 |
-
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
|
72 |
-
transition_width: Percentage of image height for smooth transition (default 10%)
|
73 |
-
|
74 |
-
Returns:
|
75 |
-
Blended images with preserved bottom portion
|
76 |
-
"""
|
77 |
-
output_shape = ops.shape(output_images)
|
78 |
-
|
79 |
-
batch_size, height, width, channels = output_shape
|
80 |
-
|
81 |
-
preserve_height = int(height * preserve_bottom_percent / 100.0)
|
82 |
-
transition_height = int(height * transition_width / 100.0)
|
83 |
-
|
84 |
-
transition_start = height - preserve_height - transition_height
|
85 |
-
preserve_start = height - preserve_height
|
86 |
-
|
87 |
-
transition_start = max(0, transition_start)
|
88 |
-
preserve_start = min(height, preserve_start)
|
89 |
-
|
90 |
-
if transition_start >= preserve_start:
|
91 |
-
transition_start = preserve_start
|
92 |
-
transition_height = 0
|
93 |
-
|
94 |
-
y_coords = ops.arange(height, dtype="float32")
|
95 |
-
y_coords = ops.reshape(y_coords, (height, 1, 1))
|
96 |
-
|
97 |
-
if transition_height > 0:
|
98 |
-
# Smooth transition using cosine interpolation
|
99 |
-
transition_region = ops.logical_and(
|
100 |
-
y_coords >= transition_start, y_coords < preserve_start
|
101 |
-
)
|
102 |
-
|
103 |
-
transition_progress = (y_coords - transition_start) / transition_height
|
104 |
-
transition_progress = ops.clip(transition_progress, 0.0, 1.0)
|
105 |
-
|
106 |
-
# Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
|
107 |
-
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
|
108 |
-
|
109 |
-
blend_weight = ops.where(
|
110 |
-
y_coords < transition_start,
|
111 |
-
0.0,
|
112 |
-
ops.where(
|
113 |
-
transition_region,
|
114 |
-
cosine_weight,
|
115 |
-
1.0,
|
116 |
-
),
|
117 |
-
)
|
118 |
-
else:
|
119 |
-
# No transition, just hard switch
|
120 |
-
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
|
121 |
-
|
122 |
-
blend_weight = ops.expand_dims(blend_weight, axis=0)
|
123 |
-
|
124 |
-
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
|
125 |
-
|
126 |
-
return blended_images
|
127 |
-
|
128 |
-
|
129 |
-
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
|
130 |
-
"""Extract skeletons from the input images."""
|
131 |
-
images_np = ops.convert_to_numpy(images)
|
132 |
-
images_np = np.clip(images_np, input_range[0], input_range[1])
|
133 |
-
images_np = translate(images_np, input_range, (0, 1))
|
134 |
-
images_np = np.squeeze(images_np, axis=-1)
|
135 |
-
|
136 |
-
skeleton_masks = []
|
137 |
-
for img in images_np:
|
138 |
-
img[img < threshold] = 0
|
139 |
-
smoothed = filters.gaussian(img, sigma=sigma_pre)
|
140 |
-
binary = smoothed > filters.threshold_otsu(smoothed)
|
141 |
-
skeleton = morphology.skeletonize(binary)
|
142 |
-
skeleton = morphology.dilation(skeleton, morphology.disk(2))
|
143 |
-
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
|
144 |
-
skeleton_masks.append(skeleton)
|
145 |
-
|
146 |
-
skeleton_masks = np.array(skeleton_masks)
|
147 |
-
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
|
148 |
-
|
149 |
-
# normalize to [0, 1]
|
150 |
-
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
|
151 |
-
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
|
152 |
-
|
153 |
-
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
|
154 |
|
155 |
|
156 |
class IdentityOperator(Operator):
|
@@ -250,7 +128,6 @@ class SemanticDPS(DPS):
|
|
250 |
masks_sept + masks_fixed + masks_skeleton + masks_dark, 0, 1
|
251 |
)
|
252 |
|
253 |
-
# background = not masks_strong, not vent
|
254 |
background = ops.where(masks_strong < 0.1, 1.0, 0.0) * ops.where(
|
255 |
masks_vent == 0, 1.0, 0.0
|
256 |
)
|
@@ -534,6 +411,7 @@ def main(
|
|
534 |
masks_viz = copy.deepcopy(masks)
|
535 |
masks_viz.pop("haze")
|
536 |
|
|
|
537 |
masks_viz = {k: v[:num_img] for k, v in masks_viz.items()}
|
538 |
|
539 |
fig = plot_batch_with_named_masks(
|
@@ -553,6 +431,14 @@ def main(
|
|
553 |
fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
|
554 |
log.success(f"Segmentation steps saved to {log.yellow(path)}")
|
555 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
plt.close("all")
|
557 |
|
558 |
|
|
|
10 |
import zea
|
11 |
from keras import ops
|
12 |
from PIL import Image
|
13 |
+
from skimage import filters
|
14 |
from zea import Config, init_device, log
|
15 |
from zea.internal.operators import Operator
|
16 |
from zea.models.diffusion import (
|
|
|
21 |
from zea.tensor_ops import L2
|
22 |
from zea.utils import translate
|
23 |
|
24 |
+
from plots import create_animation, plot_batch_with_named_masks, plot_dehazed_results
|
25 |
+
from utils import (
|
26 |
+
apply_bottom_preservation,
|
27 |
+
extract_skeleton,
|
28 |
+
postprocess,
|
29 |
+
preprocess,
|
30 |
+
smooth_L1,
|
31 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
class IdentityOperator(Operator):
|
|
|
128 |
masks_sept + masks_fixed + masks_skeleton + masks_dark, 0, 1
|
129 |
)
|
130 |
|
|
|
131 |
background = ops.where(masks_strong < 0.1, 1.0, 0.0) * ops.where(
|
132 |
masks_vent == 0, 1.0, 0.0
|
133 |
)
|
|
|
411 |
masks_viz = copy.deepcopy(masks)
|
412 |
masks_viz.pop("haze")
|
413 |
|
414 |
+
num_img = 2 # hardcoded as the plotting figure only neatly supports 2 rows
|
415 |
masks_viz = {k: v[:num_img] for k, v in masks_viz.items()}
|
416 |
|
417 |
fig = plot_batch_with_named_masks(
|
|
|
431 |
fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
|
432 |
log.success(f"Segmentation steps saved to {log.yellow(path)}")
|
433 |
|
434 |
+
last_batch_size = len(diffusion_model.track_progress[0])
|
435 |
+
create_animation(
|
436 |
+
preprocess(hazy_images[-last_batch_size:], diffusion_model.input_range),
|
437 |
+
diffusion_model,
|
438 |
+
output_path="animation.gif",
|
439 |
+
fps=10,
|
440 |
+
)
|
441 |
+
|
442 |
plt.close("all")
|
443 |
|
444 |
|
plots.py
CHANGED
@@ -2,6 +2,7 @@ import json
|
|
2 |
from pathlib import Path
|
3 |
from typing import Any, Dict, List
|
4 |
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
import tyro
|
@@ -9,8 +10,13 @@ from keras import ops
|
|
9 |
from matplotlib.patches import PathPatch
|
10 |
from matplotlib.path import Path as pltPath
|
11 |
from skimage import measure
|
|
|
|
|
|
|
12 |
from zea.visualize import plot_image_grid
|
13 |
|
|
|
|
|
14 |
|
15 |
def add_shape_from_mask(ax, mask, **kwargs):
|
16 |
"""add a shape to axis from mask array.
|
@@ -335,6 +341,62 @@ def plot_optimization_history_from_json(
|
|
335 |
plt.close(fig)
|
336 |
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"):
|
339 |
json_path = Path(json_file)
|
340 |
if not json_path.exists():
|
|
|
2 |
from pathlib import Path
|
3 |
from typing import Any, Dict, List
|
4 |
|
5 |
+
import keras
|
6 |
import matplotlib.pyplot as plt
|
7 |
import numpy as np
|
8 |
import tyro
|
|
|
10 |
from matplotlib.patches import PathPatch
|
11 |
from matplotlib.path import Path as pltPath
|
12 |
from skimage import measure
|
13 |
+
from zea import log
|
14 |
+
from zea.io_lib import matplotlib_figure_to_numpy
|
15 |
+
from zea.utils import save_to_gif
|
16 |
from zea.visualize import plot_image_grid
|
17 |
|
18 |
+
from utils import postprocess
|
19 |
+
|
20 |
|
21 |
def add_shape_from_mask(ax, mask, **kwargs):
|
22 |
"""add a shape to axis from mask array.
|
|
|
341 |
plt.close(fig)
|
342 |
|
343 |
|
344 |
+
def create_animation_frame(hazy_images, tissue_frame, haze_frame):
|
345 |
+
"""Create a single animation frame from the tracked progress."""
|
346 |
+
batch, height, width = ops.shape(hazy_images)
|
347 |
+
frame_stack = ops.stack(
|
348 |
+
[
|
349 |
+
hazy_images,
|
350 |
+
tissue_frame,
|
351 |
+
haze_frame,
|
352 |
+
]
|
353 |
+
)
|
354 |
+
frame_stack = ops.reshape(frame_stack, (-1, height, width))
|
355 |
+
fig_frame, _ = plot_image_grid(
|
356 |
+
frame_stack,
|
357 |
+
ncols=len(hazy_images),
|
358 |
+
remove_axis=False,
|
359 |
+
vmin=0,
|
360 |
+
vmax=255,
|
361 |
+
)
|
362 |
+
labels = ["Hazy", "Tissue"] if haze_frame is None else ["Hazy", "Tissue", "Haze"]
|
363 |
+
for i, ax in enumerate(fig_frame.axes):
|
364 |
+
label = labels[i % len(labels)]
|
365 |
+
ax.set_ylabel(label, fontsize=12)
|
366 |
+
frame_array = matplotlib_figure_to_numpy(fig_frame)
|
367 |
+
plt.close(fig_frame)
|
368 |
+
return frame_array
|
369 |
+
|
370 |
+
|
371 |
+
def create_animation(hazy_images, diffusion_model, output_path, fps):
|
372 |
+
"""Create animation from tracked progress frames."""
|
373 |
+
if not (len(diffusion_model.track_progress) > 1):
|
374 |
+
log.warning(
|
375 |
+
"Animation requested but no intermediate frames were tracked. "
|
376 |
+
"Try reducing diffusion_steps or ensure progress tracking is enabled."
|
377 |
+
)
|
378 |
+
return
|
379 |
+
|
380 |
+
log.info(f"Creating animation with {len(diffusion_model.track_progress)} frames...")
|
381 |
+
|
382 |
+
animation_frames = []
|
383 |
+
progbar = keras.utils.Progbar(
|
384 |
+
len(diffusion_model.track_progress), unit_name="frame"
|
385 |
+
)
|
386 |
+
for tissue_frame in diffusion_model.track_progress:
|
387 |
+
haze_frame = hazy_images - tissue_frame - 1
|
388 |
+
tissue_frame = postprocess(tissue_frame, diffusion_model.input_range)
|
389 |
+
haze_frame = postprocess(haze_frame, diffusion_model.input_range)
|
390 |
+
_hazy_images = postprocess(hazy_images, diffusion_model.input_range)
|
391 |
+
frame_array = create_animation_frame(_hazy_images, tissue_frame, haze_frame)
|
392 |
+
animation_frames.append(frame_array)
|
393 |
+
progbar.add(1)
|
394 |
+
|
395 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
396 |
+
animation_path = Path(output_path).with_suffix(".gif")
|
397 |
+
save_to_gif(animation_frames, animation_path, fps=fps)
|
398 |
+
|
399 |
+
|
400 |
def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"):
|
401 |
json_path = Path(json_file)
|
402 |
if not json_path.exists():
|
sweeper.py
CHANGED
@@ -136,7 +136,6 @@ class OptunaObjective:
|
|
136 |
"omega": trial.suggest_float("omega", 0.5, 50.0, log=True),
|
137 |
"omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True),
|
138 |
"omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True),
|
139 |
-
"omega_dark": trial.suggest_float("omega_dark", 0.001, 50.0, log=True),
|
140 |
"eta": trial.suggest_float("eta", 0.001, 1.0, log=True),
|
141 |
"smooth_l1_beta": trial.suggest_float(
|
142 |
"smooth_l1_beta", 0.1, 10.0, log=True
|
|
|
136 |
"omega": trial.suggest_float("omega", 0.5, 50.0, log=True),
|
137 |
"omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True),
|
138 |
"omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True),
|
|
|
139 |
"eta": trial.suggest_float("eta", 0.001, 1.0, log=True),
|
140 |
"smooth_l1_beta": trial.suggest_float(
|
141 |
"smooth_l1_beta", 0.1, 10.0, log=True
|
utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from keras import ops
|
3 |
+
from skimage import filters, morphology
|
4 |
+
from zea.utils import translate
|
5 |
+
|
6 |
+
|
7 |
+
def L1(x):
|
8 |
+
"""L1 norm of a tensor.
|
9 |
+
|
10 |
+
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
|
11 |
+
"""
|
12 |
+
return ops.sum(ops.abs(x))
|
13 |
+
|
14 |
+
|
15 |
+
def smooth_L1(x, beta=0.4):
|
16 |
+
"""Smooth L1 loss function.
|
17 |
+
|
18 |
+
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
|
19 |
+
while small beta values make it similar to L2 loss.
|
20 |
+
"""
|
21 |
+
abs_x = ops.abs(x)
|
22 |
+
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
|
23 |
+
return ops.sum(loss)
|
24 |
+
|
25 |
+
|
26 |
+
def postprocess(data, normalization_range):
|
27 |
+
"""Postprocess data from model output to image."""
|
28 |
+
data = ops.clip(data, *normalization_range)
|
29 |
+
data = translate(data, normalization_range, (0, 255))
|
30 |
+
data = ops.convert_to_numpy(data)
|
31 |
+
data = np.squeeze(data, axis=-1)
|
32 |
+
return np.clip(data, 0, 255).astype("uint8")
|
33 |
+
|
34 |
+
|
35 |
+
def preprocess(data, normalization_range):
|
36 |
+
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
|
37 |
+
data = ops.convert_to_tensor(data, dtype="float32")
|
38 |
+
data = translate(data, (0, 255), normalization_range)
|
39 |
+
data = ops.expand_dims(data, axis=-1)
|
40 |
+
return data
|
41 |
+
|
42 |
+
|
43 |
+
def apply_bottom_preservation(
|
44 |
+
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
|
45 |
+
):
|
46 |
+
"""Apply bottom preservation with smooth windowed transition.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
output_images: Model output images, (batch, height, width, channels)
|
50 |
+
input_images: Original input images, (batch, height, width, channels)
|
51 |
+
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
|
52 |
+
transition_width: Percentage of image height for smooth transition (default 10%)
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Blended images with preserved bottom portion
|
56 |
+
"""
|
57 |
+
output_shape = ops.shape(output_images)
|
58 |
+
|
59 |
+
batch_size, height, width, channels = output_shape
|
60 |
+
|
61 |
+
preserve_height = int(height * preserve_bottom_percent / 100.0)
|
62 |
+
transition_height = int(height * transition_width / 100.0)
|
63 |
+
|
64 |
+
transition_start = height - preserve_height - transition_height
|
65 |
+
preserve_start = height - preserve_height
|
66 |
+
|
67 |
+
transition_start = max(0, transition_start)
|
68 |
+
preserve_start = min(height, preserve_start)
|
69 |
+
|
70 |
+
if transition_start >= preserve_start:
|
71 |
+
transition_start = preserve_start
|
72 |
+
transition_height = 0
|
73 |
+
|
74 |
+
y_coords = ops.arange(height, dtype="float32")
|
75 |
+
y_coords = ops.reshape(y_coords, (height, 1, 1))
|
76 |
+
|
77 |
+
if transition_height > 0:
|
78 |
+
# Smooth transition using cosine interpolation
|
79 |
+
transition_region = ops.logical_and(
|
80 |
+
y_coords >= transition_start, y_coords < preserve_start
|
81 |
+
)
|
82 |
+
|
83 |
+
transition_progress = (y_coords - transition_start) / transition_height
|
84 |
+
transition_progress = ops.clip(transition_progress, 0.0, 1.0)
|
85 |
+
|
86 |
+
# Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
|
87 |
+
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
|
88 |
+
|
89 |
+
blend_weight = ops.where(
|
90 |
+
y_coords < transition_start,
|
91 |
+
0.0,
|
92 |
+
ops.where(
|
93 |
+
transition_region,
|
94 |
+
cosine_weight,
|
95 |
+
1.0,
|
96 |
+
),
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
# No transition, just hard switch
|
100 |
+
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
|
101 |
+
|
102 |
+
blend_weight = ops.expand_dims(blend_weight, axis=0)
|
103 |
+
|
104 |
+
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
|
105 |
+
|
106 |
+
return blended_images
|
107 |
+
|
108 |
+
|
109 |
+
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
|
110 |
+
"""Extract skeletons from the input images."""
|
111 |
+
images_np = ops.convert_to_numpy(images)
|
112 |
+
images_np = np.clip(images_np, input_range[0], input_range[1])
|
113 |
+
images_np = translate(images_np, input_range, (0, 1))
|
114 |
+
images_np = np.squeeze(images_np, axis=-1)
|
115 |
+
|
116 |
+
skeleton_masks = []
|
117 |
+
for img in images_np:
|
118 |
+
img[img < threshold] = 0
|
119 |
+
smoothed = filters.gaussian(img, sigma=sigma_pre)
|
120 |
+
binary = smoothed > filters.threshold_otsu(smoothed)
|
121 |
+
skeleton = morphology.skeletonize(binary)
|
122 |
+
skeleton = morphology.dilation(skeleton, morphology.disk(2))
|
123 |
+
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
|
124 |
+
skeleton_masks.append(skeleton)
|
125 |
+
|
126 |
+
skeleton_masks = np.array(skeleton_masks)
|
127 |
+
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
|
128 |
+
|
129 |
+
# normalize to [0, 1]
|
130 |
+
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
|
131 |
+
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
|
132 |
+
|
133 |
+
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
|