izhx commited on
Commit
f70ad8d
·
verified ·
1 Parent(s): 2df56dc

Update custom_st.py

Browse files
Files changed (1) hide show
  1. custom_st.py +149 -59
custom_st.py CHANGED
@@ -59,73 +59,163 @@ class MultiModalTransformer(BaseTransformer):
59
  )
60
  image_mask = features["input_ids"] == self.auto_model.config.image_token_id
61
  features["inputs_embeds"][image_mask] = image_embeds
62
- features.pop("pixel_values")
63
- features.pop("image_grid_thw")
64
- features.pop("input_ids")
 
65
  outputs = self.auto_model.model(
66
- **features,
67
  return_dict=True,
68
  output_hidden_states=True,
69
  # **kwargs
70
  )
71
- pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
72
- left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
73
- if left_padding:
74
- embeddings = outputs.last_hidden_state
75
- else:
76
- sequence_lengths = pooling_mask.sum(dim=1) - 1
77
- embeddings = outputs.last_hidden_state[torch.arange(
78
- outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
79
- ), sequence_lengths]
80
- features.update({"token_embeddings": embeddings})
81
  return features
82
 
83
- def tokenize(self, texts: List[List[Dict[str, Image.Image]]] | List[str]) -> Dict[str, torch.Tensor]:
84
- split_token = "<|im_end|>\n"
85
- def process_text_item(item):
 
 
86
  if isinstance(item, str):
87
- return item, None
88
-
89
- text, img = "", None
90
- if "image" in item:
91
- text += "<|vision_start|><|image_pad|><|vision_end|>"
92
- img = item["image"]
93
- if isinstance(img, bytes):
94
- img = Image.open(BytesIO(img)).convert("RGB")
95
- elif isinstance(img, str):
96
- img = Image.open(img).convert("RGB")
97
- elif not isinstance(img, Image):
98
- raise ValueError(f"Unknown image type {type(img)}")
99
- if "text" in item:
100
- text += item["text"].lstrip()
101
- if split_token in text:
102
- instruction, text = text.split(split_token, 1)
103
- text = f'{instruction}{split_token}<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
104
  else:
105
- text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
106
- return text, img
107
 
108
- all_texts, all_images = [], []
109
- for item in texts:
110
- text, images = process_text_item(item)
111
- all_texts.append(text)
112
- all_images.append(images)
113
-
114
- if all_images != [None] * len(all_images):
115
- inputs = self.processor(
116
- text=all_texts,
117
- images=all_images,
118
- padding="longest",
119
- truncation=True,
120
- max_length=self.max_seq_length,
121
- return_tensors="pt"
122
- )
123
- else:
124
- inputs = self.processor(
125
- text=all_texts,
126
- padding="longest",
127
- truncation=True,
128
- max_length=self.max_seq_length,
129
- return_tensors="pt"
130
- )
131
  return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
  image_mask = features["input_ids"] == self.auto_model.config.image_token_id
61
  features["inputs_embeds"][image_mask] = image_embeds
62
+ # features.pop("pixel_values")
63
+ # features.pop("image_grid_thw")
64
+ # features.pop("input_ids")
65
+ inputs = {k: v for k, v in features.items() if k in 'position_ids,attention_mask,inputs_embeds'}
66
  outputs = self.auto_model.model(
67
+ **inputs,
68
  return_dict=True,
69
  output_hidden_states=True,
70
  # **kwargs
71
  )
72
+ # pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
73
+ # left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
74
+ # if left_padding:
75
+ # embeddings = outputs.last_hidden_state
76
+ # else:
77
+ # sequence_lengths = pooling_mask.sum(dim=1) - 1
78
+ # embeddings = outputs.last_hidden_state[torch.arange(
79
+ # outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
80
+ # ), sequence_lengths]
81
+ features.update({"token_embeddings": outputs.last_hidden_state})
82
  return features
83
 
84
+ def tokenize(self, texts: List[List[Dict[str, Any]]] | List[str]) -> Dict[str, torch.Tensor]:
85
+ default_instruction = 'You are a helpful assistant.'
86
+
87
+ all_texts, all_images = list(), list()
88
+ for item in texts:
89
  if isinstance(item, str):
90
+ txt, img, inst = item, None, default_instruction
91
+ elif isinstance(item, dict):
92
+ txt = item.get('text', None)
93
+ img = item.get('image', None)
94
+ inst = item.get('prompt', default_instruction)
 
 
 
 
 
 
 
 
 
 
 
 
95
  else:
96
+ raise RuntimeError(f'Input format not supported! {item=}')
 
97
 
98
+ input_str = ''
99
+ if img is None:
100
+ all_images = None # All examples in the same batch are consistent
101
+ # or will have ValueError: Could not make a flat list of images from xxxx
102
+ else:
103
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
104
+ img = fetch_image(img)
105
+ all_images.append(img)
106
+ if txt is not None:
107
+ input_str += txt
108
+ msg = f'<|im_start|>system\n{inst}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
109
+ all_texts.append(msg)
110
+
111
+ inputs = self.processor(
112
+ text=all_texts,
113
+ images=all_images,
114
+ padding="longest",
115
+ truncation=True,
116
+ max_length=self.max_seq_length,
117
+ return_tensors='pt'
118
+ )
 
 
119
  return inputs
120
+
121
+
122
+ ### Copied from qwen_vl_utils.vision_process.py
123
+ import base64
124
+ from io import BytesIO
125
+ import requests
126
+
127
+ IMAGE_FACTOR = 28
128
+ MIN_PIXELS = 4 * 28 * 28
129
+ MAX_PIXELS = 16384 * 28 * 28
130
+ MAX_RATIO = 200
131
+
132
+
133
+ def round_by_factor(number: int, factor: int) -> int:
134
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
135
+ return round(number / factor) * factor
136
+
137
+
138
+ def ceil_by_factor(number: int, factor: int) -> int:
139
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
140
+ return math.ceil(number / factor) * factor
141
+
142
+
143
+ def floor_by_factor(number: int, factor: int) -> int:
144
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
145
+ return math.floor(number / factor) * factor
146
+
147
+
148
+ def smart_resize(
149
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
150
+ ) -> tuple[int, int]:
151
+ """
152
+ Rescales the image so that the following conditions are met:
153
+
154
+ 1. Both dimensions (height and width) are divisible by 'factor'.
155
+
156
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
157
+
158
+ 3. The aspect ratio of the image is maintained as closely as possible.
159
+ """
160
+ h_bar = max(factor, round_by_factor(height, factor))
161
+ w_bar = max(factor, round_by_factor(width, factor))
162
+ if h_bar * w_bar > max_pixels:
163
+ beta = math.sqrt((height * width) / max_pixels)
164
+ h_bar = floor_by_factor(height / beta, factor)
165
+ w_bar = floor_by_factor(width / beta, factor)
166
+ elif h_bar * w_bar < min_pixels:
167
+ beta = math.sqrt(min_pixels / (height * width))
168
+ h_bar = ceil_by_factor(height * beta, factor)
169
+ w_bar = ceil_by_factor(width * beta, factor)
170
+
171
+ if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
172
+ logging.warning(
173
+ f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
174
+ )
175
+ if h_bar > w_bar:
176
+ h_bar = w_bar * MAX_RATIO
177
+ else:
178
+ w_bar = h_bar * MAX_RATIO
179
+ return h_bar, w_bar
180
+
181
+
182
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
183
+ image_obj = None
184
+ if isinstance(image, Image.Image):
185
+ image_obj = image
186
+ elif image.startswith("http://") or image.startswith("https://"):
187
+ image_obj = Image.open(requests.get(image, stream=True).raw)
188
+ elif image.startswith("file://"):
189
+ image_obj = Image.open(image[7:])
190
+ elif image.startswith("data:image"):
191
+ if "base64," in image:
192
+ _, base64_data = image.split("base64,", 1)
193
+ data = base64.b64decode(base64_data)
194
+ image_obj = Image.open(BytesIO(data))
195
+ else:
196
+ image_obj = Image.open(image)
197
+ if image_obj is None:
198
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
199
+ image = image_obj.convert("RGB")
200
+ ## resize
201
+ # if "resized_height" in ele and "resized_width" in ele:
202
+ # resized_height, resized_width = smart_resize(
203
+ # ele["resized_height"],
204
+ # ele["resized_width"],
205
+ # factor=size_factor,
206
+ # )
207
+ # else:
208
+ width, height = image.size
209
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
210
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
211
+ resized_height, resized_width = smart_resize(
212
+ height,
213
+ width,
214
+ factor=size_factor,
215
+ min_pixels=MIN_PIXELS,
216
+ max_pixels=MAX_PIXELS,
217
+ )
218
+ image = image.resize((resized_width, resized_height))
219
+
220
+ return image
221
+ ###