229nagibator229 commited on
Commit
19e3d6a
·
verified ·
1 Parent(s): d3bb088

Upload processor

Browse files
Files changed (1) hide show
  1. processing_encoder.py +14 -5
processing_encoder.py CHANGED
@@ -5,7 +5,7 @@ from transformers.image_transforms import resize, center_crop, normalize
5
  from transformers.utils.generic import TensorType
6
  from transformers.image_processing_utils import BatchFeature
7
  from PIL import Image
8
- import torchvision.transforms
9
  import numpy as np
10
 
11
 
@@ -54,8 +54,9 @@ class EncoderImageProcessor(BaseImageProcessor):
54
  image = center_crop(image, size=self.input_size)
55
  image = normalize(image, mean=self.mean, std=self.std)
56
  # Convert to tensor and normalize
57
- image = torch.Tensor(image).to(torch.float32).permute(2,0,1) / 255.0 # Convert to CHW format
58
-
 
59
 
60
  return image
61
 
@@ -78,7 +79,13 @@ class EncoderImageProcessor(BaseImageProcessor):
78
  if not isinstance(images, list):
79
  images = [images]
80
 
81
- pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
 
 
 
 
 
 
82
 
83
  # Handle tensor output type
84
  if return_tensors == "pt":
@@ -88,7 +95,9 @@ class EncoderImageProcessor(BaseImageProcessor):
88
  else:
89
  raise ValueError(f"Unsupported tensor type: {return_tensors}")
90
 
91
- def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
 
 
92
  """
93
  Callable interface for preprocessing images.
94
 
 
5
  from transformers.utils.generic import TensorType
6
  from transformers.image_processing_utils import BatchFeature
7
  from PIL import Image
8
+ import torchvision.transforms
9
  import numpy as np
10
 
11
 
 
54
  image = center_crop(image, size=self.input_size)
55
  image = normalize(image, mean=self.mean, std=self.std)
56
  # Convert to tensor and normalize
57
+ image = (
58
+ torch.Tensor(image).to(torch.float32).permute(2, 0, 1) / 255.0
59
+ ) # Convert to CHW format
60
 
61
  return image
62
 
 
79
  if not isinstance(images, list):
80
  images = [images]
81
 
82
+ assert isinstance(images, list) and all(
83
+ isinstance(item, (np.ndarray, Image.Image)) for item in images
84
+ )
85
+ if isinstance(images, Image.Image):
86
+ images = [img.convert("RGB") for img in images]
87
+
88
+ pixel_values = torch.stack([self.apply_transform(image) for image in images])
89
 
90
  # Handle tensor output type
91
  if return_tensors == "pt":
 
95
  else:
96
  raise ValueError(f"Unsupported tensor type: {return_tensors}")
97
 
98
+ def __call__(
99
+ self, images: Union[Image.Image, List[Image.Image]], **kwargs
100
+ ) -> BatchFeature:
101
  """
102
  Callable interface for preprocessing images.
103