229nagibator229's picture
Upload processor
19e3d6a verified
import torch
from typing import List, Tuple, Union, Optional
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_transforms import resize, center_crop, normalize
from transformers.utils.generic import TensorType
from transformers.image_processing_utils import BatchFeature
from PIL import Image
import torchvision.transforms
import numpy as np
class EncoderImageProcessor(BaseImageProcessor):
model_input_names: List[str] = ["pixel_values"]
def __init__(
self,
input_size: Tuple[int, int] = (224, 224),
mean: List[float] = [0.5, 0.5, 0.5],
std: List[float] = [0.5, 0.5, 0.5],
resize_strategy: str = "letterbox",
**kwargs,
):
"""
Initialize an image processor for the EncoderModel.
Args:
input_size (Tuple[int, int]): Size (height, width) of input images.
mean (List[float]): Mean for normalization.
std (List[float]): Std for normalization.
resize_strategy (str): Resize strategy ("letterbox", "resize-crop", or "resize-naive").
"""
self.input_size = input_size
self.mean = mean
self.std = std
self.resize_strategy = resize_strategy
super().__init__(**kwargs)
def apply_transform(self, image: Image.Image) -> torch.Tensor:
"""
Apply transformations: Resize -> CenterCrop -> Normalize.
Args:
image (Image.Image): A PIL image to transform.
Returns:
torch.Tensor: Transformed image tensor.
"""
image = np.array(image)
# Resize
image = resize(image, size=self.input_size)
# Center Crop
image = center_crop(image, size=self.input_size)
image = normalize(image, mean=self.mean, std=self.std)
# Convert to tensor and normalize
image = (
torch.Tensor(image).to(torch.float32).permute(2, 0, 1) / 255.0
) # Convert to CHW format
return image
def preprocess(
self,
images: Union[Image.Image, List[Image.Image]],
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
Preprocess a batch of images.
Args:
images (Union[Image.Image, List[Image.Image]]): A single image or a list of images.
return_tensors (Optional[Union[str, TensorType]]): Format for the output tensors ("pt" for PyTorch).
Returns:
BatchFeature: A batch feature with preprocessed images.
"""
if not isinstance(images, list):
images = [images]
assert isinstance(images, list) and all(
isinstance(item, (np.ndarray, Image.Image)) for item in images
)
if isinstance(images, Image.Image):
images = [img.convert("RGB") for img in images]
pixel_values = torch.stack([self.apply_transform(image) for image in images])
# Handle tensor output type
if return_tensors == "pt":
return BatchFeature(data={"pixel_values": pixel_values})
elif return_tensors is None:
return BatchFeature(data={"pixel_values": pixel_values.numpy()})
else:
raise ValueError(f"Unsupported tensor type: {return_tensors}")
def __call__(
self, images: Union[Image.Image, List[Image.Image]], **kwargs
) -> BatchFeature:
"""
Callable interface for preprocessing images.
Args:
images (Union[Image.Image, List[Image.Image]]): A single image or a list of images.
Returns:
BatchFeature: Preprocessed image batch.
"""
return self.preprocess(images, **kwargs)