Munaza10 commited on
Commit
c333b3a
·
verified ·
1 Parent(s): 3417784

Delete tokenizer_wrapper.py

Browse files
Files changed (1) hide show
  1. tokenizer_wrapper.py +0 -1426
tokenizer_wrapper.py DELETED
@@ -1,1426 +0,0 @@
1
- # Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
2
- # you may not use this file except in compliance with the License.
3
- # You may obtain a copy of the License at
4
- #
5
- # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE
6
- #
7
- # Unless required by applicable law or agreed to in writing, software
8
- # distributed under the License is distributed on an "AS IS" BASIS,
9
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
- # See the License for the specific language governing permissions and
11
- # limitations under the License.
12
- # ==============================================================================
13
-
14
- import warnings
15
- import random
16
- from typing import List, Optional, Union, Dict, Any
17
- from collections import defaultdict
18
- from copy import deepcopy
19
-
20
- import numpy as np
21
- import torch
22
- import torch.nn.functional as F
23
- from transformers import AutoTokenizer
24
- from diffusers.utils import BaseOutput
25
-
26
-
27
- def default(value, default_value):
28
- return value if value is not None else default_value
29
-
30
-
31
- def ensure_list(value):
32
- if value is None:
33
- return []
34
- if isinstance(value, (list, tuple)):
35
- return list(value)
36
- return [value]
37
-
38
-
39
- class Resolution(object):
40
- def __init__(self, size, *args):
41
- if isinstance(size, str):
42
- if 'x' in size:
43
- size = size.split('x')
44
- size = (int(size[0]), int(size[1]))
45
- else:
46
- size = int(size)
47
- if len(args) > 0:
48
- size = (size, args[0])
49
- if isinstance(size, int):
50
- size = (size, size)
51
-
52
- self.h = self.height = size[0]
53
- self.w = self.width = size[1]
54
- self.r = self.ratio = self.height / self.width
55
-
56
- def __getitem__(self, idx):
57
- if idx == 0:
58
- return self.h
59
- elif idx == 1:
60
- return self.w
61
- else:
62
- raise IndexError(f'Index {idx} out of range')
63
-
64
- def __str__(self):
65
- return f'{self.h}x{self.w}'
66
-
67
-
68
- class ResolutionGroup(object):
69
- def __init__(self, base_size=None, step=None, align=1):
70
- self.align = align
71
- self.base_size = base_size
72
- assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}'
73
- if base_size is not None and not isinstance(base_size, int):
74
- raise ValueError(f'base_size must be None or int, but got {type(base_size)}')
75
- if step is None:
76
- step = base_size // 16
77
- if step is not None and step > base_size // 2:
78
- raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}')
79
-
80
- self.step = step
81
- self.data = self._calc_by_step()
82
-
83
- self.ratio = np.array([x.ratio for x in self.data])
84
- self.attr = ['' for _ in range(len(self.data))]
85
- self.prefix_space = 0
86
-
87
- def __len__(self):
88
- return len(self.data)
89
-
90
- def __getitem__(self, idx):
91
- return self.data[idx]
92
-
93
- def __repr__(self):
94
- prefix = self.prefix_space * ' '
95
- prefix_close = (self.prefix_space - 4) * ' '
96
- res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data='
97
- attr_maxlen = max([len(x) for x in self.attr] + [5])
98
- res_str += \
99
- f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}'
100
- res_str += \
101
- ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} '
102
- f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}'
103
- for i, x in enumerate(self.data)])
104
- res_str += f'\n{prefix_close})'
105
- return res_str
106
-
107
- def _calc_by_step(self):
108
- assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}'
109
-
110
- min_height = self.base_size // 2
111
- min_width = self.base_size // 2
112
- max_height = self.base_size * 2
113
- max_width = self.base_size * 2
114
-
115
- resolutions = [Resolution(self.base_size, self.base_size)]
116
-
117
- cur_height, cur_width = self.base_size, self.base_size
118
- while True:
119
- if cur_height >= max_height and cur_width <= min_width:
120
- break
121
-
122
- cur_height = min(cur_height + self.step, max_height)
123
- cur_width = max(cur_width - self.step, min_width)
124
- resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align))
125
-
126
- cur_height, cur_width = self.base_size, self.base_size
127
- while True:
128
- if cur_height <= min_height and cur_width >= max_width:
129
- break
130
-
131
- cur_height = max(cur_height - self.step, min_height)
132
- cur_width = min(cur_width + self.step, max_width)
133
- resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align))
134
-
135
- resolutions = sorted(resolutions, key=lambda x: x.ratio)
136
-
137
- return resolutions
138
-
139
- def get_target_size(self, width, height):
140
- ratio = height / width
141
- idx = np.argmin(np.abs(self.ratio - ratio))
142
- reso = self.data[idx]
143
- return reso.w, reso.h
144
-
145
- def get_base_size_and_ratio_index(self, width, height):
146
- ratio = height / width
147
- idx = np.argmin(np.abs(self.ratio - ratio))
148
- return self.base_size, idx
149
-
150
-
151
- class ImageInfo:
152
- """ Class to store image information for processing and generation. """
153
-
154
- def __init__(
155
- self,
156
- image_type: str = None,
157
- image_tensor: torch.Tensor = None,
158
- image_width: int = None,
159
- image_height: int = None,
160
- token_width: int = None,
161
- token_height: int = None,
162
- image_token_length: int = None,
163
- base_size: int = None,
164
- ratio_index: int = None,
165
- **kwargs,
166
- ):
167
- self.image_type = image_type
168
- self.image_tensor = image_tensor
169
- self.image_width = image_width
170
- self.w = image_width
171
- self.image_height = image_height
172
- self.h = image_height
173
- self.token_width = token_width
174
- self.tk_w = token_width
175
- self.token_height = token_height
176
- self.tk_h = token_height
177
- self.image_token_length = default(
178
- image_token_length,
179
- token_width * token_height if token_width is not None and token_height is not None else None
180
- )
181
- self.base_size = base_size
182
- self.ratio_index = ratio_index
183
-
184
- self.add_timestep_token = kwargs.get("add_timestep_token", True)
185
- self.add_guidance_token = kwargs.get("add_guidance_token", False)
186
- self.use_front_boi_token = kwargs.get("use_front_boi_token", True)
187
- self.add_image_shape_token = kwargs.get("add_image_shape_token", True)
188
-
189
- def __getitem__(self, key: str) -> Any:
190
- """Allow dictionary-like access to attributes."""
191
- if hasattr(self, key):
192
- return getattr(self, key)
193
- raise KeyError(f"Key '{key}' not found in ImageInfo")
194
-
195
- def __setitem__(self, key: str, value: Any) -> None:
196
- """Allow dictionary-like assignment to attributes."""
197
- if hasattr(self, key):
198
- setattr(self, key, value)
199
- else:
200
- raise KeyError(f"Key '{key}' not found in ImageInfo")
201
-
202
- def __contains__(self, key: str) -> bool:
203
- """Check if the key exists in the ImageInfo object."""
204
- return hasattr(self, key)
205
-
206
- def __repr__(self):
207
- return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, "
208
- f"image_width={self.image_width}, image_height={self.image_height}, "
209
- f"token_width={self.token_width}, token_height={self.token_height}, "
210
- f"image_token_length={self.image_token_length}, "
211
- f"base_size={self.base_size}, ratio_index={self.ratio_index}")
212
-
213
- @property
214
- def meta_info(self):
215
- # Used for image sections of tkwrapper.encode_general()
216
- if self.image_type in ["vae", "gen_image"]:
217
- return dict(
218
- token_length=self.image_token_length,
219
- add_timestep_token=self.add_timestep_token,
220
- add_guidance_token=self.add_guidance_token,
221
- use_front_boi_token=self.use_front_boi_token,
222
- add_image_shape_token=self.add_image_shape_token,
223
- base_size=self.base_size,
224
- ratio_idx=self.ratio_index,
225
- # for rope 2d
226
- token_height=self.token_height,
227
- token_width=self.token_width,
228
- # for bc
229
- image_height=self.image_height,
230
- image_width=self.image_width,
231
- )
232
- elif self.image_type in ["vit"]:
233
- return dict(
234
- token_length=self.image_token_length,
235
- use_front_boi_token=self.use_front_boi_token,
236
- add_image_shape_token=self.add_image_shape_token,
237
- # for rope 2d
238
- token_height=self.token_height,
239
- token_width=self.token_width,
240
- # for bc
241
- image_height=self.image_height,
242
- image_width=self.image_width,
243
- )
244
- else:
245
- raise ValueError(f"Unknown image type '{self.image_type}'")
246
-
247
- @property
248
- def num_special_tokens(self):
249
- if self.args is None:
250
- raise ValueError("meta_info requires `args` attribute to be set.")
251
- if self.image_type in ["vae", "src_image", "gen_image"]:
252
- count = (
253
- 2 + # <boi> + <eoi> or <src_boi> + <src_eoi>
254
- (1 if self.add_timestep_token else 0) +
255
- (1 if self.add_guidance_token else 0) +
256
- (2 if self.add_image_shape_token else 0)
257
- )
258
- else:
259
- raise ValueError(f"Unknown image_type: {self.image_type}")
260
- return count
261
-
262
- def copy(self, copy_image_tensor=True):
263
- if copy_image_tensor and self.image_tensor is None:
264
- raise ValueError("image_tensor is None, cannot copy")
265
- return ImageInfo(
266
- image_type=self.image_type,
267
- image_tensor=self.image_tensor.clone() if copy_image_tensor else None,
268
- image_width=self.image_width,
269
- image_height=self.image_height,
270
- token_width=self.token_width,
271
- token_height=self.token_height,
272
- image_token_length=self.image_token_length,
273
- base_size=self.base_size,
274
- ratio_index=self.ratio_index,
275
- )
276
-
277
- def zeros_(self):
278
- self.image_tensor = torch.zeros_like(self.image_tensor)
279
-
280
-
281
- class ImageTensor(torch.Tensor):
282
- # This class is just for type hinting purposes. Attribute `i` should be defined
283
- # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...)
284
- i: ImageInfo
285
- vision_encoder_kwargs: dict
286
-
287
-
288
- class JointImageInfo(object):
289
- def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None):
290
- self.vae_image_info = vae_image_info
291
- self.vision_image_info = vision_image_info
292
- self.vision_encoder_kwargs = vision_encoder_kwargs
293
-
294
- # Define key attributes to align with ImageInfo for uniformity
295
- self.image_type = "joint_image"
296
- self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length
297
-
298
- self.add_timestep_token = vae_image_info.add_timestep_token
299
- self.use_front_boi_token = vae_image_info.use_front_boi_token
300
- self.add_image_shape_token = vae_image_info.add_image_shape_token
301
-
302
- def __repr__(self):
303
- return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})"
304
-
305
- @property
306
- def meta_info(self):
307
- # Used for image sections of tkwrapper.encode_general()
308
- return dict(
309
- token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length],
310
- add_timestep_token=self.add_timestep_token,
311
- use_front_boi_token=self.use_front_boi_token,
312
- add_image_shape_token=self.add_image_shape_token,
313
- base_size=self.vae_image_info.base_size,
314
- ratio_idx=self.vae_image_info.ratio_index,
315
- # for rope 2d
316
- token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height],
317
- token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width],
318
- # for bc
319
- image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height],
320
- image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width],
321
- )
322
-
323
- @property
324
- def num_special_tokens(self):
325
- return (
326
- 2 + # <boi> + <eoi>
327
- (1 if self.add_timestep_token else 0) +
328
- (2 if self.add_image_shape_token else 0) +
329
- 1 # <joint_image_sep>
330
- )
331
-
332
- def copy(self, copy_image_tensor=True):
333
- if copy_image_tensor and (
334
- self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None):
335
- raise ValueError("image_tensor is None, cannot copy")
336
- return JointImageInfo(
337
- self.vae_image_info.copy(copy_image_tensor),
338
- self.vision_image_info.copy(copy_image_tensor),
339
- self.vision_encoder_kwargs,
340
- )
341
-
342
- def zeros_(self):
343
- self.vae_image_info.zeros_()
344
- self.vision_image_info.zeros_()
345
-
346
-
347
- class JointImage(object):
348
- def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor):
349
- self.vae_image = vae_image
350
- self.vision_image = vision_image
351
- self.i = JointImageInfo(vae_image.i, vision_image.i)
352
-
353
-
354
- class TokenizerEncodeOutput(BaseOutput):
355
- tokens: torch.Tensor = None
356
- timestep_scatter_index: Optional[torch.Tensor] = None
357
- guidance_scatter_index: Optional[torch.Tensor] = None
358
- text_slices: Optional[List[slice]] = None
359
- gen_image_slices: Optional[List[slice]] = None
360
- joint_image_slices: Optional[List[slice]] = None
361
- cond_vae_image_slices: Optional[List[slice]] = None
362
- cond_vit_image_slices: Optional[List[slice]] = None
363
- text_mask: Optional[torch.Tensor] = None
364
- gen_image_mask: Optional[torch.Tensor] = None
365
- cond_vae_image_mask: Optional[torch.Tensor] = None
366
- cond_vit_image_mask: Optional[torch.Tensor] = None
367
- real_pos: Optional[torch.Tensor] = None
368
- all_image_slices: Optional[List[slice]] = None
369
- cond_timestep_scatter_index: Optional[torch.Tensor] = None
370
- gen_timestep_scatter_index: Optional[torch.Tensor] = None
371
-
372
-
373
- class Conversation:
374
- roles: List[str] = ["User", "Assistant"]
375
- sep: str = "\n\n"
376
-
377
-
378
- class TokenizerWrapper(object):
379
- def __init__(self, tokenizer):
380
- if isinstance(tokenizer, str):
381
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
382
- else:
383
- self.tokenizer = tokenizer
384
-
385
- # Define short names
386
- self.bos_token_id = self.tokenizer.bos_token_id
387
- self.eos_token_id = self.tokenizer.eos_token_id
388
- self.pad_token_id = self.tokenizer.pad_token_id
389
- self.boi_token_id = self.tokenizer.convert_tokens_to_ids("<boi>")
390
- self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("<eoi>")
391
- self.img_token_id = self.tokenizer.convert_tokens_to_ids("<img>")
392
- self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("<cfg>")
393
- self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("</answer>")
394
- self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("</recaption>")
395
- self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("<img_ratio_0>")
396
- self.special_token_map = self.tokenizer.added_tokens_encoder
397
-
398
- def pad(self, tensor_list, dim=0, pad_val=None):
399
- if pad_val is None:
400
- pad_val = self.pad_token_id
401
- max_len = max([t.shape[dim] for t in tensor_list])
402
- padded_tensor_list = []
403
- for t in tensor_list:
404
- if t.shape[dim] < max_len:
405
- assert pad_val is not False, "Not allowed pad."
406
- t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val)
407
- padded_tensor_list.append(t)
408
- return padded_tensor_list
409
-
410
- def encode(self, *args, **kwargs):
411
- return self.tokenizer.encode(*args, **kwargs)
412
-
413
- def decode(self, *args, **kwargs):
414
- return self.tokenizer.decode(*args, **kwargs)
415
-
416
- def encode_text(
417
- self,
418
- *texts,
419
- uncond_enabled: Optional[Union[bool, List[bool]]] = None,
420
- uncond_p: Optional[float] = None,
421
- max_length: Optional[int] = None,
422
- pad: Optional[str] = None,
423
- return_lengths: bool = False,
424
- ):
425
- """
426
- Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks.
427
- Support encode multiple texts at once. Each text can be separately conditioned or unconditioned
428
- based on the uncond_flags and a uniform uncond_p.
429
- **<bos> token is always prepended to the text tokens.**
430
-
431
- Parameters
432
- ----------
433
- texts: str or List[str]
434
- List of texts to be encoded.
435
- uncond_enabled: bool or List[bool]
436
- List of flags to indicate whether the text should be unconditioned.
437
- If False, the text will never be unconditioned.
438
- If True, the text will be unconditioned with uncond_p.
439
- uncond_p: float
440
- Probability to the unconditional text. Only works when uncond_enabled is True.
441
- max_length: int
442
- Maximum length of the encoded text.
443
- pad: Optional[str]
444
- Padding method. Can be 'left' or 'right'.
445
- return_lengths: bool
446
- Whether to return the length of each encoded text.
447
- """
448
- if pad is not None:
449
- assert max_length is not None, "max_length should be provided when pad is not None."
450
-
451
- if uncond_enabled is None:
452
- uncond_enabled = [True] * len(texts)
453
- elif isinstance(uncond_enabled, bool):
454
- uncond_enabled = [uncond_enabled] * len(texts)
455
- if len(uncond_enabled) != len(texts):
456
- print(uncond_enabled, texts)
457
- assert len(uncond_enabled) == len(texts), (
458
- f"Length of uncond_flags should be equal to the number of texts, "
459
- f"but got {len(uncond_enabled)} and {len(texts)}."
460
- )
461
-
462
- # Prepare text/uncond tokens
463
- # TODO: If len(texts) > 1, such as instruction + prompt in inpainting, we need to determine how to do uncond.
464
- # Now all texts will be cond or uncond at the same time.
465
- do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p)
466
- text_tokens, lengths = [], []
467
- cum_length = 0
468
- for text, uncond_flag in zip(texts, uncond_enabled):
469
- # If reach the max_length and there still have unencoded texts, give a warning message and break the loop.
470
- if max_length is not None and cum_length >= max_length:
471
- warnings.warn(
472
- f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: "
473
- f"{text[:80]}..."
474
- )
475
- break
476
- # Set add_special_tokens=False to avoid adding <bos> token in some LLMs.
477
- if isinstance(text, str):
478
- text_token = self.tokenizer.encode(text, add_special_tokens=False)
479
- else:
480
- text_token = text
481
- if uncond_flag and do_uncond_drop:
482
- text_token = [self.cfg_token_id] * len(text_token)
483
- # Cutoff the text by max_length if necessary
484
- if max_length is not None and (cum_length + len(text_token)) > max_length:
485
- text_token = text_token[:max_length - cum_length]
486
- text_tokens.extend(text_token)
487
- lengths.append(len(text_token))
488
- cum_length += len(text_token)
489
-
490
- # Prepend/Append <pad> tokens if applicable
491
- if pad is not None and (pad_length := max_length - len(text_tokens)) > 0:
492
- if pad == 'left':
493
- text_tokens = [self.pad_token_id] * pad_length + text_tokens
494
- elif pad == 'right':
495
- text_tokens = text_tokens + [self.pad_token_id] * pad_length
496
- else:
497
- raise ValueError(f"Unsupported padding method: {pad}.")
498
-
499
- if return_lengths:
500
- return text_tokens, lengths
501
- return text_tokens
502
-
503
- @staticmethod
504
- def _check_key_number_matched(keys, data):
505
- # Assert keys and token_source are matched
506
- assert set(keys) == set(data.keys()), (
507
- f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}."
508
- )
509
- key_counts = {k: 0 for k in keys}
510
- for key in keys:
511
- key_counts[key] += 1
512
- for key, count in key_counts.items():
513
- assert len(data[key]) == count, (
514
- f"Number of `{key}` in the token source should be matched with the template, but got "
515
- f"{data[key]}({len(data[key])}) and {count}."
516
- )
517
-
518
- def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False,
519
- add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None,
520
- add_guidance_token=False):
521
- if add_image_shape_token:
522
- token_seq.extend([
523
- self.special_token_map[f"<img_size_{base_size}>"],
524
- self.special_token_map[f"<img_ratio_{ratio_idx}>"]
525
- ])
526
- token_count += 2
527
- if add_timestep_token:
528
- token_seq.extend([self.special_token_map["<timestep>"]])
529
- extra_token_pos['timestep'].append(token_count)
530
- if image_type is not None:
531
- if image_type == "gen_image":
532
- extra_token_pos['gen_timestep'].append(token_count)
533
- elif image_type in ["joint_image"]:
534
- extra_token_pos['cond_timestep'].append(token_count)
535
- else:
536
- raise ValueError(f"Unsupported image type: {image_type}.")
537
- token_count += 1
538
- if add_guidance_token:
539
- token_seq.extend([self.special_token_map["<guidance>"]])
540
- extra_token_pos['guidance'].append(token_count)
541
- token_count += 1
542
- return token_count
543
-
544
- @staticmethod
545
- def _shorten_text(text):
546
- import re
547
- text = re.sub(r"(<img>)+", lambda m: f"[<img>]{{{len(m.group(0)) // 5}}}", text)
548
- text = re.sub(r"(<pad>)+", lambda m: f"[<pad>]{{{len(m.group(0)) // 5}}}", text)
549
- return text
550
-
551
- def encode_sequence(
552
- self,
553
- template: str,
554
- token_source: Dict[str, List],
555
- total_length=None,
556
- add_timestep_token=False,
557
- add_guidance_token=False,
558
- last_key_only_prefix=False,
559
- add_eos=True,
560
- use_front_boi_token=True,
561
- add_pad=True,
562
- add_bos=True,
563
- drop_last: Union[str, bool] = 'auto',
564
- add_image_shape_token=False,
565
- ):
566
- """
567
- Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning)
568
- and token source.
569
-
570
- Parameters
571
- ----------
572
- template: str
573
- Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image.
574
- "text-text-gen_image" means the sequence is composed of two sections of text and an image.
575
- token_source: Dict[str, List]
576
- Token source for each key in the template, in order.
577
- - text: List[Dict].
578
- - gen_image: List[Dict].
579
- - joint_image: List[Dict].
580
- total_length: int
581
- Total length of the encoded sequence, include padding tokens.
582
- add_timestep_token: bool
583
- Whether to add timestep token before the image tokens.
584
- (Right after the <img_ratio_*><img_size_*> tokens)
585
- add_guidance_token: bool
586
- Whether to add guidance token before the image tokens.
587
- last_key_only_prefix: bool
588
- Whether to only use the modal prefix in the last key.
589
- add_eos: bool or 'auto'
590
- Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto',
591
- add eos token only when the total_length is not reached and the last token is not <eos>.
592
- use_front_boi_token: bool:
593
- Whether to put the <boi> token at the front of iw, ih and timestep tokens.
594
- add_pad: bool or 'auto'
595
- Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens.
596
- add_bos: bool
597
- Whether to add bos token at the beginning of the sequence.
598
- drop_last: bool or 'auto'
599
- - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is
600
- in the middle of the image tokens, an error will raised.
601
- - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens,
602
- all the successive image tokens will be dropped.
603
- - If False, keep the last tokens exceeding the total_length, even if the total_length is reached.
604
- add_image_shape_token: bool
605
- Whether to add image shape token before the image tokens. (Right before the <timestep> token)
606
-
607
- Returns
608
- -------
609
- token_seq: list
610
- Encoded token sequence.
611
- extra_token_pos: dict
612
- Positions of extra tokens.
613
- """
614
- if last_key_only_prefix:
615
- assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True."
616
- if drop_last is True and total_length is None:
617
- raise ValueError("total_length should be provided when drop_last is True.")
618
-
619
- keys = template.split('-')
620
- modal_length = len(keys)
621
- index_indicator = {k: 0 for k in token_source}
622
- for k, v in token_source.items():
623
- assert isinstance(v, (list, tuple)), (
624
- f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}."
625
- )
626
- self._check_key_number_matched(keys, token_source)
627
-
628
- token_seq = []
629
- token_count = 0
630
- extra_token_pos = defaultdict(list)
631
- if add_bos:
632
- token_seq.append(self.bos_token_id)
633
- token_count += 1
634
- # If drop_last is True, we check the token_count on the fly and exit the loop if the total_length is reached.
635
- # This check is only applied to the block tokens. Block tokens mean the tokens that are unsplittable, like
636
- # image tokens. Text tokens are splittable, so we don't need to check the token_count for text.
637
- # If the loop is broken by drop_last, we don't add the eos token at the end because the sequence is not
638
- # complete.
639
- drop_last_break = False
640
- for i, key in enumerate(keys):
641
- source = token_source[key][index_indicator[key]]
642
- if key == "text":
643
- token_seq.extend(source) # text token sequence
644
- extra_token_pos["<text>_start"].append(token_count)
645
- token_count += len(source)
646
- extra_token_pos["<text>_end"].append(token_count - 1)
647
-
648
- elif key == "gen_image":
649
- if isinstance(source, int):
650
- source = {'length': source}
651
- extra_count = 2 + (
652
- 1 if source.get('timestep', add_timestep_token) else 0) + (
653
- 1 if source.get('guidance', add_guidance_token) else 0) + (
654
- 2 if source.get('image_shape', add_image_shape_token) else 0
655
- )
656
- if drop_last is True and token_count + extra_count + source['length'] > total_length:
657
- drop_last_break = True
658
- break
659
- if source.get('front_boi', use_front_boi_token):
660
- token_seq.append(self.boi_token_id)
661
- extra_token_pos["boi"].append(token_count)
662
- token_count += 1
663
- token_count = self._add_image_meta_info_token(
664
- token_seq=token_seq,
665
- token_count=token_count,
666
- extra_token_pos=extra_token_pos,
667
- add_timestep_token=source.get('timestep', add_timestep_token),
668
- add_guidance_token=source.get('guidance', add_guidance_token),
669
- add_image_shape_token=source.get('image_shape', add_image_shape_token),
670
- base_size=source.get('base_size'),
671
- ratio_idx=source.get('ratio_idx'),
672
- image_type=key,
673
- )
674
- if not source.get('front_boi', use_front_boi_token):
675
- token_seq.append(self.boi_token_id)
676
- extra_token_pos["boi"].append(token_count)
677
- token_count += 1
678
- if last_key_only_prefix and i == modal_length - 1:
679
- pass # for AR inference
680
- else:
681
- token_seq.extend(
682
- [self.img_token_id] * source['length'] + # token number
683
- [self.eoi_token_id]
684
- )
685
- extra_token_pos["<img>_start"].append(token_count)
686
- extra_token_pos["<all_img>_start"].append(token_count)
687
- token_count += source['length']
688
- extra_token_pos["<img>_end"].append(token_count - 1)
689
- extra_token_pos["<all_img>_end"].append(token_count - 1)
690
- extra_token_pos["eoi"].append(token_count)
691
- token_count += 1 # <eoi>
692
-
693
- elif key == "joint_image":
694
- assert isinstance(source['length'], list) and len(
695
- source['length']) == 2, "joint_image length should be a list of two integers"
696
- extra_count = 2 + 1 + ( # boi, eoi, joint_img_sep
697
- 1 if source.get('timestep', add_timestep_token) else 0) + (
698
- 2 if source.get('image_shape', add_image_shape_token) else 0
699
- )
700
- if drop_last is True and token_count + extra_count + sum(source['length']) > total_length:
701
- drop_last_break = True
702
- break
703
- if source.get('front_boi', use_front_boi_token):
704
- token_seq.append(self.boi_token_id) # Use patched boi for Janus, otherwise useing default <boi>
705
- extra_token_pos["boi"].append(token_count)
706
- token_count += 1
707
- token_count = self._add_image_meta_info_token(
708
- token_seq=token_seq,
709
- token_count=token_count,
710
- extra_token_pos=extra_token_pos,
711
- add_timestep_token=source.get('timestep', add_timestep_token),
712
- add_image_shape_token=source.get('image_shape', add_image_shape_token),
713
- base_size=source.get('base_size'),
714
- ratio_idx=source.get('ratio_idx'),
715
- image_type=key,
716
- )
717
- if not source.get('front_boi', use_front_boi_token):
718
- token_seq.append(self.boi_token_id)
719
- extra_token_pos["boi"].append(token_count)
720
- token_count += 1
721
- if last_key_only_prefix and i == modal_length - 1:
722
- pass # for AR inference
723
- else:
724
- token_seq.extend(
725
- [self.img_token_id] * source['length'][0]
726
- )
727
- extra_token_pos["<vae_img>_start"].append(token_count)
728
- extra_token_pos["<joint_img>_start"].append(token_count)
729
- extra_token_pos["<all_img>_start"].append(token_count)
730
- token_count += source['length'][0]
731
- extra_token_pos["<vae_img>_end"].append(token_count - 1)
732
- extra_token_pos["<all_img>_end"].append(token_count - 1)
733
-
734
- token_seq.extend(
735
- [self.special_token_map["<joint_img_sep>"]]
736
- )
737
- extra_token_pos["joint_img_sep"].append(token_count)
738
- token_count += 1
739
-
740
- token_seq.extend(
741
- [self.img_token_id] * source['length'][1]
742
- )
743
- extra_token_pos["<vit_img>_start"].append(token_count)
744
- extra_token_pos["<all_img>_start"].append(token_count)
745
- token_count += source['length'][1]
746
- extra_token_pos["<vit_img>_end"].append(token_count - 1)
747
- extra_token_pos["<joint_img>_end"].append(token_count - 1)
748
- extra_token_pos["<all_img>_end"].append(token_count - 1)
749
-
750
- token_seq.extend(
751
- [self.eoi_token_id]
752
- )
753
- extra_token_pos["eoi"].append(token_count)
754
- token_count += 1 # <eoi>
755
-
756
- else:
757
- raise ValueError(f"Not supported key: {key}")
758
- index_indicator[key] += 1
759
-
760
- if add_eos is True and not drop_last_break:
761
- # Typically used for t2i task.
762
- token_seq.append(self.eos_token_id)
763
- extra_token_pos["eos"].append(token_count)
764
- token_count += 1
765
- elif add_eos == 'auto' and not drop_last_break:
766
- # Typically used for lm and mmu task.
767
- if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length):
768
- token_seq.append(self.eos_token_id)
769
- extra_token_pos["eos"].append(token_count)
770
- token_count += 1
771
-
772
- if total_length:
773
- # Check token count and clip sequence if necessary
774
- if token_count > total_length and drop_last:
775
- # Assert clip position is not in the middle of the block-wise tokens (gen_image, joint_image)
776
- for start_key, end_key in [
777
- ("<img>_start", "<img>_end"), ("<joint_img>_start", "<joint_img>_end"),
778
- ("<vae_img>_start", "<vae_img>_end"), ("<vit_img>_start", "<vit_img>_end"),
779
- ]:
780
- if start_key in extra_token_pos and end_key in extra_token_pos:
781
- assert all(
782
- (start > total_length or end + 1 < total_length)
783
- for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key])
784
- ), ("Clip position should not be in the middle of the image tokens.\n"
785
- f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}")
786
- token_seq = token_seq[:total_length]
787
-
788
- # Pad the sequence if necessary
789
- pad_num = max(0, total_length - len(token_seq))
790
- if add_pad and pad_num:
791
- token_seq.extend([self.pad_token_id] * pad_num)
792
- extra_token_pos["first_pad"].append(token_count)
793
-
794
- return token_seq, extra_token_pos
795
-
796
- def batch_gen_infer(
797
- self,
798
- infer_fn,
799
- prompt_list: list,
800
- negative_prompt_list: list = None,
801
- infer_fn_kwargs_list: List[Dict[str, int]] = None,
802
- do_classifier_free_guidance=False,
803
- condition_repeat_times: int = 1,
804
- uncondition_repeat_times: int = 1,
805
- ):
806
- """
807
- Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks.
808
-
809
- Parameters
810
- ----------
811
- infer_fn: callable
812
- Inference function to encode the prompt.
813
- prompt_list: list
814
- List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn.
815
- negative_prompt_list: list
816
- List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use <cfg>
817
- token sequence as negative prompt.
818
- infer_fn_kwargs_list: List[Dict[str, int]]
819
- List of keyword arguments for the infer_fn.
820
- do_classifier_free_guidance: bool
821
- Whether to do classifier-free guidance.
822
- condition_repeat_times: int
823
- Support multi-condition.
824
- uncondition_repeat_times: int
825
- Support multi-uncondition.
826
- """
827
- if infer_fn_kwargs_list is None:
828
- infer_fn_kwargs_list = [{} for _ in prompt_list]
829
-
830
- # [n_output, bsz]
831
- cond_results_list = None
832
- uncond_results_list = None
833
- output_type_list = []
834
-
835
- for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)):
836
- if not isinstance(prompt, (list, tuple)):
837
- prompt = [prompt]
838
- cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {}
839
- results = infer_fn(
840
- *prompt,
841
- **infer_fn_kwargs,
842
- **cond_kwargs,
843
- )
844
- output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1))
845
- if isinstance(results, dict):
846
- raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.")
847
- if not isinstance(results, (list, tuple)):
848
- results = (results,)
849
- if cond_results_list is None:
850
- cond_results_list = [[] for _ in results]
851
- uncond_results_list = [[] for _ in results]
852
- for i, result in enumerate(results):
853
- cond_results_list[i].append(result)
854
-
855
- if do_classifier_free_guidance:
856
- if negative_prompt_list is None:
857
- uncond_kwargs = {"uncond_p": 1.0}
858
- uncond_results = infer_fn(
859
- *prompt,
860
- **infer_fn_kwargs,
861
- **uncond_kwargs,
862
- )
863
- else:
864
- negative_prompt = negative_prompt_list[prompt_idx]
865
- if not isinstance(negative_prompt, (list, tuple)):
866
- negative_prompt = [negative_prompt]
867
- uncond_results = infer_fn(
868
- *negative_prompt,
869
- **infer_fn_kwargs,
870
- )
871
- if isinstance(uncond_results, TokenizerEncodeOutput):
872
- uncond_results_list.append(uncond_results)
873
- else:
874
- for i, result in enumerate(uncond_results):
875
- uncond_results_list[i].append(result)
876
-
877
- assert all(output_type_list[0] == n for n in output_type_list), \
878
- f"Number of outputs should be equal for all samples, but got {output_type_list}."
879
- output_type, output_num = output_type_list[0]
880
-
881
- def make_batch(batch_cond_item, batch_uncond_item):
882
- # Process each output item to make batch
883
- first = batch_cond_item[0] # The first element in the batch
884
- if isinstance(first, torch.Tensor):
885
- stacked_item = torch.stack(self.pad(
886
- batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times,
887
- ))
888
-
889
- elif first is None:
890
- assert all(item is None for item in batch_cond_item + batch_uncond_item), \
891
- (f"The first cond item is None, but some items are not None:\n\n"
892
- f"condition: {batch_cond_item}\n\n"
893
- f"uncondition: {batch_uncond_item}")
894
- stacked_item = None
895
-
896
- elif isinstance(first, (list, tuple)):
897
- # If the output item is a list or tuple, we treat it as a whole, and won't make nested batch any more.
898
- stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times
899
-
900
- elif isinstance(first, TokenizerEncodeOutput):
901
- stacked_item = {}
902
- # Traverse not-None attributes
903
- for key in list(first.keys()):
904
- merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \
905
- [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times
906
- if isinstance(first[key], torch.Tensor):
907
- if 'mask' in key:
908
- pad_val = 0.0
909
- elif key == 'tokens':
910
- pad_val = self.special_token_map["<pad>"]
911
- else:
912
- pad_val = False # Should not pad for other tensors
913
- stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0)
914
- elif isinstance(first[key], list):
915
- stacked_item[key] = merged_list
916
- elif first[key] is None:
917
- pass
918
- else:
919
- raise ValueError(f"Unsupported type of {key}: {type(first[key])}.")
920
- stacked_item = TokenizerEncodeOutput(stacked_item)
921
-
922
- else:
923
- raise TypeError(f"Making batch on type {type(first)} is not supported.")
924
-
925
- return stacked_item
926
-
927
- stacked_outputs = []
928
- for cond_results, uncond_results in zip(cond_results_list, uncond_results_list):
929
- stacked_outputs.append(make_batch(cond_results, uncond_results))
930
-
931
- if output_type == list:
932
- return stacked_outputs
933
- elif output_type == tuple:
934
- return tuple(stacked_outputs)
935
- elif output_num == 1:
936
- return stacked_outputs[0]
937
- else:
938
- raise ValueError(f"Unsupported output type: {output_type}.")
939
-
940
- @staticmethod
941
- def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None):
942
- if rng is None:
943
- rng = slice(None)
944
- image_slices = [
945
- slice(start, end + 1)
946
- for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng])
947
- ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else []
948
- if image_slices:
949
- image_mask = torch.zeros_like(tokens, dtype=torch.bool)
950
- for image_slice in image_slices:
951
- image_mask[image_slice] = True
952
- else:
953
- image_mask = None
954
- return image_slices, image_mask
955
-
956
- def encode_general(
957
- self,
958
- sections: Optional[List[Dict[str, Any]]] = None,
959
- max_token_length: Optional[int] = None,
960
- add_eos='auto',
961
- use_text_mask=True,
962
- add_pad='auto',
963
- add_bos=True,
964
- drop_last='auto',
965
- ):
966
- """
967
- General encode function to encode a sequence with multiple sections of text and images.
968
- Each section is a dict with a `type` key and other keys depending on the type.
969
- Supported section types:
970
- - text: dict with keys:
971
- - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided.
972
- - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided.
973
- - uncond_enabled: bool, whether to enable uncondition for this text section.
974
- - uncond_p: float, probability to drop the text section for uncondition.
975
- - max_length: int, maximum length of the text section.
976
- - ignore: bool, whether to ignore this text section in the text mask.
977
- - start_offset: int, start offset of the text mask.
978
- - end_offset: int, end offset of the text mask.
979
- - gen_image: dict with keys:
980
- - token_length: int, number of image tokens.
981
- - add_timestep_token: bool, whether to add timestep token before the image tokens.
982
- - add_guidance_token: bool, whether to add guidance token before the image tokens.
983
- - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens.
984
- - add_image_shape_token: bool, whether to add image shape token before the image tokens.
985
- - base_size: int, base size of the image.
986
- - ratio_idx: int, ratio index of the image.
987
- - joint_image: dict with keys:
988
- - token_length: List[int], number of image tokens for the two images.
989
- - add_timestep_token: bool, whether to add timestep token before the image tokens.
990
- - use_front_boi_token: bool, whether to put the <boi> token at the front of size, ratio and timestep tokens.
991
- - add_image_shape_token: bool, whether to add image shape token before the image tokens.
992
- - base_size: int, base size of the image.
993
- - ratio_idx: int, ratio index of the image.
994
-
995
- Parameters
996
- ----------
997
- sections: List[Dict[str, Any]]
998
- List of sections to be encoded.
999
- max_token_length: int
1000
- Maximum length of the encoded token sequence.
1001
- add_eos: bool or 'auto'
1002
- Whether to add eos token at the end of the sequence. If True, always add eos
1003
- token. If 'auto', add eos token only when the total_length is not reached and the last token is not <eos>.
1004
- use_text_mask: bool
1005
- Whether to generate text mask.
1006
- add_pad: bool or 'auto'
1007
- Whether to add padding tokens to the sequence. If True and total_length is not reached,
1008
- add padding tokens.
1009
- add_bos: bool
1010
- Whether to add bos token at the beginning of the sequence.
1011
- drop_last: bool or 'auto'
1012
- - If auto, drop last tokens exceeding the total_length if the total_length is provided.
1013
- If cut point is in the middle of the image tokens, an error will raised.
1014
- - If True, drop last tokens exceeding the total_length. If cut point is in the
1015
- middle of the image tokens, all the successive image tokens will be dropped.
1016
- - If False, keep the last tokens exceeding the total_length, even if the total_length
1017
- is reached.
1018
-
1019
- Returns
1020
- -------
1021
- TokenizerEncodeOutput
1022
- Encoded token sequence and extra information.
1023
- """
1024
- if sections is None:
1025
- raise ValueError("sections must be provided.")
1026
- template = '-'.join([section['type'] for section in sections])
1027
-
1028
- sections = deepcopy(sections)
1029
- token_source = defaultdict(list)
1030
- text_mask_specs = []
1031
- for section in sections:
1032
- if section['type'] == 'text':
1033
- text = self.encode_text(
1034
- section['text'] if 'text' in section else section['tokens'],
1035
- uncond_enabled=section.get('uncond_enabled'),
1036
- uncond_p=section.get('uncond_p'),
1037
- max_length=section.get('max_length'),
1038
- )
1039
- token_source['text'].append(text)
1040
- text_mask_specs.append(dict(
1041
- ignore=section.get('ignore', False),
1042
- start_offset=section.get('start_offset', 0),
1043
- end_offset=section.get('end_offset', 0),
1044
- ))
1045
- elif section['type'] == 'gen_image':
1046
- token_source['gen_image'].append(dict(
1047
- length=section['token_length'],
1048
- timestep=section.get('add_timestep_token', False),
1049
- guidance=section.get('add_guidance_token', False),
1050
- front_boi=section.get('use_front_boi_token', False),
1051
- image_shape=section.get('add_image_shape_token', False),
1052
- base_size=section.get('base_size'),
1053
- ratio_idx=section.get('ratio_idx'),
1054
- ))
1055
- elif section['type'] == 'joint_image':
1056
- token_source['joint_image'].append(dict(
1057
- length=section['token_length'],
1058
- timestep=section.get('add_timestep_token', False),
1059
- front_boi=section.get('use_front_boi_token', False),
1060
- image_shape=section.get('add_image_shape_token', False),
1061
- base_size=section.get('base_size'),
1062
- ratio_idx=section.get('ratio_idx'),
1063
- ))
1064
- else:
1065
- raise ValueError(f"Invalid section type: {section['type']}")
1066
-
1067
- # Combine text and image tokens
1068
- full_token_seq, extra_token_pos = self.encode_sequence(
1069
- template=template,
1070
- token_source=dict(token_source),
1071
- total_length=max_token_length,
1072
- add_eos=add_eos,
1073
- add_pad=add_pad,
1074
- add_bos=add_bos,
1075
- drop_last=drop_last,
1076
- )
1077
- full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long)
1078
-
1079
- timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \
1080
- if 'timestep' in extra_token_pos else None
1081
- guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \
1082
- if 'guidance' in extra_token_pos else None
1083
- cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \
1084
- if 'cond_timestep' in extra_token_pos else None
1085
- gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \
1086
- if 'gen_timestep' in extra_token_pos else None
1087
-
1088
- # Gen image mask
1089
- gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor)
1090
- # Joint image
1091
- joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor)
1092
- # Conditional vae image
1093
- cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos(
1094
- extra_token_pos, 'vae_img', full_seq_token_tensor)
1095
- # Conditional vit image
1096
- cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos(
1097
- extra_token_pos, 'vit_img', full_seq_token_tensor)
1098
- # All image slices (gen_image, joint_image)
1099
- all_image_slices = [
1100
- slice(start, end + 1)
1101
- for start, end in zip(extra_token_pos['<all_img>_start'], extra_token_pos['<all_img>_end'])
1102
- ] if '<all_img>_start' in extra_token_pos and '<all_img>_end' in extra_token_pos else []
1103
-
1104
- # Text mask
1105
- text_slices = [
1106
- slice(start, end + 1)
1107
- for start, end in zip(extra_token_pos['<text>_start'], extra_token_pos['<text>_end'])
1108
- ] if '<text>_start' in extra_token_pos and '<text>_end' in extra_token_pos else []
1109
- assert len(text_slices) <= len(text_mask_specs), \
1110
- (f"Number of text slices ({len(text_slices)}) should be less than or equal to "
1111
- f"number of text mask specs ({len(text_mask_specs)})")
1112
- if use_text_mask:
1113
- text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32)
1114
- for text_slice, mask_spec in zip(text_slices, text_mask_specs):
1115
- if not mask_spec['ignore']:
1116
- real_slice = slice(
1117
- text_slice.start + mask_spec['start_offset'],
1118
- text_slice.stop + mask_spec['end_offset']
1119
- )
1120
- text_mask[real_slice] = 1.0
1121
- else:
1122
- text_mask = None
1123
-
1124
- # real_pos is the first position of the <pad> token
1125
- real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long)
1126
-
1127
- return TokenizerEncodeOutput(
1128
- tokens=full_seq_token_tensor,
1129
- timestep_scatter_index=timestep_scatter_index,
1130
- guidance_scatter_index=guidance_scatter_index,
1131
- text_slices=text_slices,
1132
- gen_image_slices=gen_image_slices,
1133
- joint_image_slices=joint_image_slices,
1134
- cond_vae_image_slices=cond_vae_image_slices,
1135
- cond_vit_image_slices=cond_vit_image_slices,
1136
- text_mask=text_mask,
1137
- gen_image_mask=gen_image_mask,
1138
- cond_vae_image_mask=cond_vae_image_mask,
1139
- cond_vit_image_mask=cond_vit_image_mask,
1140
- real_pos=real_pos,
1141
- all_image_slices=all_image_slices,
1142
- cond_timestep_scatter_index=cond_timestep_scatter_index,
1143
- gen_timestep_scatter_index=gen_timestep_scatter_index,
1144
- )
1145
-
1146
- def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False):
1147
- if not cot_text: # None or empty
1148
- return []
1149
- if '<think>' in cot_text and '</think>' in cot_text:
1150
- before_think_sec = cot_text.split('<think>')[0]
1151
- after_think_sec = cot_text.split('</think>')[1]
1152
- think_sec = cot_text.split('<think>')[1].split('</think>')[0]
1153
- return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \
1154
- ([
1155
- dict(type="text", text="<think>"),
1156
- dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs),
1157
- dict(type="text", text="</think>")
1158
- ] if not drop_think else []) + \
1159
- self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think)
1160
-
1161
- if '<recaption>' in cot_text and '</recaption>' in cot_text:
1162
- before_recaption_sec = cot_text.split('<recaption>')[0]
1163
- after_recaption_sec = cot_text.split('</recaption>')[1]
1164
- recaption_sec = cot_text.split('<recaption>')[1].split('</recaption>')[0]
1165
- return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \
1166
- [
1167
- dict(type="text", text="<recaption>"),
1168
- dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs),
1169
- dict(type="text", text="</recaption>")
1170
- ] + \
1171
- self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think)
1172
-
1173
- return [
1174
- dict(type="text", text=cot_text, **uncond_kwargs),
1175
- ]
1176
-
1177
- def apply_general_template(
1178
- self,
1179
- message_list,
1180
- max_length=None,
1181
- add_assistant_prefix=False,
1182
- answer="auto",
1183
- bot_task="auto",
1184
- sequence_template="instruct",
1185
- uncond_p=0.0,
1186
- cfg_factor=1,
1187
- batchify=False,
1188
- image_base_size=1024,
1189
- drop_think=False,
1190
- ):
1191
- # If cfg_factor > 1, we need to repeat the unconditioned part
1192
- if batchify:
1193
- assert isinstance(message_list[0], list), \
1194
- f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]."
1195
- return self.batch_gen_infer(
1196
- infer_fn=self.apply_general_template,
1197
- prompt_list=[[]],
1198
- infer_fn_kwargs_list=[dict(
1199
- message_list=message_list_i,
1200
- max_length=max_length,
1201
- add_assistant_prefix=add_assistant_prefix,
1202
- answer=answer,
1203
- bot_task=bot_task,
1204
- sequence_template=sequence_template,
1205
- image_base_size=image_base_size,
1206
- drop_think=drop_think,
1207
- ) for message_list_i in message_list],
1208
- do_classifier_free_guidance=cfg_factor > 1,
1209
- condition_repeat_times=1,
1210
- uncondition_repeat_times=cfg_factor - 1,
1211
- )
1212
-
1213
- conv = Conversation()
1214
- uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p)
1215
-
1216
- def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix,
1217
- answer_prefix="", answer_suffix=""):
1218
- _sub_sections = []
1219
- while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role:
1220
- message = _message_list[_cur_message_idx]
1221
- if message['type'] == 'text':
1222
- text = message['content']
1223
- if role == "system":
1224
- _sub_sections.append(dict(type="text", text=text))
1225
- elif role == "assistant":
1226
- if ("<recaption>" in text and "</recaption>" in text) or (
1227
- "<think>" in text and "</think>" in text):
1228
- _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think))
1229
- else:
1230
- _sub_sections.append(dict(type="text", text=text, **uncond_kwargs))
1231
- else:
1232
- _sub_sections.append(dict(
1233
- type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs))
1234
- elif message['type'] == 'gen_image':
1235
- info = message['content']
1236
- assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}"
1237
- if role == "assistant":
1238
- _sub_sections.append(dict(type="text", text=answer_prefix))
1239
- _sub_sections.append(dict(type=message['type'], **info.meta_info))
1240
- if role == "assistant":
1241
- _sub_sections.append(dict(type="text", text=answer_suffix))
1242
- elif message['type'] == 'joint_image':
1243
- info = message['content']
1244
- assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}"
1245
- _sub_sections.append(dict(type=message['type'], **info.meta_info))
1246
- else:
1247
- raise ValueError(f"Unknown message type: {message['type']}")
1248
- _cur_message_idx += 1
1249
- if len(_sub_sections) > 0:
1250
- # Add role prefix and suffix
1251
- _sub_sections.insert(0, dict(type='text', text=prefix))
1252
- _sub_sections.append(dict(type='text', text=suffix))
1253
- return _sub_sections, _cur_message_idx
1254
-
1255
- # Define assistant prefix and suffix
1256
- if (answer == "auto" and sequence_template == "instruct") or answer is True:
1257
- answer_prefix, answer_suffix = "<answer>", "</answer>"
1258
- else:
1259
- answer_prefix, answer_suffix = "", ""
1260
- if sequence_template == "pretrain":
1261
- system_suffix = ""
1262
- user_prefix = ""
1263
- user_suffix = ""
1264
- bot_prefix = ""
1265
- bot_suffix = ""
1266
- else:
1267
- system_suffix = f"{conv.sep}"
1268
- user_prefix = f"{conv.roles[0]}: "
1269
- user_suffix = f"{conv.sep}"
1270
- bot_prefix = f"{conv.roles[1]}: "
1271
- bot_suffix = f"{conv.sep}"
1272
-
1273
- # Process successive user and assistant messages
1274
- sections = []
1275
- cur_message_idx = 0
1276
- final_role = None
1277
- while cur_message_idx < len(message_list):
1278
- # Process successive system messages
1279
- sub_sections, cur_message_idx = process_successive_message(
1280
- message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix)
1281
- # Add to the template and sections
1282
- sections.extend(sub_sections)
1283
- if len(sub_sections) > 0:
1284
- final_role = "system"
1285
-
1286
- # Process successive user messages
1287
- sub_sections, cur_message_idx = process_successive_message(
1288
- message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix)
1289
- # Add to the template and sections
1290
- sections.extend(sub_sections)
1291
- if len(sub_sections) > 0:
1292
- final_role = "user"
1293
-
1294
- # Process successive assistant messages
1295
- sub_sections, cur_message_idx = process_successive_message(
1296
- message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix,
1297
- answer_prefix=answer_prefix, answer_suffix=answer_suffix,
1298
- )
1299
- # Add to the template and sections
1300
- sections.extend(sub_sections)
1301
- if len(sub_sections) > 0:
1302
- final_role = "assistant"
1303
-
1304
- if add_assistant_prefix:
1305
- if final_role == "assistant":
1306
- # Avoid adding prefix twice
1307
- _bot_prefix = ""
1308
- # Remove the final bot_suffix
1309
- if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix:
1310
- sections = sections[:-1]
1311
- else:
1312
- _bot_prefix = bot_prefix
1313
- # We can add special tokens for the bot lastest message according to different tasks
1314
- bot_response_prefix = dict(
1315
- auto=_bot_prefix,
1316
- image="",
1317
- think=f"{_bot_prefix}<think>",
1318
- recaption=f"{_bot_prefix}<recaption>",
1319
- img_ratio=f"{_bot_prefix}{answer_prefix}<boi><img_size_{image_base_size}>",
1320
- )[bot_task]
1321
- sections.append(dict(type='text', text=bot_response_prefix))
1322
-
1323
- output = self.encode_general(
1324
- sections=sections,
1325
- use_text_mask=False,
1326
- add_eos=False,
1327
- add_pad=False,
1328
- )
1329
-
1330
- if max_length is not None:
1331
- if output.tokens.shape[-1] > max_length:
1332
- raise ValueError(
1333
- f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n"
1334
- f"Please set a larger max_length or check the input messages:\n{message_list}"
1335
- )
1336
-
1337
- return output, sections
1338
-
1339
- def apply_chat_template(
1340
- self,
1341
- batch_prompt: Optional[List[str]] = None,
1342
- batch_message_list: Optional[List[List[Dict[str, Any]]]] = None,
1343
- mode: str = "gen_text",
1344
- batch_gen_image_info: Optional[List[ImageInfo]] = None,
1345
- batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None,
1346
- batch_system_prompt: Optional[List[str]] = None,
1347
- batch_cot_text: Optional[List[str]] = None,
1348
- max_length: Optional[int] = None,
1349
- bot_task: str = "auto", # auto/image/think/recaption/img_ratio
1350
- image_base_size: int = 1024,
1351
- sequence_template: str = "pretrain",
1352
- cfg_factor: int = 1,
1353
- add_assistant_prefix: Optional[bool] = None,
1354
- drop_think: bool = False,
1355
- ) -> Dict[str, Any]:
1356
- assert bot_task in ["image", "auto", "think", "recaption", "img_ratio"], \
1357
- f"bot_task should be one of ['image', 'auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}."
1358
-
1359
- if batch_message_list is None:
1360
- # Simple text-to-image or text-cot-to-image task
1361
- batch_size = len(batch_prompt)
1362
-
1363
- # Batchify inputs
1364
- if not isinstance(batch_system_prompt, list):
1365
- batch_system_prompt = [batch_system_prompt] * batch_size
1366
- if not isinstance(batch_gen_image_info, list):
1367
- batch_gen_image_info = [batch_gen_image_info] * batch_size
1368
- if batch_cot_text is not None:
1369
- assert len(batch_cot_text) == batch_size, \
1370
- (f"batch_cot_text should have the same length as batch_size ({batch_size}), "
1371
- f"but got {len(batch_cot_text)}.")
1372
- else:
1373
- batch_cot_text = [None] * batch_size
1374
- if batch_cond_image_info is not None:
1375
- assert len(batch_cond_image_info) == batch_size, \
1376
- (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), "
1377
- f"but got {len(batch_cond_image_info)}.")
1378
- batch_cond_image_info = [
1379
- cond_image_info if isinstance(cond_image_info, list) else [cond_image_info]
1380
- for cond_image_info in batch_cond_image_info
1381
- ]
1382
- else:
1383
- batch_cond_image_info = [[] for _ in range(batch_size)]
1384
-
1385
- # Convert single round materials into standard message list
1386
- batch_message_list = []
1387
- for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip(
1388
- batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info,
1389
- batch_cond_image_info,
1390
- ):
1391
- message_list = []
1392
- # 1. system prompt section
1393
- if system_prompt:
1394
- message_list.append(dict(
1395
- role="system", type="text", content=system_prompt, context_type="str"))
1396
- # 2. user inputs sections
1397
- # 2.1 image inputs
1398
- if len(cond_image_info_list) > 0:
1399
- message_list.extend([
1400
- dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info")
1401
- for cond_image_info in cond_image_info_list
1402
- ])
1403
- # 2.2 text inputs
1404
- message_list.append(dict(
1405
- role="user", type="text", content=prompt, context_type="str"))
1406
- # 3. assistant answer sections
1407
- if cot_text is not None:
1408
- message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str"))
1409
- if mode == "gen_image":
1410
- message_list.append(dict(
1411
- role="assistant", type="gen_image", content=gen_image_info, context_type="image_info"))
1412
- # ---
1413
- batch_message_list.append(message_list)
1414
-
1415
- output, sections = self.apply_general_template(
1416
- message_list=batch_message_list,
1417
- max_length=max_length,
1418
- add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"),
1419
- bot_task=bot_task,
1420
- sequence_template=sequence_template,
1421
- cfg_factor=cfg_factor,
1422
- batchify=True,
1423
- image_base_size=image_base_size,
1424
- drop_think=drop_think,
1425
- )
1426
- return dict(output=output, sections=sections)