Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
from keras import ops | |
from skimage import filters, morphology | |
from zea.utils import translate | |
def L1(x): | |
"""L1 norm of a tensor. | |
Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html | |
""" | |
return ops.sum(ops.abs(x)) | |
def smooth_L1(x, beta=0.4): | |
"""Smooth L1 loss function. | |
Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss, | |
while small beta values make it similar to L2 loss. | |
""" | |
abs_x = ops.abs(x) | |
loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta) | |
return ops.sum(loss) | |
def postprocess(data, normalization_range): | |
"""Postprocess data from model output to image.""" | |
data = ops.clip(data, *normalization_range) | |
data = translate(data, normalization_range, (0, 255)) | |
data = ops.convert_to_numpy(data) | |
data = np.squeeze(data, axis=-1) | |
return np.clip(data, 0, 255).astype("uint8") | |
def preprocess(data, normalization_range): | |
"""Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range.""" | |
data = ops.convert_to_tensor(data, dtype="float32") | |
data = translate(data, (0, 255), normalization_range) | |
data = ops.expand_dims(data, axis=-1) | |
return data | |
def apply_bottom_preservation( | |
output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0 | |
): | |
"""Apply bottom preservation with smooth windowed transition. | |
Args: | |
output_images: Model output images, (batch, height, width, channels) | |
input_images: Original input images, (batch, height, width, channels) | |
preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%) | |
transition_width: Percentage of image height for smooth transition (default 10%) | |
Returns: | |
Blended images with preserved bottom portion | |
""" | |
output_shape = ops.shape(output_images) | |
batch_size, height, width, channels = output_shape | |
preserve_height = int(height * preserve_bottom_percent / 100.0) | |
transition_height = int(height * transition_width / 100.0) | |
transition_start = height - preserve_height - transition_height | |
preserve_start = height - preserve_height | |
transition_start = max(0, transition_start) | |
preserve_start = min(height, preserve_start) | |
if transition_start >= preserve_start: | |
transition_start = preserve_start | |
transition_height = 0 | |
y_coords = ops.arange(height, dtype="float32") | |
y_coords = ops.reshape(y_coords, (height, 1, 1)) | |
if transition_height > 0: | |
# Smooth transition using cosine interpolation | |
transition_region = ops.logical_and( | |
y_coords >= transition_start, y_coords < preserve_start | |
) | |
transition_progress = (y_coords - transition_start) / transition_height | |
transition_progress = ops.clip(transition_progress, 0.0, 1.0) | |
# Use cosine for smooth transition (0.5 * (1 - cos(π * t))) | |
cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress)) | |
blend_weight = ops.where( | |
y_coords < transition_start, | |
0.0, | |
ops.where( | |
transition_region, | |
cosine_weight, | |
1.0, | |
), | |
) | |
else: | |
# No transition, just hard switch | |
blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0) | |
blend_weight = ops.expand_dims(blend_weight, axis=0) | |
blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images | |
return blended_images | |
def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3): | |
"""Extract skeletons from the input images.""" | |
images_np = ops.convert_to_numpy(images) | |
images_np = np.clip(images_np, input_range[0], input_range[1]) | |
images_np = translate(images_np, input_range, (0, 1)) | |
images_np = np.squeeze(images_np, axis=-1) | |
skeleton_masks = [] | |
for img in images_np: | |
img[img < threshold] = 0 | |
smoothed = filters.gaussian(img, sigma=sigma_pre) | |
binary = smoothed > filters.threshold_otsu(smoothed) | |
skeleton = morphology.skeletonize(binary) | |
skeleton = morphology.dilation(skeleton, morphology.disk(2)) | |
skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post) | |
skeleton_masks.append(skeleton) | |
skeleton_masks = np.array(skeleton_masks) | |
skeleton_masks = np.expand_dims(skeleton_masks, axis=-1) | |
# normalize to [0, 1] | |
min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks) | |
skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8) | |
return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype) | |