import torch from typing import List, Tuple, Union, Optional from transformers.image_processing_utils import BaseImageProcessor from transformers.image_transforms import resize, center_crop, normalize from transformers.utils.generic import TensorType from transformers.image_processing_utils import BatchFeature from PIL import Image import torchvision.transforms import numpy as np class EncoderImageProcessor(BaseImageProcessor): model_input_names: List[str] = ["pixel_values"] def __init__( self, input_size: Tuple[int, int] = (224, 224), mean: List[float] = [0.5, 0.5, 0.5], std: List[float] = [0.5, 0.5, 0.5], resize_strategy: str = "letterbox", **kwargs, ): """ Initialize an image processor for the EncoderModel. Args: input_size (Tuple[int, int]): Size (height, width) of input images. mean (List[float]): Mean for normalization. std (List[float]): Std for normalization. resize_strategy (str): Resize strategy ("letterbox", "resize-crop", or "resize-naive"). """ self.input_size = input_size self.mean = mean self.std = std self.resize_strategy = resize_strategy super().__init__(**kwargs) def apply_transform(self, image: Image.Image) -> torch.Tensor: """ Apply transformations: Resize -> CenterCrop -> Normalize. Args: image (Image.Image): A PIL image to transform. Returns: torch.Tensor: Transformed image tensor. """ image = np.array(image) # Resize image = resize(image, size=self.input_size) # Center Crop image = center_crop(image, size=self.input_size) image = normalize(image, mean=self.mean, std=self.std) # Convert to tensor and normalize image = ( torch.Tensor(image).to(torch.float32).permute(2, 0, 1) / 255.0 ) # Convert to CHW format return image def preprocess( self, images: Union[Image.Image, List[Image.Image]], return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> BatchFeature: """ Preprocess a batch of images. Args: images (Union[Image.Image, List[Image.Image]]): A single image or a list of images. return_tensors (Optional[Union[str, TensorType]]): Format for the output tensors ("pt" for PyTorch). Returns: BatchFeature: A batch feature with preprocessed images. """ if not isinstance(images, list): images = [images] assert isinstance(images, list) and all( isinstance(item, (np.ndarray, Image.Image)) for item in images ) if isinstance(images, Image.Image): images = [img.convert("RGB") for img in images] pixel_values = torch.stack([self.apply_transform(image) for image in images]) # Handle tensor output type if return_tensors == "pt": return BatchFeature(data={"pixel_values": pixel_values}) elif return_tensors is None: return BatchFeature(data={"pixel_values": pixel_values.numpy()}) else: raise ValueError(f"Unsupported tensor type: {return_tensors}") def __call__( self, images: Union[Image.Image, List[Image.Image]], **kwargs ) -> BatchFeature: """ Callable interface for preprocessing images. Args: images (Union[Image.Image, List[Image.Image]]): A single image or a list of images. Returns: BatchFeature: Preprocessed image batch. """ return self.preprocess(images, **kwargs)