File size: 6,763 Bytes
4ce7387 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import logging
import torch
from PIL.Image import Image
from transformers.processing_utils import ProcessorMixin
logger = logging.getLogger("kanana-1.5-v")
HUMAN = "Human: "
AI = "AI: "
CHAT_TEMPLATE = (
"""
{%- if bos_token is defined and bos_token %}
{{- bos_token }}
{%- endif %}
{%- set intro %}
The following is a conversation between a curious human and AI assistant. 당신은 Kakao에서 개발된 인공지능 언어모델이고 이름은 kanana입니다.
Knowledge Cutoff Date: June30, 2024.
Capabilities and Limitations:
- I cannot search for external content such as weather, news, or the current date and time.
- If a URL is provided, I cannot access it directly. Insteaed, please copy and provide the relevant content for me to process.
{%- endset %}
{{ intro }}
{{- '\n' }}
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{{- message['content'] }}
{%- elif message['role'] == 'user' %}
{{- '<|USER|>' + message['content'] }}
{%- elif message['role'] == 'assistant' %}
{{- '<|ASSISTANT|>' + message['content'] + eos_token }}
{%- endif %}
{%- if not loop.last %}
{{- '\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '\n<|ASSISTANT|>' }}
{%- endif %}
""".strip()
.replace("<|USER|>", HUMAN)
.replace("<|ASSISTANT|>", AI)
)
class KananaVProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
valid_kwargs = []
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
self.image_processor = image_processor
self.tokenizer = tokenizer
self.tokenizer.mllm_setup("dynamic")
def conv2prompt(
self,
conv: list[dict] | str,
chat_template=CHAT_TEMPLATE,
add_generation_prompt=False,
) -> str:
"""Convert conversation to prompt"""
if isinstance(conv, list):
prompt = self.tokenizer.apply_chat_template(
conversation=conv,
tokenize=False,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
elif isinstance(conv, str):
prompt = conv
else:
raise TypeError(f"conv must be list or str, but got {type(conv)}")
return prompt
def __call__(self, data: dict, max_length, add_generation_prompt=False):
return self.encode(data, max_length, add_generation_prompt=add_generation_prompt)
def encode(self, data: dict, max_length, add_generation_prompt=False) -> dict:
"""
Args:
data (dict): {
"conv": [
{"role": "system", "content": "The following is a conversation between a curious human and AI assistant."},
{"role": "user", "content": IMAGE},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
...
],
"image": [
PIL.Image,
...
]
}
Return:
data (dict): {
"text": text_tokens_from_tokenizer,
"text_raw": prompt,
"image": pixel_values,
"image_meta": image_meta (dict of list) includes image resolution, etc.
}
"""
assert "images" not in data
conv = data["conv"]
images: list[Image] = data.get("image") # PIL images
data = {
"text": None,
"text_raw": None,
"image": None,
"image_meta": None,
}
# image
if images:
processor_output = [
self.image_processor(image) for image in images if image
]
pixel_values = [
processor_output["pixel_values"] for processor_output in processor_output
]
image_meta = [processor_output["image_meta"] for processor_output in processor_output]
if pixel_values:
pixel_values = torch.concat(pixel_values, dim=0)
data["image"] = pixel_values
data["image_meta"] = {k: [d[k] for d in image_meta] for k in image_meta[0]}
# text
prompt = self.conv2prompt(conv, add_generation_prompt=add_generation_prompt)
text_tokens = self.tokenizer.encode_prompt(
prompt,
max_length,
image_meta=data["image_meta"],
)
data["text"] = text_tokens
data["text_raw"] = prompt
return data
def batch_encode_collate(
self,
data_list: list[dict],
padding: str = "longest",
padding_side: str = "right",
max_length: int | None = None,
add_generation_prompt=False,
):
"""Encode batch and collate them"""
batch = [
self.encode(data, max_length, add_generation_prompt=add_generation_prompt)
for data in data_list
]
batch = self.collate(
batch,
padding=padding,
padding_side=padding_side,
max_length=max_length,
)
return batch
def collate(
self,
batch,
padding,
padding_side,
max_length,
):
"""Collate encoded results to model inputs"""
text_batch = [data["text"] for data in batch]
text_batch = self.tokenizer.batch_collate_pad(
text_batch,
padding=padding,
padding_side=padding_side,
max_length=max_length,
)
image_list = [data["image"] for data in batch if data["image"] is not None]
image_meta = [data["image_meta"] for data in batch if data["image_meta"] is not None]
if len(image_meta) > 0:
image_meta = {
k: sum([d[k] for d in image_meta], []) for k in image_meta[0]
}
if image_meta.get("vision_grid_thw"):
image_meta["vision_grid_thw"] = torch.tensor(image_meta["vision_grid_thw"])
else:
image_meta = None
output_batch = text_batch
output_batch["pixel_values"] = torch.cat(image_list, dim=0) if len(image_list) > 0 else None
output_batch["image_metas"] = image_meta
return output_batch
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
|