|
import torch |
|
from typing import List, Tuple, Union, Optional |
|
from transformers.image_processing_utils import BaseImageProcessor |
|
from transformers.image_transforms import resize, center_crop, normalize |
|
from transformers.utils.generic import TensorType |
|
from transformers.image_processing_utils import BatchFeature |
|
from PIL import Image |
|
import torchvision.transforms |
|
import numpy as np |
|
|
|
|
|
class EncoderImageProcessor(BaseImageProcessor): |
|
model_input_names: List[str] = ["pixel_values"] |
|
|
|
def __init__( |
|
self, |
|
input_size: Tuple[int, int] = (224, 224), |
|
mean: List[float] = [0.5, 0.5, 0.5], |
|
std: List[float] = [0.5, 0.5, 0.5], |
|
resize_strategy: str = "letterbox", |
|
**kwargs, |
|
): |
|
""" |
|
Initialize an image processor for the EncoderModel. |
|
|
|
Args: |
|
input_size (Tuple[int, int]): Size (height, width) of input images. |
|
mean (List[float]): Mean for normalization. |
|
std (List[float]): Std for normalization. |
|
resize_strategy (str): Resize strategy ("letterbox", "resize-crop", or "resize-naive"). |
|
""" |
|
self.input_size = input_size |
|
self.mean = mean |
|
self.std = std |
|
self.resize_strategy = resize_strategy |
|
|
|
super().__init__(**kwargs) |
|
|
|
def apply_transform(self, image: Image.Image) -> torch.Tensor: |
|
""" |
|
Apply transformations: Resize -> CenterCrop -> Normalize. |
|
|
|
Args: |
|
image (Image.Image): A PIL image to transform. |
|
|
|
Returns: |
|
torch.Tensor: Transformed image tensor. |
|
""" |
|
image = np.array(image) |
|
|
|
image = resize(image, size=self.input_size) |
|
|
|
|
|
image = center_crop(image, size=self.input_size) |
|
image = normalize(image, mean=self.mean, std=self.std) |
|
|
|
image = ( |
|
torch.Tensor(image).to(torch.float32).permute(2, 0, 1) / 255.0 |
|
) |
|
|
|
return image |
|
|
|
def preprocess( |
|
self, |
|
images: Union[Image.Image, List[Image.Image]], |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
**kwargs, |
|
) -> BatchFeature: |
|
""" |
|
Preprocess a batch of images. |
|
|
|
Args: |
|
images (Union[Image.Image, List[Image.Image]]): A single image or a list of images. |
|
return_tensors (Optional[Union[str, TensorType]]): Format for the output tensors ("pt" for PyTorch). |
|
|
|
Returns: |
|
BatchFeature: A batch feature with preprocessed images. |
|
""" |
|
if not isinstance(images, list): |
|
images = [images] |
|
|
|
assert isinstance(images, list) and all( |
|
isinstance(item, (np.ndarray, Image.Image)) for item in images |
|
) |
|
if isinstance(images, Image.Image): |
|
images = [img.convert("RGB") for img in images] |
|
|
|
pixel_values = torch.stack([self.apply_transform(image) for image in images]) |
|
|
|
|
|
if return_tensors == "pt": |
|
return BatchFeature(data={"pixel_values": pixel_values}) |
|
elif return_tensors is None: |
|
return BatchFeature(data={"pixel_values": pixel_values.numpy()}) |
|
else: |
|
raise ValueError(f"Unsupported tensor type: {return_tensors}") |
|
|
|
def __call__( |
|
self, images: Union[Image.Image, List[Image.Image]], **kwargs |
|
) -> BatchFeature: |
|
""" |
|
Callable interface for preprocessing images. |
|
|
|
Args: |
|
images (Union[Image.Image, List[Image.Image]]): A single image or a list of images. |
|
|
|
Returns: |
|
BatchFeature: Preprocessed image batch. |
|
""" |
|
return self.preprocess(images, **kwargs) |
|
|