DongHyunKim commited on
Commit
863f351
·
1 Parent(s): a8b23e3

update code for transformers 4.53.3 compatibility

Browse files
Files changed (1) hide show
  1. processing_hyperclovax.py +226 -10
processing_hyperclovax.py CHANGED
@@ -9,23 +9,22 @@ import PIL
9
  from PIL import Image
10
  import torch
11
  from transformers.feature_extraction_utils import BatchFeature
12
- from transformers.image_utils import ImageInput
13
  from transformers.processing_utils import (
14
  AllKwargsForChatTemplate,
15
  ChatTemplateLoadKwargs,
16
  ProcessingKwargs,
17
  ProcessorMixin,
18
- TextInput,
19
  Unpack,
20
- VideoInput,
21
  )
22
- from transformers.tokenization_utils_base import AudioInput
23
  from transformers.utils import (
24
  is_torch_device,
25
  is_torch_dtype,
26
  logging,
27
  requires_backends,
28
  )
 
29
  from transformers.video_utils import VideoInput, VideoMetadata, load_video
30
 
31
  logger = logging.get_logger(__name__)
@@ -131,13 +130,223 @@ class HCXProcessor(ProcessorMixin):
131
  chat_template: Optional[str] = None,
132
  **kwargs: Unpack[AllKwargsForChatTemplate],
133
  ) -> str:
134
- model_inputs = super().apply_chat_template(conversation, chat_template, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # vllm needs vision_query_lengths, but we don't need it
137
- del model_inputs["vision_query_lengths_images"]
138
- del model_inputs["vision_query_lengths_videos"]
139
-
140
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def repeat_dummy_tokens(self, input_ids, target_token_id, vision_query_lengths):
143
  input_ids = input_ids.clone().detach()
@@ -202,6 +411,13 @@ class HCXProcessor(ProcessorMixin):
202
  logger.warning_once(f"fps control via argument is not supported yet. Ignored fps: {fps}.")
203
  logger.warning_once(f"backend control via argument is not supported yet. Ignored backend: {backend}.")
204
 
 
 
 
 
 
 
 
205
  def _hcx_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
206
  max_num_grids = output_kwargs["videos_kwargs"]["max_num_grids"]
207
  max_image_cnt = output_kwargs["videos_kwargs"]["max_image_cnt"]
 
9
  from PIL import Image
10
  import torch
11
  from transformers.feature_extraction_utils import BatchFeature
12
+ from transformers.image_utils import ImageInput, load_image
13
  from transformers.processing_utils import (
14
  AllKwargsForChatTemplate,
15
  ChatTemplateLoadKwargs,
16
  ProcessingKwargs,
17
  ProcessorMixin,
 
18
  Unpack,
 
19
  )
20
+ from transformers.tokenization_utils_base import AudioInput, TextInput
21
  from transformers.utils import (
22
  is_torch_device,
23
  is_torch_dtype,
24
  logging,
25
  requires_backends,
26
  )
27
+ from transformers.utils.chat_template_utils import render_jinja_template
28
  from transformers.video_utils import VideoInput, VideoMetadata, load_video
29
 
30
  logger = logging.get_logger(__name__)
 
130
  chat_template: Optional[str] = None,
131
  **kwargs: Unpack[AllKwargsForChatTemplate],
132
  ) -> str:
133
+ """
134
+ Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
135
+ conversations to turn them into a single tokenizable string.
136
+
137
+ The input is expected to be in the following format, where each message content is a list consisting of text and
138
+ optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
139
+ `pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
140
+
141
+ conversation = [
142
+ {
143
+ "role": "user",
144
+ "content": [
145
+ {"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
146
+ {"type": "text", "text": "Please describe this image in detail."},
147
+ ],
148
+ },
149
+ ]
150
+
151
+ Args:
152
+ conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
153
+ The conversation to format.
154
+ chat_template (`Optional[str]`, *optional*):
155
+ The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
156
+ chat template is used.
157
+ """
158
+
159
+ if chat_template is None:
160
+ if isinstance(self.chat_template, dict) and "default" in self.chat_template:
161
+ chat_template = self.chat_template["default"]
162
+ elif isinstance(self.chat_template, dict):
163
+ raise ValueError(
164
+ 'The processor has multiple chat templates but none of them are named "default". You need to specify'
165
+ " which one to use by passing the `chat_template` argument. Available templates are: "
166
+ f"{', '.join(self.chat_template.keys())}"
167
+ )
168
+ elif self.chat_template is not None:
169
+ chat_template = self.chat_template
170
+ else:
171
+ raise ValueError(
172
+ "Cannot use apply_chat_template because this processor does not have a chat template."
173
+ )
174
+ else:
175
+ if isinstance(self.chat_template, dict) and chat_template in self.chat_template:
176
+ # It's the name of a template, not a full template string
177
+ chat_template = self.chat_template[chat_template]
178
+ else:
179
+ # It's a template string, render it directly
180
+ chat_template = chat_template
181
+
182
+ if kwargs.get("continue_final_message", False):
183
+ if kwargs.get("add_generation_prompt", False):
184
+ raise ValueError(
185
+ "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
186
+ )
187
+ if kwargs.get("return_assistant_tokens_mask", False):
188
+ raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
189
 
190
+ # Fill sets of kwargs that should be used by different parts of template
191
+ processed_kwargs = {
192
+ "mm_load_kwargs": {},
193
+ "template_kwargs": {},
194
+ }
195
+
196
+ for kwarg_type in processed_kwargs:
197
+ for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys():
198
+ kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
199
+ default_value = getattr(kwarg_type_defaults, key, None)
200
+ value = kwargs.pop(key, default_value)
201
+ if value is not None and not isinstance(value, dict):
202
+ processed_kwargs[kwarg_type][key] = value
203
+
204
+ # Pass unprocessed custom kwargs
205
+ processed_kwargs["template_kwargs"].update(kwargs)
206
+
207
+ if isinstance(conversation, (list, tuple)) and (
208
+ isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
209
+ ):
210
+ is_batched = True
211
+ conversations = conversation
212
+ else:
213
+ is_batched = False
214
+ conversations = [conversation]
215
+
216
+ tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False)
217
+ return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False)
218
+ mm_load_kwargs = processed_kwargs["mm_load_kwargs"]
219
+
220
+ if tokenize:
221
+ batch_images, batch_videos = [], []
222
+ batch_audios = []
223
+ batch_video_metadata = []
224
+ for conversation in conversations:
225
+ images, videos = [], []
226
+ video_metadata = []
227
+ for message in conversation:
228
+ visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
229
+ audio_fnames = [
230
+ content[key]
231
+ for content in message["content"]
232
+ for key in ["audio", "url", "path"]
233
+ if key in content and content["type"] == "audio"
234
+ ]
235
+ image_fnames = [
236
+ vision_info[key]
237
+ for vision_info in visuals
238
+ for key in ["image", "url", "path", "base64"]
239
+ if key in vision_info and vision_info["type"] == "image"
240
+ ]
241
+ video_fnames = [
242
+ vision_info[key]
243
+ for vision_info in visuals
244
+ for key in ["video", "url", "path"]
245
+ if key in vision_info and vision_info["type"] == "video"
246
+ ]
247
+
248
+ for fname in image_fnames:
249
+ images.append(load_image(fname))
250
+
251
+ # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
252
+ if not mm_load_kwargs["load_audio_from_video"]:
253
+ for fname in audio_fnames:
254
+ batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
255
+ else:
256
+ for fname in video_fnames:
257
+ batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
258
+
259
+ for fname in video_fnames:
260
+ if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
261
+ video = [np.array(load_image(image_fname)) for image_fname in fname]
262
+ # create a 4D video because `load_video` always returns a 4D array
263
+ video = np.stack(video)
264
+ metadata = None
265
+ logger.warning(
266
+ "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
267
+ "If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
268
+ )
269
+ else:
270
+ # TODO: raushan, should be `self.video_processor.load_video_for_model` when API is added
271
+ video, metadata = self._load_video_for_model(
272
+ fname,
273
+ num_frames=mm_load_kwargs.get("num_frames", None),
274
+ fps=mm_load_kwargs.get("video_fps", None),
275
+ backend=mm_load_kwargs["video_load_backend"],
276
+ **kwargs,
277
+ )
278
+ videos.append(video)
279
+ video_metadata.append(metadata)
280
+
281
+ # Currently all processors can accept nested list of batches, but not flat list of visuals
282
+ # So we'll make a batched list of images and let the processor handle it
283
+ if images:
284
+ batch_images.append(images)
285
+ if videos:
286
+ batch_videos.append(videos)
287
+ batch_video_metadata.append(video_metadata)
288
+
289
+ # Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
290
+ conversations = self._process_messages_for_chat_template(
291
+ conversations,
292
+ batch_images=batch_images,
293
+ batch_videos=batch_videos,
294
+ batch_video_metadata=batch_video_metadata,
295
+ **processed_kwargs["mm_load_kwargs"],
296
+ )
297
+
298
+ prompt, generation_indices = render_jinja_template(
299
+ conversations=conversations,
300
+ chat_template=chat_template,
301
+ **processed_kwargs["template_kwargs"], # different flags such as `return_assistant_mask`
302
+ **self.tokenizer.special_tokens_map, # tokenizer special tokens are used by some templates
303
+ )
304
+
305
+ if not is_batched:
306
+ prompt = prompt[0]
307
+
308
+ if tokenize:
309
+ # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
310
+ # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
311
+ # and pass it to the processor. Users thus never worried about special tokens relying on processor handling
312
+ # everything internally. The below line is to keep BC for that and be able to work with model that have
313
+ # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
314
+ # without actionable solution for users
315
+ single_prompt = prompt[0] if is_batched else prompt
316
+ if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
317
+ kwargs["add_special_tokens"] = False
318
+
319
+ out = self(
320
+ text=prompt,
321
+ images=batch_images if batch_images else None,
322
+ videos=batch_videos if batch_videos else None,
323
+ audio=batch_audios if batch_audios else None,
324
+ **kwargs,
325
+ )
326
+ if return_dict:
327
+ if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
328
+ assistant_masks = []
329
+ input_ids = out["input_ids"]
330
+ for i in range(len(input_ids)):
331
+ current_mask = [0] * len(input_ids[i])
332
+ for assistant_start_char, assistant_end_char in generation_indices[i]:
333
+ start_token = out.char_to_token(i, assistant_start_char)
334
+ end_token = out.char_to_token(i, assistant_end_char - 1)
335
+ if start_token is None:
336
+ # start_token is out of bounds maybe due to truncation.
337
+ break
338
+ for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
339
+ current_mask[token_id] = 1
340
+ assistant_masks.append(current_mask)
341
+ out["assistant_masks"] = assistant_masks
342
+ out.convert_to_tensors(tensor_type=kwargs.get("return_tensors", None))
343
+
344
+ # vllm needs vision_query_lengths, but hf model doesn't need it
345
+ del out["vision_query_lengths_images"]
346
+ del out["vision_query_lengths_videos"]
347
+ return out
348
+ else:
349
+ return out["input_ids"]
350
 
351
  def repeat_dummy_tokens(self, input_ids, target_token_id, vision_query_lengths):
352
  input_ids = input_ids.clone().detach()
 
411
  logger.warning_once(f"fps control via argument is not supported yet. Ignored fps: {fps}.")
412
  logger.warning_once(f"backend control via argument is not supported yet. Ignored backend: {backend}.")
413
 
414
+ # video_loaded, video_metadata = load_video(
415
+ # video, backend="decord", num_frames=32
416
+ # )
417
+ # frame_interval = int(video_metadata.total_num_frames / 32)
418
+ # time_interval = frame_interval / video_metadata.fps
419
+ # video_metadata.time_interval = time_interval
420
+
421
  def _hcx_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
422
  max_num_grids = output_kwargs["videos_kwargs"]["max_num_grids"]
423
  max_image_cnt = output_kwargs["videos_kwargs"]["max_image_cnt"]