from pathlib import Path import numpy as np from keras import ops from PIL import Image from skimage import filters, morphology from zea.utils import translate def L1(x): """L1 norm of a tensor. Implementation of L1 norm: https://mathworld.wolfram.com/L1-Norm.html """ return ops.sum(ops.abs(x)) def smooth_L1(x, beta=0.4): """Smooth L1 loss function. Implementation of Smooth L1 loss. Large beta values make it similar to L1 loss, while small beta values make it similar to L2 loss. """ abs_x = ops.abs(x) loss = ops.where(abs_x < beta, 0.5 * x**2 / beta, abs_x - 0.5 * beta) return ops.sum(loss) def postprocess(data, normalization_range): """Postprocess data from model output to image.""" data = ops.clip(data, *normalization_range) data = translate(data, normalization_range, (0, 255)) data = ops.convert_to_numpy(data) data = np.squeeze(data, axis=-1) return np.clip(data, 0, 255).astype("uint8") def preprocess(data, normalization_range): """Preprocess data for model input. Converts uint8 image(s) in [0, 255] to model input range.""" data = ops.convert_to_tensor(data, dtype="float32") data = translate(data, (0, 255), normalization_range) data = ops.expand_dims(data, axis=-1) return data def apply_bottom_preservation( output_images, input_images, preserve_bottom_percent=30.0, transition_width=10.0 ): """Apply bottom preservation with smooth windowed transition. Args: output_images: Model output images, (batch, height, width, channels) input_images: Original input images, (batch, height, width, channels) preserve_bottom_percent: Percentage of bottom to preserve from input (default 30%) transition_width: Percentage of image height for smooth transition (default 10%) Returns: Blended images with preserved bottom portion """ output_shape = ops.shape(output_images) batch_size, height, width, channels = output_shape preserve_height = int(height * preserve_bottom_percent / 100.0) transition_height = int(height * transition_width / 100.0) transition_start = height - preserve_height - transition_height preserve_start = height - preserve_height transition_start = max(0, transition_start) preserve_start = min(height, preserve_start) if transition_start >= preserve_start: transition_start = preserve_start transition_height = 0 y_coords = ops.arange(height, dtype="float32") y_coords = ops.reshape(y_coords, (height, 1, 1)) if transition_height > 0: # Smooth transition using cosine interpolation transition_region = ops.logical_and( y_coords >= transition_start, y_coords < preserve_start ) transition_progress = (y_coords - transition_start) / transition_height transition_progress = ops.clip(transition_progress, 0.0, 1.0) # Use cosine for smooth transition (0.5 * (1 - cos(π * t))) cosine_weight = 0.5 * (1.0 - ops.cos(np.pi * transition_progress)) blend_weight = ops.where( y_coords < transition_start, 0.0, ops.where( transition_region, cosine_weight, 1.0, ), ) else: # No transition, just hard switch blend_weight = ops.where(y_coords >= preserve_start, 1.0, 0.0) blend_weight = ops.expand_dims(blend_weight, axis=0) blended_images = (1.0 - blend_weight) * output_images + blend_weight * input_images return blended_images def extract_skeleton(images, input_range, sigma_pre=4, sigma_post=4, threshold=0.3): """Extract skeletons from the input images.""" images_np = ops.convert_to_numpy(images) images_np = np.clip(images_np, input_range[0], input_range[1]) images_np = translate(images_np, input_range, (0, 1)) images_np = np.squeeze(images_np, axis=-1) skeleton_masks = [] for img in images_np: img[img < threshold] = 0 smoothed = filters.gaussian(img, sigma=sigma_pre) binary = smoothed > filters.threshold_otsu(smoothed) skeleton = morphology.skeletonize(binary) skeleton = morphology.dilation(skeleton, morphology.disk(2)) skeleton = filters.gaussian(skeleton.astype(np.float32), sigma=sigma_post) skeleton_masks.append(skeleton) skeleton_masks = np.array(skeleton_masks) skeleton_masks = np.expand_dims(skeleton_masks, axis=-1) # normalize to [0, 1] min_val, max_val = np.min(skeleton_masks), np.max(skeleton_masks) skeleton_masks = (skeleton_masks - min_val) / (max_val - min_val + 1e-8) return ops.convert_to_tensor(skeleton_masks, dtype=images.dtype) def load_image(filename, grayscale=True): """Load an image file and return a numpy array using PIL. Args: filename (str): The path to the image file. grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to True. Returns: numpy.ndarray: A numpy array of the image. Raises: FileNotFoundError: If the file does not exist. """ filename = Path(filename) if not filename.exists(): raise FileNotFoundError(f"File {filename} does not exist") img = Image.open(filename) if grayscale: img = img.convert("L") else: img = img.convert("RGB") arr = np.array(img) return arr