AdityasArsenal commited on
Commit
add38c3
·
verified ·
1 Parent(s): 1dbe669

Create feature_extractor.py

Browse files
Files changed (1) hide show
  1. feature_extractor.py +36 -0
feature_extractor.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedFeatureExtractor
2
+ import tensorflow as tf
3
+
4
+ class CustomImageProcessor(PreTrainedFeatureExtractor):
5
+ def __init__(self, image_size=(224, 224), normalize=True):
6
+ super().__init__()
7
+ self.image_size = image_size
8
+ self.normalize = normalize
9
+
10
+ def __call__(self, images, return_tensors="pt"):
11
+ """
12
+ Preprocesses the images (resize, normalize, and format for the model).
13
+ Args:
14
+ images: List or batch of images in PIL or Tensor format.
15
+ return_tensors: Specify the format to return ('pt' for PyTorch, 'tf' for TensorFlow).
16
+ Returns:
17
+ Processed image tensor ready for model inference.
18
+ """
19
+ # Resize images to the required shape
20
+ images = [tf.image.resize(image, self.image_size) for image in images]
21
+
22
+ # Normalize the image if required (scale pixel values to [0, 1])
23
+ if self.normalize:
24
+ images = [image / 255.0 for image in images]
25
+
26
+ # Convert to the required tensor format (tf.Tensor for TensorFlow models)
27
+ images = tf.stack(images)
28
+
29
+ if return_tensors == "tf":
30
+ return {"pixel_values": images}
31
+ elif return_tensors == "pt":
32
+ # Convert to PyTorch tensor if needed (for PyTorch models)
33
+ import torch
34
+ return {"pixel_values": torch.tensor(images.numpy())}
35
+ else:
36
+ return {"pixel_values": images}