|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
import torch |
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
from transformers.image_utils import ImageInput, make_list_of_images |
|
from transformers.models.clip import CLIPProcessor |
|
|
|
from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform |
|
|
|
""" Jina CLIP processor implementation """ |
|
|
|
|
|
class JinaCLIPProcessor(CLIPProcessor): |
|
image_processor_class = 'AutoImageProcessor' |
|
tokenizer_class = 'AutoTokenizer' |
|
|
|
|
|
""" Jina CLIP image processor implementation """ |
|
|
|
|
|
class JinaCLIPImageProcessor(BaseImageProcessor): |
|
model_input_names = ['pixel_values'] |
|
_valid_processor_keys = [ |
|
'size', |
|
'mean', |
|
'std', |
|
'resize_mode', |
|
'interpolation', |
|
'fill_color', |
|
] |
|
|
|
def __init__( |
|
self, |
|
size: Union[int, Tuple[int, int]] = 224, |
|
mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN, |
|
std: Union[float, Tuple[float]] = OPENAI_DATASET_STD, |
|
resize_mode: str = 'shortest', |
|
interpolation: str = 'bicubic', |
|
fill_color: int = 0, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.size = size |
|
self.mean = mean |
|
self.std = std |
|
self.resize_mode = resize_mode |
|
self.interpolation = interpolation |
|
self.fill_color = fill_color |
|
self.transform = self._build_transform() |
|
|
|
def _build_transform(self): |
|
return image_transform( |
|
image_size=self.size, |
|
is_train=False, |
|
mean=self.mean, |
|
std=self.std, |
|
resize_mode=self.resize_mode, |
|
interpolation=self.interpolation, |
|
fill_color=self.fill_color, |
|
aug_cfg=None, |
|
) |
|
|
|
def to_dict(self): |
|
output = super().to_dict() |
|
output.pop('transform') |
|
return output |
|
|
|
def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature: |
|
|
|
_transform_needs_rebuild = False |
|
for k, v in kwargs.items(): |
|
if k in self._valid_processor_keys: |
|
if v != getattr(self, k): |
|
setattr(self, k, v) |
|
_transform_needs_rebuild = True |
|
|
|
if _transform_needs_rebuild: |
|
self.transform = self._build_transform() |
|
|
|
images = make_list_of_images(images) |
|
out = torch.stack([self.transform(image) for image in images], dim=0) |
|
return BatchFeature(data={'pixel_values': out}) |
|
|