"""Image processor class for MolmoAct""" from typing import TYPE_CHECKING, Tuple, List, Optional, Union, Dict, Any import numpy as np import einops import torch import torchvision.transforms from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import convert_image_dtype from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, ImageInput, is_valid_image, valid_images, to_numpy_array, ) from transformers.image_transforms import convert_to_rgb, to_channel_dimension_format from transformers.processing_utils import ImagesKwargs from transformers.image_processing_utils import BaseImageProcessor from transformers.utils import logging from transformers.feature_extraction_utils import BatchFeature from transformers.utils import TensorType, logging if TYPE_CHECKING: from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) def is_multi_image(image: Union[ImageInput, List[ImageInput]]) -> bool: return isinstance(image, (list, tuple)) def make_batched_images(images) -> List[ImageInput]: """ Accepts images in list or nested list format. Args: images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): The input image. Returns: list: A list of images or a list of lists of images. """ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): return images elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): return images elif is_valid_image(images): return [images] raise ValueError(f"Could not make batched images from {images}") def normalize_image(image: np.ndarray, normalize_mode: str) -> np.ndarray: if normalize_mode == "openai": image -= np.array(OPENAI_CLIP_MEAN, dtype=np.float32)[None, None, :] image /= np.array(OPENAI_CLIP_STD, dtype=np.float32)[None, None, :] elif normalize_mode == "siglip": image = np.asarray(-1.0, dtype=np.float32) + image * np.asarray(2.0, dtype=np.float32) elif normalize_mode == "dino": image -= np.array([0.485, 0.456, 0.406], dtype=np.float32)[None, None, :] image /= np.array([0.229, 0.224, 0.225], dtype=np.float32)[None, None, :] else: raise NotImplementedError(normalize_mode) return image # Helper to ensure output_size is a 2-tuple of built-in Python ints def _ensure_pyint_size2(size): """ Ensure `size` is a 2-tuple of built-in Python ints. Accepts int, list/tuple, or numpy array of length 1 or 2. """ import numpy as np # If it's an array-like, normalize to length-2 tuple if isinstance(size, (list, tuple, np.ndarray)): if len(size) == 2: return (int(size[0]), int(size[1])) elif len(size) == 1: s = int(size[0]) return (s, s) else: # Fallback: try to interpret as square size using first element s = int(size[0]) return (s, s) # Scalar → square size s = int(size) return (s, s) def resize_and_pad( image, desired_output_size, resize_method="torch-bilinear", pad_value=0, ): """Resize an image while padding to preserve uts aspect ratio.""" desired_output_size = _ensure_pyint_size2(desired_output_size) desired_height, desired_width = desired_output_size height, width = image.shape[:2] # Cast into float32 since the training code did this in float32 and it (very rarely) effects # the results after rounding. image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32) image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32) image_scale = min(image_scale_x, image_scale_y) scaled_height = int(np.array(height, np.float32) * image_scale) scaled_width = int(np.array(width, np.float32) * image_scale) if resize_method in ["torch-bilinear"]: image = torch.permute(torch.from_numpy(image), [2, 0, 1]) image = convert_image_dtype(image) # resize in float32 to match the training code mode = InterpolationMode.BILINEAR image = torchvision.transforms.Resize([scaled_height, scaled_width], mode, antialias=True)(image) image = torch.clip(image, 0.0, 1.0) image = torch.permute(image, [1, 2, 0]).numpy() else: raise NotImplementedError(resize_method) top_pad = (desired_height - scaled_height) // 2 left_pad = (desired_width - scaled_width) // 2 padding = [ [top_pad, desired_height - scaled_height - top_pad], [left_pad, desired_width - scaled_width - left_pad], [0, 0] ] image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2]) image = np.pad(image, padding, constant_values=pad_value) return image, image_mask def metaclip_resize(image, desired_output_size): desired_output_size = _ensure_pyint_size2(desired_output_size) image = torch.permute(torch.from_numpy(image), [2, 0, 1]) if torch.is_floating_point(image): image = torchvision.transforms.Resize( desired_output_size, InterpolationMode.BICUBIC, antialias=True)(image) image = torch.clip(image, 0.0, 1.0) else: assert image.dtype == torch.uint8, "Expected float images or uint8 images, but got {}".format(image.dtype) image = torchvision.transforms.Resize( desired_output_size, InterpolationMode.BICUBIC, antialias=True)(image) image = image.to(torch.float32) image = torch.clip(image, 0, 255) image = image / 255.0 resized = torch.permute(image, [1, 2, 0]).numpy() image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_) return resized, image_mask def siglip_resize_and_pad( image: np.ndarray, desired_output_size: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: desired_output_size = _ensure_pyint_size2(desired_output_size) if len(image.shape) == 3: is_video = False image = torch.permute(torch.from_numpy(image), [2, 0, 1]) else: is_video = True image = torch.permute(torch.from_numpy(image), [0, 3, 1, 2]) dtype = image.dtype if torch.is_floating_point(image): in_min = 0.0 in_max = 1.0 resized = torchvision.transforms.Resize( desired_output_size, InterpolationMode.BILINEAR, antialias=False, )(image) resized = torch.clip(resized, 0.0, 1.0).to(dtype) else: assert image.dtype == torch.uint8, "SigLIP expects float images or uint8 images, but got {}".format(image.dtype) in_min = 0.0 in_max = 255.0 resized = torchvision.transforms.Resize( desired_output_size, InterpolationMode.BILINEAR, antialias=False, )(image) resized = torch.clip(resized, 0, 255).to(dtype) resized = resized.to(torch.float32) resized = (resized - in_min) / (in_max - in_min) if is_video: resized = torch.permute(resized, [0, 2, 3, 1]).numpy() image_mask = None else: resized = torch.permute(resized, [1, 2, 0]).numpy() image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_) return resized, image_mask def dino_resize_and_pad( image: np.ndarray, desired_output_size: Tuple[int, int], ) -> Tuple[np.ndarray, np.ndarray]: desired_output_size = _ensure_pyint_size2(desired_output_size) image = torch.permute(torch.from_numpy(image), [2, 0, 1]) dtype = image.dtype if torch.is_floating_point(image): resized = torchvision.transforms.Resize( desired_output_size, InterpolationMode.BICUBIC, antialias=True, )(image) resized = torch.clip(resized, 0.0, 1.0).to(torch.float32) else: assert image.dtype == torch.uint8, "DINOv2 expects float images or uint8 images, but got {}".format(image.dtype) resized = torchvision.transforms.Resize( desired_output_size, InterpolationMode.BICUBIC, antialias=True, )(image) resized = torch.clip(resized, 0, 255).to(torch.float32) resized = resized / 255.0 resized = torch.permute(resized, [1, 2, 0]).numpy() image_mask = np.ones_like(resized[:, :, 0], dtype=np.bool_) return resized, image_mask def resize_image( image: np.ndarray, resize_mode: str, output_size: Tuple[int, int], pad_value: float, ) -> Tuple[np.ndarray, np.ndarray]: if resize_mode == "siglip": return siglip_resize_and_pad(image, output_size) elif resize_mode == "dino": return dino_resize_and_pad(image, output_size) elif resize_mode == "metaclip": return metaclip_resize(image, output_size) else: resize = "torch-bilinear" if resize_mode == "default" else resize_mode return resize_and_pad( image, output_size, resize_method=resize, pad_value=pad_value, ) def select_tiling(h, w, patch_size, max_num_crops): """Divide in image of size [w, h] in up to max_num_patches of size patch_size""" original_size = np.stack([h, w]) # [1, 2] original_res = h * w tilings = [] for i in range(1, max_num_crops + 1): for j in range(1, max_num_crops + 1): if i*j <= max_num_crops: tilings.append((i, j)) # sort so argmin and argmax favour smaller tilings in the event of a tie tilings.sort(key=lambda x: (x[0]*x[1], x[0])) candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2] candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2] # How much we would need to scale the image to fit exactly in each tiling original_size = np.stack([h, w], dtype=np.float32) # [1, 2] # The original size can be zero in rare cases if the image is smaller than the margin # In those cases letting the scale become infinite means the tiling is based on the # other side, or falls back to the smallest tiling with np.errstate(divide='ignore'): required_scale_d = candidate_resolutions.astype(np.float32) / original_size, required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1] if np.all(required_scale < 1): # We are forced to downscale, so try to minimize the amount of downscaling ix = np.argmax(required_scale) else: # Pick the resolution that required the least upscaling so that it most closely fits the image required_scale = np.where(required_scale < 1.0, 10e9, required_scale) ix = np.argmin(required_scale) return candidate_tilings[ix] def build_resized_image( image: np.ndarray, resize_mode: str, normalized_mode: str, base_image_input_size: List[int], pad_value: float, image_patch_size: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: resized, resized_mask = resize_image( image, resize_mode, base_image_input_size, pad_value, ) resized = normalize_image(resized, normalized_mode) if len(resized.shape) == 3: resized = np.expand_dims(resized, 0) resized_mask = np.expand_dims(resized_mask, 0) crop_patch_w = base_image_input_size[1] // image_patch_size crop_patch_h = base_image_input_size[0] // image_patch_size resize_idx = np.arange(crop_patch_w*crop_patch_h).reshape([crop_patch_h, crop_patch_w]) return resized, resized_mask, resize_idx def build_overlapping_crops( image: np.ndarray, resize_mode: str, normalize_mode: str, max_crops: int, overlap_margins: List[int], base_image_input_size: List[int], pad_value: float, image_patch_size: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Decompose an image into a set of overlapping crops :return crop_arr: [n_crops, h, w, 3] The crops :return mask_arr: [n_crops, h, w] The padding masks :return patch_idx: [overlap_patch_h, overlap_patch_w] For each patch in the resized image the crops were extracted from, what patch in `crop_arr` it corresponds to """ original_image_h, original_image_w = image.shape[:2] crop_size = base_image_input_size[0] assert base_image_input_size[0] == base_image_input_size[1] left_margin, right_margin = overlap_margins total_margin_pixels = image_patch_size * (right_margin + left_margin) # pixels removed per dim crop_patches = base_image_input_size[0] // image_patch_size # patches per crop dim crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches crop_window_size = crop_window_patches * image_patch_size crop_patch_w = base_image_input_size[1] // image_patch_size crop_patch_h = base_image_input_size[0] // image_patch_size original_image_h, original_image_w = image.shape[:2] crop_size = base_image_input_size[0] # Decide how to tile the image, to account for the overlap margins we compute the tiling # as if we had an image without the margins and were using a crop size without the margins tiling = select_tiling( original_image_h - total_margin_pixels, original_image_w - total_margin_pixels, crop_window_size, max_crops, ) src, img_mask = resize_image( image, resize_mode, [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels], pad_value, ) src = normalize_image(src, normalize_mode) # Now we have to split the image into crops, and track what patches came from # where in `patch_idx_arr` n_crops = tiling[0] * tiling[1] crop_arr = np.zeros([n_crops, crop_size, crop_size, 3], dtype=src.dtype) mask_arr = np.zeros([n_crops, crop_size, crop_size], dtype=img_mask.dtype) patch_idx_arr = np.zeros([n_crops, crop_patch_h, crop_patch_w], dtype=np.int32) on = 0 on_crop = 0 for i in range(tiling[0]): # Slide over `src` by `crop_window_size` steps, but extract crops of size `crops_size` # which results in overlapping crop windows y0 = i*crop_window_size for j in range(tiling[1]): x0 = j*crop_window_size crop_arr[on_crop] = src[y0:y0+crop_size, x0:x0+crop_size] mask_arr[on_crop] = img_mask[y0:y0+crop_size, x0:x0+crop_size] patch_idx = np.arange(crop_patch_w*crop_patch_h).reshape(crop_patch_h, crop_patch_w) patch_idx += on_crop * crop_patch_h * crop_patch_w # Mask out idx that are in the overlap region if i != 0: patch_idx[:left_margin, :] = -1 if j != 0: patch_idx[:, :left_margin] = -1 if i != tiling[0]-1: patch_idx[-right_margin:, :] = -1 if j != tiling[1]-1: patch_idx[:, -right_margin:] = -1 patch_idx_arr[on_crop] = patch_idx on_crop += 1 # `patch_idx_arr` is ordered crop-by-crop, here we transpose `patch_idx_arr` # so it is ordered left-to-right order patch_idx_arr = np.reshape( patch_idx_arr, [tiling[0], tiling[1], crop_patch_h, crop_patch_w] ) patch_idx_arr = np.transpose(patch_idx_arr, [0, 2, 1, 3]) patch_idx_arr = np.reshape(patch_idx_arr, [-1]) # Now get the parts not in the overlap region, so it should map each patch in `src` # to the correct patch it should come from in `crop_arr` patch_idx_arr = patch_idx_arr[patch_idx_arr >= 0].reshape( src.shape[0]//image_patch_size, src.shape[1]//image_patch_size, ) return crop_arr, mask_arr, patch_idx_arr def batch_pixels_to_patches(array: np.ndarray, patch_size: int) -> np.ndarray: """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]""" if len(array.shape) == 3: n_crops, h, w = array.shape h_patches = h//patch_size w_patches = w//patch_size array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size]) array = np.transpose(array, [0, 1, 3, 2, 4]) array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size]) return array else: n_crops, h, w, c = array.shape h_patches = h//patch_size w_patches = w//patch_size array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c]) array = np.transpose(array, [0, 1, 3, 2, 4, 5]) array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c]) return array def arange_for_pooling( idx_arr: np.ndarray, pool_h: int, pool_w: int, ) -> np.ndarray: h_pad = pool_h * ((idx_arr.shape[0] + pool_h - 1) // pool_h) - idx_arr.shape[0] w_pad = pool_w * ((idx_arr.shape[1] + pool_w - 1) // pool_w) - idx_arr.shape[1] idx_arr = np.pad(idx_arr, [[h_pad//2, (h_pad+1)//2], [w_pad//2, (w_pad+1)//2]], mode='constant',constant_values=-1) return einops.rearrange( idx_arr, "(h dh) (w dw) -> h w (dh dw)", dh=pool_h, dw=pool_w) def image_to_patches_and_grids( image: ImageInput, crop_mode: str, resize_mode: str, normalize_mode: str, max_crops: int, overlap_margins: List[int], base_image_input_size: List[int], pad_value: float, image_patch_size: int, image_pooling_w: int, image_pooling_h: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ :return image_grids, the shape of each (low-res, high-res) image after pooling :return crops, the image crops to processes with the ViT :return mask, the padding mask for each crop :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the patches in `crops` to pool for that token, masked with -1 """ if isinstance(base_image_input_size, int): base_image_input_size = (base_image_input_size, base_image_input_size) base_image_input_d = image_patch_size pooling_w = image_pooling_w pooling_h = image_pooling_h crop_patch_w = base_image_input_size[1] // base_image_input_d crop_patch_h = base_image_input_size[0] // base_image_input_d if crop_mode == "resize": resized, resized_mask, resize_idx = build_resized_image( image, resize_mode, normalize_mode, base_image_input_size, pad_value, image_patch_size ) pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) h, w = pooling_idx.shape[:2] pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) image_grid = [np.array([h, w])] return ( np.stack(image_grid, 0), batch_pixels_to_patches(resized, image_patch_size), batch_pixels_to_patches(resized_mask, image_patch_size).mean(-1), pooling_idx, ) if crop_mode in ["overlap-and-resize-c2", "overlap-and-resize"]: crop_arr, mask_arr, patch_idx_arr = build_overlapping_crops( image, resize_mode, normalize_mode, max_crops, overlap_margins, base_image_input_size, pad_value, image_patch_size, ) pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w) h, w = pooling_idx.shape[:2] pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) image_grid = [np.array([h, w])] if crop_mode == "overlap-and-resize": crop_arr = batch_pixels_to_patches(crop_arr, image_patch_size) mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1) return np.stack(image_grid, 0), crop_arr, mask_arr, pooling_idx # Finally do the same for the global image resized, resized_mask, resize_idx = build_resized_image( image, resize_mode, normalize_mode, base_image_input_size, pad_value, image_patch_size ) crop_arr = np.concatenate([resized, crop_arr], 0) mask_arr = np.concatenate([resized_mask, mask_arr], 0) resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) h, w = resize_idx.shape[:2] resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w]) # Global image goes first, so the order of patches in previous crops gets increased pooling_idx = np.where( pooling_idx >= 0, pooling_idx + crop_patch_h*crop_patch_w, -1 ) pooling_idx = np.concatenate([resize_idx, pooling_idx]) image_grid = [ np.array([h, w]), ] + image_grid mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1) return ( np.stack(image_grid, 0), batch_pixels_to_patches(crop_arr, image_patch_size), mask_arr, pooling_idx ) else: raise NotImplementedError(crop_mode) def image_to_patches_and_tokens( image: ImageInput, crop_mode: str, use_col_tokens: bool, resize_mode: str, normalize_mode: str, max_crops: int, overlap_margins: List[int], base_image_input_size: List[int], pad_value: float, image_patch_size: int, image_pooling_w: int, image_pooling_h: int, image_patch_token_id: int, image_col_token_id: int, image_start_token_id: int, image_end_token_id: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ :return image_tokens, the token IDS for this image, including special tokens :return crops, the image crops to processes with the ViT :return mask, the padding mask for each crop :return pooled_patch_idx, for each patch_id tokens in `image_tokens`, the indices of the patches in `crops` to pool for that token, masked with -1 """ if isinstance(base_image_input_size, int): base_image_input_size = (base_image_input_size, base_image_input_size) base_image_input_d = image_patch_size pooling_w = image_pooling_w pooling_h = image_pooling_h patch_id = image_patch_token_id col_id = image_col_token_id start_id = image_start_token_id end_id = image_end_token_id crop_patch_w = base_image_input_size[1] // base_image_input_d crop_patch_h = base_image_input_size[0] // base_image_input_d if crop_mode == "resize": resized, resized_mask, resize_idx = build_resized_image( image, resize_mode, normalize_mode, base_image_input_size, pad_value, image_patch_size ) pooling_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) h, w = pooling_idx.shape[:2] pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) per_row = np.full( (w,), patch_id, dtype=np.int32 ) if use_col_tokens: per_row = np.concatenate([per_row, [col_id]], 0) extra_tokens = np.tile(per_row, [h]) joint = [ [start_id], extra_tokens, [end_id], ] return ( np.concatenate(joint, 0), batch_pixels_to_patches(resized, image_patch_size), batch_pixels_to_patches(resized_mask, image_patch_size).mean(-1), pooling_idx, ) if crop_mode in ["overlap-and-resize-c2", "overlap-and-resize"]: crop_arr, mask_arr, patch_idx_arr = build_overlapping_crops( image, resize_mode, normalize_mode, max_crops, overlap_margins, base_image_input_size, pad_value, image_patch_size, ) pooling_idx = arange_for_pooling(patch_idx_arr, pooling_h, pooling_w) h, w = pooling_idx.shape[:2] pooling_idx = pooling_idx.reshape([-1, pooling_h*pooling_w]) # Now build the output tokens per_row = np.full(w, patch_id, dtype=np.int32) if use_col_tokens: per_row = np.concatenate([per_row, [col_id]], 0) joint = np.tile(per_row, [h]) joint = [ [start_id], joint, [end_id] ] if crop_mode == "overlap-and-resize": crop_arr = batch_pixels_to_patches(crop_arr, image_patch_size) mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1) return np.concatenate(joint, 0), crop_arr, mask_arr, pooling_idx # Finally do the same for the global image resized, resized_mask, resize_idx = build_resized_image( image, resize_mode, normalize_mode, base_image_input_size, pad_value, image_patch_size ) crop_arr = np.concatenate([resized, crop_arr], 0) mask_arr = np.concatenate([resized_mask, mask_arr], 0) resize_idx = arange_for_pooling(resize_idx, pooling_h, pooling_w) h, w = resize_idx.shape[:2] resize_idx = resize_idx.reshape([-1, pooling_h*pooling_w]) # Global image goes first, so the order of patches in previous crops gets increased pooling_idx = np.where( pooling_idx >= 0, pooling_idx + crop_patch_h*crop_patch_w, -1 ) pooling_idx = np.concatenate([resize_idx, pooling_idx]) per_row = np.full( (w,), patch_id, dtype=np.int32 ) if use_col_tokens: per_row = np.concatenate([per_row, [col_id]], 0) extra_tokens = np.tile(per_row, [h]) joint = [ [start_id], extra_tokens, [end_id], ] + joint mask_arr = batch_pixels_to_patches(mask_arr, image_patch_size).astype(np.float32).mean(axis=-1) return ( np.concatenate(joint, 0), batch_pixels_to_patches(crop_arr, image_patch_size), mask_arr, pooling_idx ) else: raise NotImplementedError(crop_mode) class MolmoActImagesKwargs(ImagesKwargs, total=False): crop_mode: Optional[str] resize_mode: Optional[str] normalize_mode: Optional[str] max_crops: Optional[int] max_multi_image_crops: Optional[int] overlap_margins: Optional[List[int]] base_image_input_size: Optional[List[int]] pad_value: Optional[float] image_patch_size: Optional[int] image_pooling_w: Optional[int] image_pooling_h: Optional[int] class MolmoActImageProcessor(BaseImageProcessor): model_input_names = ["images", "pooled_patches_idx", "image_masks"] def __init__( self, crop_mode: str = "overlap-and-resize-c2", resize_mode: str = "siglip", normalize_mode: str = "siglip", max_crops: int = 8, max_multi_image_crops: int = 4, overlap_margins: List[int] = [4, 4], base_image_input_size: List[int] = (378, 378), pad_value: float = 0.0, image_patch_size: int = 14, image_pooling_w: int = 2, image_pooling_h: int = 2, do_convert_rgb: bool = True, do_pad: Optional[bool] = True, **kwargs, ) -> None: super().__init__(**kwargs) self.crop_mode = crop_mode self.resize_mode = resize_mode self.normalize_mode = normalize_mode self.overlap_margins = overlap_margins self.max_crops = max_crops self.max_multi_image_crops = max_multi_image_crops self.overlap_margins = overlap_margins self.base_image_input_size = base_image_input_size self.pad_value = pad_value self.image_patch_size = image_patch_size self.image_pooling_w = image_pooling_w self.image_pooling_h = image_pooling_h self.do_convert_rgb = do_convert_rgb self.do_pad = do_pad def to_channel_dimension_last( self, images: List[ImageInput], ) -> List[ImageInput]: """ Convert images to channel dimension last. """ new_images = [] for image in images: if is_multi_image(image): new_images.append([to_channel_dimension_format(img, ChannelDimension.LAST) for img in image]) else: new_images.append(to_channel_dimension_format(image, ChannelDimension.LAST)) return new_images def to_numpy_array( self, images: List[ImageInput], ) -> List[np.ndarray]: """ Convert images to numpy array. """ new_images = [] for image in images: if is_multi_image(image): new_images.append([to_numpy_array(img) for img in image]) else: new_images.append(to_numpy_array(image)) return new_images def to_rgb( self, images: List[ImageInput], ) -> List[ImageInput]: """ Convert images to RGB. """ new_images = [] for image in images: if is_multi_image(image): new_images.append([convert_to_rgb(img) for img in image]) else: new_images.append(convert_to_rgb(image)) return new_images def pad_arrays(self, arrays: List[np.ndarray], pad_value: float = -1) -> np.ndarray: max_len = max(arr.shape[0] for arr in arrays) padded_arr = np.full( [len(arrays), max_len] + list(arrays[0].shape[1:]), pad_value, dtype=arrays[0].dtype ) for ix, arr in enumerate(arrays): padded_arr[ix, :len(arr)] = arr[:max_len] return padded_arr def pad_for_batching(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Pad the data for batching. """ images = self.pad_arrays(data["images"]) pooled_patches_idx = self.pad_arrays(data["pooled_patches_idx"]) image_masks = self.pad_arrays(data["image_masks"]) image_grids = self.pad_arrays(data["image_grids"]) new_data = dict( images=images, pooled_patches_idx=pooled_patches_idx, image_masks=image_masks, image_grids=image_grids, ) return new_data def preprocess( self, images: Union[ImageInput, List[ImageInput]], crop_mode: Optional[str] = None, resize_mode: Optional[str] = None, normalize_mode: Optional[str] = None, max_crops: Optional[int] = None, max_multi_image_crops: Optional[int] = None, overlap_margins: Optional[List[int]] = None, base_image_input_size: Optional[List[int]] = None, pad_value: Optional[float] = None, image_patch_size: Optional[int] = None, image_pooling_w: Optional[int] = None, image_pooling_h: Optional[int] = None, do_convert_rgb: Optional[bool] = None, do_pad: Optional[bool] = None, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> BatchFeature: """ Preprocess an image for the model. Args: image: The image to preprocess. crop_mode: The crop mode to use. If None, use the default crop mode. resize_mode: The resize mode to use. If None, use the default resize mode. normalize_mode: The normalization mode to use. If None, use the default normalization mode. max_crops: The maximum number of crops to use. If None, use the default value. max_multi_image_crops: The maximum number of crops to use for multi-image inputs. overlap_margins: The overlap margins to use. If None, use the default values. base_image_input_size: The base image input size to use. If None, use the default size. pad_value: The padding value to use. If None, use the default value. image_patch_size: The size of the image patches. If None, use the default size. image_pooling_h: The height of the image pooling. If None, use the default height. image_pooling_w: The width of the image pooling. If None, use the default width. do_convert_rgb: Whether to convert the image to RGB. If None, use the default value. do_pad: Whether to pad image features. If None, use the default value. Returns: A tuple containing: - The image grids - The preprocessed images - The padding masks - The pooling indices """ images = make_batched_images(images) if not valid_images(images): raise ValueError("Invalid image input") crop_mode = crop_mode or self.crop_mode normalize_mode = normalize_mode or self.normalize_mode resize_mode = resize_mode or self.resize_mode max_crops = max_crops or self.max_crops max_multi_image_crops = max_multi_image_crops or self.max_multi_image_crops overlap_margins = overlap_margins or self.overlap_margins base_image_input_size = base_image_input_size or self.base_image_input_size pad_value = pad_value or self.pad_value image_patch_size = image_patch_size or self.image_patch_size image_pooling_w = image_pooling_w or self.image_pooling_w image_pooling_h = image_pooling_h or self.image_pooling_h do_convert_rgb = do_convert_rgb or self.do_convert_rgb do_pad = do_pad or self.do_pad if do_convert_rgb: images = self.to_rgb(images) # All transformations expect numpy arrays. images = self.to_numpy_array(images) # All transformations expect channel dimension last. images = self.to_channel_dimension_last(images) batch_image_grids = [] batch_crops = [] batch_crop_masks = [] batch_pooled_patches_idx = [] for image in images: if is_multi_image(image): all_image_grids = [] all_crops = [] all_crop_masks = [] pooled_patches_idx = [] for img in image: image_grid, crops, img_mask, pooled_idx = image_to_patches_and_grids( img, crop_mode, resize_mode, normalize_mode, max_multi_image_crops, overlap_margins, base_image_input_size, pad_value, image_patch_size, image_pooling_w, image_pooling_h, ) pooled_patches_idx.append(pooled_idx + sum(np.prod(x.shape[:2]) for x in all_crops)) all_crops.append(crops) all_crop_masks.append(img_mask) all_image_grids.append(image_grid) all_image_grids = np.concatenate(all_image_grids, 0) all_crops = np.concatenate(all_crops, 0) all_crop_masks = np.concatenate(all_crop_masks, 0) pooled_patches_idx = np.concatenate(pooled_patches_idx, 0) batch_image_grids.append(all_image_grids) batch_crops.append(all_crops) batch_crop_masks.append(all_crop_masks) batch_pooled_patches_idx.append(pooled_patches_idx) else: image_grid, crops, img_mask, pooled_idx = image_to_patches_and_grids( image, crop_mode, resize_mode, normalize_mode, max_crops, overlap_margins, base_image_input_size, pad_value, image_patch_size, image_pooling_w, image_pooling_h, ) batch_image_grids.append(image_grid) batch_crops.append(crops) batch_crop_masks.append(img_mask) batch_pooled_patches_idx.append(pooled_idx) data =dict( images=batch_crops, pooled_patches_idx=batch_pooled_patches_idx, image_masks=batch_crop_masks, image_grids=batch_image_grids, ) if do_pad: data = self.pad_for_batching(data) return BatchFeature(data, tensor_type=return_tensors) MolmoActImageProcessor.register_for_auto_class()