kanana-1.5-v-3b-instruct / tokenization.py
peterroh's picture
Upload folder using huggingface_hub
4ce7387 verified
import logging
import re
from typing import Optional
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
# Role tokens
AI = "AI: "
HUMAN = "Human: "
_AI = "\n" + AI
_HUMAN = "\n" + HUMAN
# special media tokens
IMAGE = "<image>"
IMAGE_ROW_SEPARATOR = "\n"
IMAGE_GLOBAL_LOCAL_SEPARATOR = "\n"
MEDIA_TOKENS = {
"image": [IMAGE],
}
_INFINITE = int(1e12) # infinite token length for no-truncation
logger = logging.getLogger("kanana-1.5-v")
def _pad_trunc(
x: list[list[int]],
padding: str,
padding_side: str,
pad_value: int,
max_length: int,
) -> torch.LongTensor:
"""Pad and truncate sequences to the same length
Args:
x (list[list[int]])
padding ("longest" or "max_length")
padding_side ("left" or "right")
pad_value (int)
max_length (int or None): if padding == "max_length", max_length should be given.
"""
assert padding in ["longest", "max_length"]
assert padding_side in ["left", "right"]
lengths = [len(sample) for sample in x]
if padding == "longest":
max_length = max(lengths)
new_x = []
for sample, length in zip(x, lengths):
if torch.is_tensor(sample):
sample = sample.tolist()
if length >= max_length:
new_x.append(sample[:max_length])
continue
padding_size = max_length - length
pads = [pad_value] * padding_size
if padding_side == "right":
new_x.append(sample + pads)
else:
new_x.append(pads + sample)
return torch.as_tensor(new_x, dtype=torch.long)
class KananaVTokenizerMixin:
def mllm_setup(self, num_visual_tokens: int):
self.num_visual_tokens = num_visual_tokens
# Currently we only support the image modality for media modality.
self.media_tokens = {k: -int(i + 1) for i, k in enumerate(MEDIA_TOKENS["image"])}
self.media_lengths = {MEDIA_TOKENS["image"][0]: num_visual_tokens}
def repeat_image_tokens(
self, hw_tokens, with_row_separator=True, add_global_local_separator=False
):
if len(hw_tokens) == 3:
T, H, W = hw_tokens
else:
H, W = hw_tokens
repeated_tokens = []
if add_global_local_separator:
global_local_separator = self(IMAGE_GLOBAL_LOCAL_SEPARATOR, add_special_tokens=False)[
"input_ids"
]
repeated_tokens += global_local_separator
if with_row_separator:
row_sep = self(IMAGE_ROW_SEPARATOR, add_special_tokens=False)["input_ids"]
for h_idx in range(H):
repeated_tokens += [self.media_tokens[IMAGE]] * W
if with_row_separator and h_idx != H - 1:
repeated_tokens += row_sep
return repeated_tokens
def encode_text_only(self, prompt: str, add_special_tokens: bool = False) -> list:
# Text-only Data
# split prompt into chunks by role tokens
tokens_to_split = [_AI, _HUMAN]
pattern = "|".join(map(re.escape, tokens_to_split))
chunk_strs = re.split(f"({pattern})", prompt)
chunk_strs = [x for x in chunk_strs if len(x) > 0]
enc_chunk = []
for idx, chunk_str in enumerate(chunk_strs):
curr_chunk = self(chunk_str, add_special_tokens=False)["input_ids"]
enc_chunk += curr_chunk
return enc_chunk
def encode_prompt(
self, prompt: str, max_length: int | None = None, image_meta: dict | None = None
) -> dict:
"""Tokenize prompt which consists of image-text or text only, with role tokens.
Role pattern is "AI: " or "Human: ".
Args:
prompt
max_length (int or None): here, max_length is used for truncation.
If max_length is None, no truncation is applied.
"""
max_length = max_length or _INFINITE # if None, set to infinite for no-truncation
# output enc_chunk
enc_chunk = []
# Text-only or Image-Text Data
# split prompt into chunks by media and role tokens
tokens_to_split = list(self.media_tokens.keys()) + [_AI, _HUMAN]
pattern = "|".join(map(re.escape, tokens_to_split))
chunk_strs = re.split(f"({pattern})", prompt)
chunk_strs = [x for x in chunk_strs if len(x) > 0]
# tokenize chunks
img_idx = 0 # for sync with image_meta
for idx, chunk_str in enumerate(chunk_strs):
if chunk_str in self.media_tokens:
if chunk_str == IMAGE:
image_token_thw = (
image_meta["image_token_thw"][img_idx]
if image_meta.get("image_token_thw")
else None
)
media_tokens = self.repeat_image_tokens(
image_token_thw,
with_row_separator=True,
add_global_local_separator=True,
)
# increment image index
img_idx += 1
else:
raise ValueError("Unknown chunk str", chunk_str)
enc_chunk += media_tokens
else:
curr_chunk = self(chunk_str, add_special_tokens=False)["input_ids"]
enc_chunk += curr_chunk
L = len(enc_chunk)
input_ids = torch.as_tensor(enc_chunk, dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
assert L <= max_length, (
f"[Length exceeded] Input sequence length ({L}) is greater than "
f"the allowed max_length ({max_length}). "
"Please truncate the sequence or increase max_length."
)
return {
"input_ids": input_ids, # [L]
"seq_length": L, # int
"attention_mask": attention_mask, # [L]
}
def batch_collate_pad(
self,
batch: list,
padding: str,
padding_side: str,
max_length: int | None,
) -> dict[str, torch.LongTensor]:
"""Collate batch and pad/truncate to the same length
Args:
batch
padding ("longest" or "max_length")
padding_side ("left" or "right")
pad_value (int)
max_length (int or None): if padding == "max_length", max_length should be given
"""
if padding == "max_length":
assert max_length is not None, "max_length should be given if padding == 'max_length'"
else:
# if padding == 'longest' and max_length is None, set to infinite for no-truncation
max_length = max_length or _INFINITE
input_ids = [sample["input_ids"] for sample in batch]
attention_mask = [sample["attention_mask"] for sample in batch]
seq_length = [sample["seq_length"] for sample in batch]
input_ids = _pad_trunc(input_ids, padding, padding_side, self.pad_token_id, max_length)
attention_mask = _pad_trunc(attention_mask, padding, padding_side, 0, max_length)
seq_length = torch.as_tensor(seq_length, dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"seq_length": seq_length,
}
def get_chat_template(self) -> str:
"""Method for bw-compat: old HF transformers (e.g., 4.41.0) does not have get_chat_template
"""
return self.chat_template
class KananaVTokenizer(PreTrainedTokenizer, KananaVTokenizerMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def encode(self, text, add_special_tokens=False) -> list:
return self.encode_text_only(prompt=text, add_special_tokens=add_special_tokens)
class KananaVTokenizerFast(PreTrainedTokenizerFast, KananaVTokenizerMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def encode(self, text, add_special_tokens=False) -> list:
return self.encode_text_only(prompt=text, add_special_tokens=add_special_tokens)