tristan-deep commited on
Commit
1c76709
·
1 Parent(s): 1398519
.gitignore CHANGED
@@ -2,6 +2,7 @@
2
  .env
3
  temp/
4
  *.png
 
5
  *.pdf
6
  *.hash
7
  *.npz
 
2
  .env
3
  temp/
4
  *.png
5
+ !assets/*.png
6
  *.pdf
7
  *.hash
8
  *.npz
README.md CHANGED
@@ -13,6 +13,9 @@
13
  <p>Eindhoven University of Technology, the Netherlands</p>
14
  </div>
15
 
 
 
 
16
 
17
  ### Installation
18
 
@@ -21,6 +24,7 @@ The algorithm is implemented using Keras with JAX backend. Furthermore it heavil
21
  Either install the following in your Python environment, or use the [Dockerfile](./Dockerfile) provided in this repository.
22
 
23
  ```bash
 
24
  pip install tyro optuna zea==0.0.4
25
  pip install -U "jax[cuda12]"
26
  ```
 
13
  <p>Eindhoven University of Technology, the Netherlands</p>
14
  </div>
15
 
16
+ <p align="center">
17
+ <img src="animation.gif" alt="Cardiac Ultrasound Dehazing Animation" style="max-width: 100%; height: auto;">
18
+ </p>
19
 
20
  ### Installation
21
 
 
24
  Either install the following in your Python environment, or use the [Dockerfile](./Dockerfile) provided in this repository.
25
 
26
  ```bash
27
+ # requires Python>=3.10
28
  pip install tyro optuna zea==0.0.4
29
  pip install -U "jax[cuda12]"
30
  ```
assets/patient-17-4C-frame-11.png DELETED
Binary file (27.4 kB)
 
assets/patient-2-4C-frame-9.png ADDED
assets/patient-50-4C-frame-53.png DELETED
Binary file (22.1 kB)
 
configs/semantic_dps_opt.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # While these params optimize for the final score of the challenge
2
+ # generally we observe lower perceptual quality of the dehazed results
3
+ # and in some cases even artifacts.
4
+ # we therefore recommend using configs/semantic_dps.yaml instead.
5
+ diffusion_model_path: "hf://tristan-deep/semantic-diffusion-echo-dehazing"
6
+ segmentation_model_path: "hf://tristan-deep/semantic-segmentation-echo-dehazing"
7
+ seed: 42
8
+
9
+ params:
10
+ diffusion_steps: 480
11
+ initial_diffusion_step: 0
12
+ batch_size: 16
13
+ threshold_output_quantile: 0.17447
14
+ preserve_bottom_percent: 32.0
15
+ bottom_transition_width: 7.0
16
+
17
+ mask_params:
18
+ sigma: 0.4704516
19
+ threshold: 0.18935
20
+ fixed_mask_params:
21
+ top_px: 20
22
+ bottom_px: 40
23
+ skeleton_params:
24
+ sigma_pre: 9.919
25
+ sigma_post: 9.5347840479
26
+ threshold: 0.73917
27
+ guidance_kwargs:
28
+ omega: 0.78
29
+ omega_vent: 0.0001
30
+ omega_sept: 15.84
31
+ eta: 0.01105
32
+ smooth_l1_beta: 6.3726
main.py CHANGED
@@ -10,7 +10,7 @@ import tyro
10
  import zea
11
  from keras import ops
12
  from PIL import Image
13
- from skimage import filters, morphology
14
  from zea import Config, init_device, log
15
  from zea.internal.operators import Operator
16
  from zea.models.diffusion import (
@@ -21,136 +21,14 @@ from zea.models.diffusion import (
21
  from zea.tensor_ops import L2
22
  from zea.utils import translate
23
 
24
- from plots import plot_batch_with_named_masks, plot_dehazed_results
25
-
26
-
27
- def L1(x):
28
- """L1 norm of a tensor.
29
-
30
- Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
31
- """
32
- return ops.sum(ops.abs(x))
33
-
34
-
35
- def smooth_L1(x, beta=0.4):
36
- """Smooth L1 loss function.
37
-
38
- Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
39
- while small beta values make it similar to L2 loss.
40
- """
41
- abs_x = ops.abs(x)
42
- loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
43
- return ops.sum(loss)
44
-
45
-
46
- def postprocess(data, normalization_range):
47
- """Postprocess data from model output to image."""
48
- data = ops.clip(data, *normalization_range)
49
- data = translate(data, normalization_range, (0, 255))
50
- data = ops.convert_to_numpy(data)
51
- data = np.squeeze(data, axis=-1)
52
- return np.clip(data, 0, 255).astype("uint8")
53
-
54
-
55
- def preprocess(data, normalization_range):
56
- """Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
57
- data = ops.convert_to_tensor(data, dtype="float32")
58
- data = translate(data, (0, 255), normalization_range)
59
- data = ops.expand_dims(data, axis=-1)
60
- return data
61
-
62
-
63
- def apply_bottom_preservation(
64
- output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
65
- ):
66
- """Apply bottom preservation with smooth windowed transition.
67
-
68
- Args:
69
- output_images: Model output images, (batch, height, width, channels)
70
- input_images: Original input images, (batch, height, width, channels)
71
- preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
72
- transition_width: Percentage of image height for smooth transition (default 10%)
73
-
74
- Returns:
75
- Blended images with preserved bottom portion
76
- """
77
- output_shape = ops.shape(output_images)
78
-
79
- batch_size, height, width, channels = output_shape
80
-
81
- preserve_height = int(height * preserve_bottom_percent / 100.0)
82
- transition_height = int(height * transition_width / 100.0)
83
-
84
- transition_start = height - preserve_height - transition_height
85
- preserve_start = height - preserve_height
86
-
87
- transition_start = max(0, transition_start)
88
- preserve_start = min(height, preserve_start)
89
-
90
- if transition_start >= preserve_start:
91
- transition_start = preserve_start
92
- transition_height = 0
93
-
94
- y_coords = ops.arange(height, dtype="float32")
95
- y_coords = ops.reshape(y_coords, (height, 1, 1))
96
-
97
- if transition_height > 0:
98
- # Smooth transition using cosine interpolation
99
- transition_region = ops.logical_and(
100
- y_coords >= transition_start, y_coords < preserve_start
101
- )
102
-
103
- transition_progress = (y_coords - transition_start) / transition_height
104
- transition_progress = ops.clip(transition_progress, 0.0, 1.0)
105
-
106
- # Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
107
- cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
108
-
109
- blend_weight = ops.where(
110
- y_coords < transition_start,
111
- 0.0,
112
- ops.where(
113
- transition_region,
114
- cosine_weight,
115
- 1.0,
116
- ),
117
- )
118
- else:
119
- # No transition, just hard switch
120
- blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
121
-
122
- blend_weight = ops.expand_dims(blend_weight, axis=0)
123
-
124
- blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
125
-
126
- return blended_images
127
-
128
-
129
- def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
130
- """Extract skeletons from the input images."""
131
- images_np = ops.convert_to_numpy(images)
132
- images_np = np.clip(images_np, input_range[0], input_range[1])
133
- images_np = translate(images_np, input_range, (0, 1))
134
- images_np = np.squeeze(images_np, axis=-1)
135
-
136
- skeleton_masks = []
137
- for img in images_np:
138
- img[img < threshold] = 0
139
- smoothed = filters.gaussian(img, sigma=sigma_pre)
140
- binary = smoothed > filters.threshold_otsu(smoothed)
141
- skeleton = morphology.skeletonize(binary)
142
- skeleton = morphology.dilation(skeleton, morphology.disk(2))
143
- skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
144
- skeleton_masks.append(skeleton)
145
-
146
- skeleton_masks = np.array(skeleton_masks)
147
- skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
148
-
149
- # normalize to [0, 1]
150
- min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
151
- skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
152
-
153
- return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)
154
 
155
 
156
  class IdentityOperator(Operator):
@@ -250,7 +128,6 @@ class SemanticDPS(DPS):
250
  masks_sept + masks_fixed + masks_skeleton + masks_dark, 0, 1
251
  )
252
 
253
- # background = not masks_strong, not vent
254
  background = ops.where(masks_strong < 0.1, 1.0, 0.0) * ops.where(
255
  masks_vent == 0, 1.0, 0.0
256
  )
@@ -534,6 +411,7 @@ def main(
534
  masks_viz = copy.deepcopy(masks)
535
  masks_viz.pop("haze")
536
 
 
537
  masks_viz = {k: v[:num_img] for k, v in masks_viz.items()}
538
 
539
  fig = plot_batch_with_named_masks(
@@ -553,6 +431,14 @@ def main(
553
  fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
554
  log.success(f"Segmentation steps saved to {log.yellow(path)}")
555
 
 
 
 
 
 
 
 
 
556
  plt.close("all")
557
 
558
 
 
10
  import zea
11
  from keras import ops
12
  from PIL import Image
13
+ from skimage import filters
14
  from zea import Config, init_device, log
15
  from zea.internal.operators import Operator
16
  from zea.models.diffusion import (
 
21
  from zea.tensor_ops import L2
22
  from zea.utils import translate
23
 
24
+ from plots import create_animation, plot_batch_with_named_masks, plot_dehazed_results
25
+ from utils import (
26
+ apply_bottom_preservation,
27
+ extract_skeleton,
28
+ postprocess,
29
+ preprocess,
30
+ smooth_L1,
31
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  class IdentityOperator(Operator):
 
128
  masks_sept + masks_fixed + masks_skeleton + masks_dark, 0, 1
129
  )
130
 
 
131
  background = ops.where(masks_strong < 0.1, 1.0, 0.0) * ops.where(
132
  masks_vent == 0, 1.0, 0.0
133
  )
 
411
  masks_viz = copy.deepcopy(masks)
412
  masks_viz.pop("haze")
413
 
414
+ num_img = 2 # hardcoded as the plotting figure only neatly supports 2 rows
415
  masks_viz = {k: v[:num_img] for k, v in masks_viz.items()}
416
 
417
  fig = plot_batch_with_named_masks(
 
431
  fig.savefig(path.with_suffix(".pdf"), **save_kwargs)
432
  log.success(f"Segmentation steps saved to {log.yellow(path)}")
433
 
434
+ last_batch_size = len(diffusion_model.track_progress[0])
435
+ create_animation(
436
+ preprocess(hazy_images[-last_batch_size:], diffusion_model.input_range),
437
+ diffusion_model,
438
+ output_path="animation.gif",
439
+ fps=10,
440
+ )
441
+
442
  plt.close("all")
443
 
444
 
plots.py CHANGED
@@ -2,6 +2,7 @@ import json
2
  from pathlib import Path
3
  from typing import Any, Dict, List
4
 
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import tyro
@@ -9,8 +10,13 @@ from keras import ops
9
  from matplotlib.patches import PathPatch
10
  from matplotlib.path import Path as pltPath
11
  from skimage import measure
 
 
 
12
  from zea.visualize import plot_image_grid
13
 
 
 
14
 
15
  def add_shape_from_mask(ax, mask, **kwargs):
16
  """add a shape to axis from mask array.
@@ -335,6 +341,62 @@ def plot_optimization_history_from_json(
335
  plt.close(fig)
336
 
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"):
339
  json_path = Path(json_file)
340
  if not json_path.exists():
 
2
  from pathlib import Path
3
  from typing import Any, Dict, List
4
 
5
+ import keras
6
  import matplotlib.pyplot as plt
7
  import numpy as np
8
  import tyro
 
10
  from matplotlib.patches import PathPatch
11
  from matplotlib.path import Path as pltPath
12
  from skimage import measure
13
+ from zea import log
14
+ from zea.io_lib import matplotlib_figure_to_numpy
15
+ from zea.utils import save_to_gif
16
  from zea.visualize import plot_image_grid
17
 
18
+ from utils import postprocess
19
+
20
 
21
  def add_shape_from_mask(ax, mask, **kwargs):
22
  """add a shape to axis from mask array.
 
341
  plt.close(fig)
342
 
343
 
344
+ def create_animation_frame(hazy_images, tissue_frame, haze_frame):
345
+ """Create a single animation frame from the tracked progress."""
346
+ batch, height, width = ops.shape(hazy_images)
347
+ frame_stack = ops.stack(
348
+ [
349
+ hazy_images,
350
+ tissue_frame,
351
+ haze_frame,
352
+ ]
353
+ )
354
+ frame_stack = ops.reshape(frame_stack, (-1, height, width))
355
+ fig_frame, _ = plot_image_grid(
356
+ frame_stack,
357
+ ncols=len(hazy_images),
358
+ remove_axis=False,
359
+ vmin=0,
360
+ vmax=255,
361
+ )
362
+ labels = ["Hazy", "Tissue"] if haze_frame is None else ["Hazy", "Tissue", "Haze"]
363
+ for i, ax in enumerate(fig_frame.axes):
364
+ label = labels[i % len(labels)]
365
+ ax.set_ylabel(label, fontsize=12)
366
+ frame_array = matplotlib_figure_to_numpy(fig_frame)
367
+ plt.close(fig_frame)
368
+ return frame_array
369
+
370
+
371
+ def create_animation(hazy_images, diffusion_model, output_path, fps):
372
+ """Create animation from tracked progress frames."""
373
+ if not (len(diffusion_model.track_progress) > 1):
374
+ log.warning(
375
+ "Animation requested but no intermediate frames were tracked. "
376
+ "Try reducing diffusion_steps or ensure progress tracking is enabled."
377
+ )
378
+ return
379
+
380
+ log.info(f"Creating animation with {len(diffusion_model.track_progress)} frames...")
381
+
382
+ animation_frames = []
383
+ progbar = keras.utils.Progbar(
384
+ len(diffusion_model.track_progress), unit_name="frame"
385
+ )
386
+ for tissue_frame in diffusion_model.track_progress:
387
+ haze_frame = hazy_images - tissue_frame - 1
388
+ tissue_frame = postprocess(tissue_frame, diffusion_model.input_range)
389
+ haze_frame = postprocess(haze_frame, diffusion_model.input_range)
390
+ _hazy_images = postprocess(hazy_images, diffusion_model.input_range)
391
+ frame_array = create_animation_frame(_hazy_images, tissue_frame, haze_frame)
392
+ animation_frames.append(frame_array)
393
+ progbar.add(1)
394
+
395
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
396
+ animation_path = Path(output_path).with_suffix(".gif")
397
+ save_to_gif(animation_frames, animation_path, fps=fps)
398
+
399
+
400
  def main(json_file: str, output_dir: str = "plots", method: str = "semantic_dps"):
401
  json_path = Path(json_file)
402
  if not json_path.exists():
sweeper.py CHANGED
@@ -136,7 +136,6 @@ class OptunaObjective:
136
  "omega": trial.suggest_float("omega", 0.5, 50.0, log=True),
137
  "omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True),
138
  "omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True),
139
- "omega_dark": trial.suggest_float("omega_dark", 0.001, 50.0, log=True),
140
  "eta": trial.suggest_float("eta", 0.001, 1.0, log=True),
141
  "smooth_l1_beta": trial.suggest_float(
142
  "smooth_l1_beta", 0.1, 10.0, log=True
 
136
  "omega": trial.suggest_float("omega", 0.5, 50.0, log=True),
137
  "omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True),
138
  "omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True),
 
139
  "eta": trial.suggest_float("eta", 0.001, 1.0, log=True),
140
  "smooth_l1_beta": trial.suggest_float(
141
  "smooth_l1_beta", 0.1, 10.0, log=True
utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from keras import ops
3
+ from skimage import filters, morphology
4
+ from zea.utils import translate
5
+
6
+
7
+ def L1(x):
8
+ """L1 norm of a tensor.
9
+
10
+ Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html
11
+ """
12
+ return ops.sum(ops.abs(x))
13
+
14
+
15
+ def smooth_L1(x, beta=0.4):
16
+ """Smooth L1 loss function.
17
+
18
+ Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss,
19
+ while small beta values make it similar to L2 loss.
20
+ """
21
+ abs_x = ops.abs(x)
22
+ loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta)
23
+ return ops.sum(loss)
24
+
25
+
26
+ def postprocess(data, normalization_range):
27
+ """Postprocess data from model output to image."""
28
+ data = ops.clip(data, *normalization_range)
29
+ data = translate(data, normalization_range, (0, 255))
30
+ data = ops.convert_to_numpy(data)
31
+ data = np.squeeze(data, axis=-1)
32
+ return np.clip(data, 0, 255).astype("uint8")
33
+
34
+
35
+ def preprocess(data, normalization_range):
36
+ """Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range."""
37
+ data = ops.convert_to_tensor(data, dtype="float32")
38
+ data = translate(data, (0, 255), normalization_range)
39
+ data = ops.expand_dims(data, axis=-1)
40
+ return data
41
+
42
+
43
+ def apply_bottom_preservation(
44
+ output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0
45
+ ):
46
+ """Apply bottom preservation with smooth windowed transition.
47
+
48
+ Args:
49
+ output_images: Model output images, (batch, height, width, channels)
50
+ input_images: Original input images, (batch, height, width, channels)
51
+ preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%)
52
+ transition_width: Percentage of image height for smooth transition (default 10%)
53
+
54
+ Returns:
55
+ Blended images with preserved bottom portion
56
+ """
57
+ output_shape = ops.shape(output_images)
58
+
59
+ batch_size, height, width, channels = output_shape
60
+
61
+ preserve_height = int(height * preserve_bottom_percent / 100.0)
62
+ transition_height = int(height * transition_width / 100.0)
63
+
64
+ transition_start = height - preserve_height - transition_height
65
+ preserve_start = height - preserve_height
66
+
67
+ transition_start = max(0, transition_start)
68
+ preserve_start = min(height, preserve_start)
69
+
70
+ if transition_start >= preserve_start:
71
+ transition_start = preserve_start
72
+ transition_height = 0
73
+
74
+ y_coords = ops.arange(height, dtype="float32")
75
+ y_coords = ops.reshape(y_coords, (height, 1, 1))
76
+
77
+ if transition_height > 0:
78
+ # Smooth transition using cosine interpolation
79
+ transition_region = ops.logical_and(
80
+ y_coords >= transition_start, y_coords < preserve_start
81
+ )
82
+
83
+ transition_progress = (y_coords - transition_start) / transition_height
84
+ transition_progress = ops.clip(transition_progress, 0.0, 1.0)
85
+
86
+ # Use cosine for smooth transition (0.5 * (1 - cos(π * t)))
87
+ cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress))
88
+
89
+ blend_weight = ops.where(
90
+ y_coords < transition_start,
91
+ 0.0,
92
+ ops.where(
93
+ transition_region,
94
+ cosine_weight,
95
+ 1.0,
96
+ ),
97
+ )
98
+ else:
99
+ # No transition, just hard switch
100
+ blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0)
101
+
102
+ blend_weight = ops.expand_dims(blend_weight, axis=0)
103
+
104
+ blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images
105
+
106
+ return blended_images
107
+
108
+
109
+ def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3):
110
+ """Extract skeletons from the input images."""
111
+ images_np = ops.convert_to_numpy(images)
112
+ images_np = np.clip(images_np, input_range[0], input_range[1])
113
+ images_np = translate(images_np, input_range, (0, 1))
114
+ images_np = np.squeeze(images_np, axis=-1)
115
+
116
+ skeleton_masks = []
117
+ for img in images_np:
118
+ img[img < threshold] = 0
119
+ smoothed = filters.gaussian(img, sigma=sigma_pre)
120
+ binary = smoothed > filters.threshold_otsu(smoothed)
121
+ skeleton = morphology.skeletonize(binary)
122
+ skeleton = morphology.dilation(skeleton, morphology.disk(2))
123
+ skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post)
124
+ skeleton_masks.append(skeleton)
125
+
126
+ skeleton_masks = np.array(skeleton_masks)
127
+ skeleton_masks = np.expand_dims(skeleton_masks, axis=-1)
128
+
129
+ # normalize to [0, 1]
130
+ min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks)
131
+ skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8)
132
+
133
+ return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype)