File size: 3,750 Bytes
ef29d72
 
 
 
 
 
 
19e3d6a
a990c06
ef29d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a990c06
ef29d72
a990c06
ef29d72
 
 
a990c06
ef29d72
19e3d6a
 
 
ef29d72
a990c06
ef29d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e3d6a
 
 
 
 
 
 
ef29d72
 
 
 
 
 
 
 
 
19e3d6a
 
 
ef29d72
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from typing import List, Tuple, Union, Optional
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_transforms import resize, center_crop, normalize
from transformers.utils.generic import TensorType
from transformers.image_processing_utils import BatchFeature
from PIL import Image
import torchvision.transforms
import numpy as np


class EncoderImageProcessor(BaseImageProcessor):
    model_input_names: List[str] = ["pixel_values"]

    def __init__(
        self,
        input_size: Tuple[int, int] = (224, 224),
        mean: List[float] = [0.5, 0.5, 0.5],
        std: List[float] = [0.5, 0.5, 0.5],
        resize_strategy: str = "letterbox",
        **kwargs,
    ):
        """
        Initialize an image processor for the EncoderModel.

        Args:
            input_size (Tuple[int, int]): Size (height, width) of input images.
            mean (List[float]): Mean for normalization.
            std (List[float]): Std for normalization.
            resize_strategy (str): Resize strategy ("letterbox", "resize-crop", or "resize-naive").
        """
        self.input_size = input_size
        self.mean = mean
        self.std = std
        self.resize_strategy = resize_strategy

        super().__init__(**kwargs)

    def apply_transform(self, image: Image.Image) -> torch.Tensor:
        """
        Apply transformations: Resize -> CenterCrop -> Normalize.

        Args:
            image (Image.Image): A PIL image to transform.

        Returns:
            torch.Tensor: Transformed image tensor.
        """
        image = np.array(image)
        # Resize
        image = resize(image, size=self.input_size)

        # Center Crop
        image = center_crop(image, size=self.input_size)
        image = normalize(image, mean=self.mean, std=self.std)
        # Convert to tensor and normalize
        image = (
            torch.Tensor(image).to(torch.float32).permute(2, 0, 1) / 255.0
        )  # Convert to CHW format

        return image

    def preprocess(
        self,
        images: Union[Image.Image, List[Image.Image]],
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> BatchFeature:
        """
        Preprocess a batch of images.

        Args:
            images (Union[Image.Image, List[Image.Image]]): A single image or a list of images.
            return_tensors (Optional[Union[str, TensorType]]): Format for the output tensors ("pt" for PyTorch).

        Returns:
            BatchFeature: A batch feature with preprocessed images.
        """
        if not isinstance(images, list):
            images = [images]

        assert isinstance(images, list) and all(
            isinstance(item, (np.ndarray, Image.Image)) for item in images
        )
        if isinstance(images, Image.Image):
            images = [img.convert("RGB") for img in images]

        pixel_values = torch.stack([self.apply_transform(image) for image in images])

        # Handle tensor output type
        if return_tensors == "pt":
            return BatchFeature(data={"pixel_values": pixel_values})
        elif return_tensors is None:
            return BatchFeature(data={"pixel_values": pixel_values.numpy()})
        else:
            raise ValueError(f"Unsupported tensor type: {return_tensors}")

    def __call__(
        self, images: Union[Image.Image, List[Image.Image]], **kwargs
    ) -> BatchFeature:
        """
        Callable interface for preprocessing images.

        Args:
            images (Union[Image.Image, List[Image.Image]]): A single image or a list of images.

        Returns:
            BatchFeature: Preprocessed image batch.
        """
        return self.preprocess(images, **kwargs)