|
import logging |
|
|
|
import torch |
|
from PIL.Image import Image |
|
from transformers.processing_utils import ProcessorMixin |
|
|
|
|
|
logger = logging.getLogger("kanana-1.5-v") |
|
|
|
|
|
HUMAN = "Human: " |
|
AI = "AI: " |
|
CHAT_TEMPLATE = ( |
|
""" |
|
{%- if bos_token is defined and bos_token %} |
|
{{- bos_token }} |
|
{%- endif %} |
|
{%- set intro %} |
|
The following is a conversation between a curious human and AI assistant. 당신은 Kakao에서 개발된 인공지능 언어모델이고 이름은 kanana입니다. |
|
Knowledge Cutoff Date: June30, 2024. |
|
Capabilities and Limitations: |
|
- I cannot search for external content such as weather, news, or the current date and time. |
|
- If a URL is provided, I cannot access it directly. Insteaed, please copy and provide the relevant content for me to process. |
|
{%- endset %} |
|
{{ intro }} |
|
{{- '\n' }} |
|
{%- for message in messages %} |
|
{%- if message['role'] == 'system' %} |
|
{{- message['content'] }} |
|
{%- elif message['role'] == 'user' %} |
|
{{- '<|USER|>' + message['content'] }} |
|
{%- elif message['role'] == 'assistant' %} |
|
{{- '<|ASSISTANT|>' + message['content'] + eos_token }} |
|
{%- endif %} |
|
{%- if not loop.last %} |
|
{{- '\n' }} |
|
{%- endif %} |
|
{%- endfor %} |
|
{%- if add_generation_prompt %} |
|
{{- '\n<|ASSISTANT|>' }} |
|
{%- endif %} |
|
""".strip() |
|
.replace("<|USER|>", HUMAN) |
|
.replace("<|ASSISTANT|>", AI) |
|
) |
|
|
|
|
|
class KananaVProcessor(ProcessorMixin): |
|
attributes = ["image_processor", "tokenizer"] |
|
valid_kwargs = [] |
|
image_processor_class = "AutoImageProcessor" |
|
tokenizer_class = "AutoTokenizer" |
|
|
|
def __init__(self, image_processor, tokenizer): |
|
super().__init__(image_processor, tokenizer) |
|
self.image_processor = image_processor |
|
self.tokenizer = tokenizer |
|
self.tokenizer.mllm_setup("dynamic") |
|
|
|
def conv2prompt( |
|
self, |
|
conv: list[dict] | str, |
|
chat_template=CHAT_TEMPLATE, |
|
add_generation_prompt=False, |
|
) -> str: |
|
"""Convert conversation to prompt""" |
|
if isinstance(conv, list): |
|
prompt = self.tokenizer.apply_chat_template( |
|
conversation=conv, |
|
tokenize=False, |
|
chat_template=chat_template, |
|
add_generation_prompt=add_generation_prompt, |
|
) |
|
elif isinstance(conv, str): |
|
prompt = conv |
|
else: |
|
raise TypeError(f"conv must be list or str, but got {type(conv)}") |
|
|
|
return prompt |
|
|
|
def __call__(self, data: dict, max_length, add_generation_prompt=False): |
|
return self.encode(data, max_length, add_generation_prompt=add_generation_prompt) |
|
|
|
def encode(self, data: dict, max_length, add_generation_prompt=False) -> dict: |
|
""" |
|
Args: |
|
data (dict): { |
|
"conv": [ |
|
{"role": "system", "content": "The following is a conversation between a curious human and AI assistant."}, |
|
{"role": "user", "content": IMAGE}, |
|
{"role": "user", "content": "Hello, how are you?"}, |
|
{"role": "assistant", "content": "I'm doing great. How can I help you today?"}, |
|
... |
|
], |
|
"image": [ |
|
PIL.Image, |
|
... |
|
] |
|
} |
|
|
|
Return: |
|
data (dict): { |
|
"text": text_tokens_from_tokenizer, |
|
"text_raw": prompt, |
|
"image": pixel_values, |
|
"image_meta": image_meta (dict of list) includes image resolution, etc. |
|
} |
|
""" |
|
assert "images" not in data |
|
|
|
conv = data["conv"] |
|
images: list[Image] = data.get("image") |
|
|
|
data = { |
|
"text": None, |
|
"text_raw": None, |
|
"image": None, |
|
"image_meta": None, |
|
} |
|
|
|
|
|
if images: |
|
processor_output = [ |
|
self.image_processor(image) for image in images if image |
|
] |
|
pixel_values = [ |
|
processor_output["pixel_values"] for processor_output in processor_output |
|
] |
|
image_meta = [processor_output["image_meta"] for processor_output in processor_output] |
|
if pixel_values: |
|
pixel_values = torch.concat(pixel_values, dim=0) |
|
data["image"] = pixel_values |
|
data["image_meta"] = {k: [d[k] for d in image_meta] for k in image_meta[0]} |
|
|
|
|
|
prompt = self.conv2prompt(conv, add_generation_prompt=add_generation_prompt) |
|
text_tokens = self.tokenizer.encode_prompt( |
|
prompt, |
|
max_length, |
|
image_meta=data["image_meta"], |
|
) |
|
|
|
data["text"] = text_tokens |
|
data["text_raw"] = prompt |
|
|
|
return data |
|
|
|
def batch_encode_collate( |
|
self, |
|
data_list: list[dict], |
|
padding: str = "longest", |
|
padding_side: str = "right", |
|
max_length: int | None = None, |
|
add_generation_prompt=False, |
|
): |
|
"""Encode batch and collate them""" |
|
batch = [ |
|
self.encode(data, max_length, add_generation_prompt=add_generation_prompt) |
|
for data in data_list |
|
] |
|
batch = self.collate( |
|
batch, |
|
padding=padding, |
|
padding_side=padding_side, |
|
max_length=max_length, |
|
) |
|
|
|
return batch |
|
|
|
def collate( |
|
self, |
|
batch, |
|
padding, |
|
padding_side, |
|
max_length, |
|
): |
|
"""Collate encoded results to model inputs""" |
|
text_batch = [data["text"] for data in batch] |
|
|
|
text_batch = self.tokenizer.batch_collate_pad( |
|
text_batch, |
|
padding=padding, |
|
padding_side=padding_side, |
|
max_length=max_length, |
|
) |
|
|
|
image_list = [data["image"] for data in batch if data["image"] is not None] |
|
image_meta = [data["image_meta"] for data in batch if data["image_meta"] is not None] |
|
if len(image_meta) > 0: |
|
image_meta = { |
|
k: sum([d[k] for d in image_meta], []) for k in image_meta[0] |
|
} |
|
if image_meta.get("vision_grid_thw"): |
|
image_meta["vision_grid_thw"] = torch.tensor(image_meta["vision_grid_thw"]) |
|
else: |
|
image_meta = None |
|
|
|
output_batch = text_batch |
|
|
|
output_batch["pixel_values"] = torch.cat(image_list, dim=0) if len(image_list) > 0 else None |
|
output_batch["image_metas"] = image_meta |
|
return output_batch |
|
|
|
def decode(self, *args, **kwargs): |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|