|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
from PIL import Image |
|
from torch import nn |
|
from transformers import ( |
|
CLIPVisionModel, |
|
GenerationMixin, |
|
PreTrainedModel, |
|
PreTrainedTokenizer, |
|
) |
|
|
|
from .catty import split_image_with_catty |
|
from .configuration_points_chat import POINTSChatConfig |
|
from .dynamic_high_resolution import split_image |
|
from .modeling_llama import CustomLlamaForCausalLM |
|
|
|
|
|
class POINTSChatModel(PreTrainedModel, GenerationMixin): |
|
config_class = POINTSChatConfig |
|
_no_split_modules = ["CLIPVisionModel", "LLamaDecoderLayer"] |
|
"""Chat model for POINTS. |
|
|
|
Official implementation of the paper "POINTS: Improving Your Vision-language Model with Affordable Strategies" # noqa: E501 |
|
paper: https://huggingface.co/papers/2409.04828 |
|
|
|
Args: |
|
config (PretrainedConfig): The model config. |
|
""" |
|
|
|
def __init__(self, config: POINTSChatConfig) -> None: |
|
super().__init__(config) |
|
self.general_vit = CLIPVisionModel(config.vision_config) |
|
self.ocr_vit = CLIPVisionModel(config.vision_config) |
|
self.llm = CustomLlamaForCausalLM(config.llm_config) |
|
self.vision_projector = nn.Sequential( |
|
nn.Linear(config.vision_config.hidden_size * |
|
4, config.llm_config.hidden_size), |
|
nn.GELU(), |
|
nn.Linear(config.llm_config.hidden_size, |
|
config.llm_config.hidden_size) |
|
|
|
) |
|
|
|
def apply_chat_template(self, prompt: str, image_num: int) -> str: |
|
"""Apply the Yi-1.5-Chat template to the prompt. |
|
|
|
Args: |
|
prompt (str): The prompt to apply the template to. |
|
image_num (int): The number of the image in the prompt. |
|
Returns: |
|
str: The prompt with the template applied. |
|
""" |
|
image_tokens = ('<|endoftext|>' * 144) * image_num |
|
prompt = f'<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n<|im_start|>assistant\n' |
|
return prompt |
|
|
|
def pixel_shuffle(self, feature_map: torch.Tensor, |
|
scale_factor: float = 0.5) -> torch.Tensor: |
|
"""Implementation of pixel shuffle. |
|
|
|
Merge several patches into a single patch by concatenating |
|
them across the channel dimension. Therefore, we can reduce |
|
the image sequence length. In POINTS, we merge 2x2 adjacent |
|
patches into a single patch. |
|
|
|
Args: |
|
feature_map (torch.Tensor): The feature map to be pixel |
|
shuffled. |
|
scale_factor (float, optional): The scale factor for the |
|
""" |
|
|
|
|
|
n, w, h, c = feature_map.size() |
|
|
|
feature_map = feature_map.view( |
|
n, w, int(h * scale_factor), int(c / scale_factor)) |
|
|
|
feature_map = feature_map.permute(0, 2, 1, 3).contiguous() |
|
|
|
feature_map = feature_map.view( |
|
n, |
|
int(h * scale_factor), |
|
int(w * scale_factor), |
|
int(c / (scale_factor * scale_factor)), |
|
) |
|
feature_map = feature_map.permute(0, 2, 1, 3).contiguous() |
|
return feature_map |
|
|
|
def extract_image_features(self, images: torch.Tensor, |
|
vision_encoder: str = 'general_vit') -> torch.Tensor: |
|
"""Extract the image features from the vision encoder. |
|
|
|
Args: |
|
images (torch.Tensor): The images to extract the features from. |
|
vision_encoder (str, optional): The vision encoder to use. |
|
Defaults to 'general_vit'. |
|
|
|
Returns: |
|
torch.Tensor: The extracted image features. |
|
""" |
|
if vision_encoder == 'general_vit': |
|
image_features = self.general_vit( |
|
images, output_hidden_states=True |
|
) |
|
else: |
|
image_features = self.ocr_vit( |
|
images, output_hidden_states=True |
|
) |
|
image_features = image_features.hidden_states[-2] |
|
image_features = image_features[:, 1:] |
|
image_features = image_features.reshape(-1, 24, 24, 1024) |
|
image_features = self.pixel_shuffle(image_features, 0.5) |
|
image_features = image_features.view(-1, 144, 4096) |
|
image_features = self.vision_projector(image_features) |
|
return image_features |
|
|
|
def get_pos_mapping(self, pos: List[list]) -> Tuple[dict, int]: |
|
"""Get the position mapping for the images. |
|
|
|
Args: |
|
pos (List[list]): The position of the images in the prompt. |
|
|
|
Returns: |
|
Tuple[dict, int]: The position mapping and the |
|
total number of images. |
|
""" |
|
mapping = {} |
|
total_images = 0 |
|
for i, (start, end) in enumerate(pos): |
|
num_image = int((end - start) / 144) |
|
mapping[i] = num_image |
|
total_images += num_image |
|
return mapping, total_images |
|
|
|
@torch.no_grad() |
|
def chat(self, pixel_values: Image, prompt: str, |
|
tokenizer: PreTrainedTokenizer, |
|
image_processor, catty: bool = True, |
|
generation_config: dict = None, |
|
max_splits: int = 8) -> str: |
|
"""Generate a response to the input prompt. |
|
|
|
Args: |
|
pixel_values (Image): The input image. |
|
prompt (str): The input prompt. |
|
tokenizer (PreTrainedTokenizer): The tokenizer to use. |
|
image_processor: The image processor to use. |
|
catty (bool, optional): Whether to use catty. Defaults to True. |
|
generation_config (dict, optional): The generation config. |
|
Defaults to None. |
|
max_splits (int, optional): The maximum number of splits. |
|
Defaults to 8. |
|
Returns: |
|
str: The generated response. |
|
""" |
|
if catty: |
|
cropped_images = split_image_with_catty(pixel_values, |
|
do_resize=True, |
|
max_crop_slices=max_splits) |
|
else: |
|
cropped_images = split_image(pixel_values, max_splits=max_splits) |
|
prompt = self.apply_chat_template(prompt, len(cropped_images)) |
|
cropped_images = image_processor.preprocess( |
|
cropped_images, return_tensors='pt')['pixel_values'] |
|
cropped_images = cropped_images.to(self.device) |
|
|
|
general_vit_features = self.extract_image_features( |
|
cropped_images, vision_encoder='general_vit') |
|
|
|
ocr_vit_features = self.extract_image_features( |
|
cropped_images, vision_encoder='ocr_vit') |
|
image_features = 0.5 * general_vit_features + 0.5 * ocr_vit_features |
|
model_inputs = tokenizer(prompt, return_tensors='pt') |
|
input_ids = model_inputs['input_ids'].to(self.device) |
|
attention_mask = model_inputs['attention_mask'].to(self.device) |
|
|
|
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
|
|
image_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") |
|
generation_config.update( |
|
{ |
|
'eos_token_id': eos_token_id, |
|
} |
|
) |
|
outputs = self.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
image_features=[image_features], |
|
image_token_id=image_token_id, |
|
**generation_config |
|
) |
|
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
return response |
|
|
|
def generate(self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: torch.LongTensor, |
|
image_features: List[torch.Tensor], |
|
image_token_id: int, |
|
generation_config: Optional[dict] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**generate_kwargs) -> torch.LongTensor: |
|
input_embeddings = self.llm.lm.embed_in(input_ids) |
|
batch_size = input_ids.shape[0] |
|
assert len(image_features) == batch_size |
|
for i in range(batch_size): |
|
special_pos = input_ids[i] == image_token_id |
|
pos = (special_pos[:-1] != special_pos[1:]).nonzero() + 1 |
|
if pos.shape[0] % 2 != 0: |
|
|
|
|
|
pos = torch.cat([torch.tensor([[0]]).to(pos.device), pos]) |
|
pos = pos.reshape(-1, 2).tolist() |
|
pos_mapping, total_images = self.get_pos_mapping(pos) |
|
assert total_images == len(image_features[i]) |
|
img_offset = 0 |
|
for j, (start, end) in enumerate(pos): |
|
num_images = pos_mapping[j] |
|
input_embeddings[i, start:end] = torch.cat( |
|
[image_features[i][img_offset+k] |
|
for k in range(num_images)], |
|
dim=0 |
|
) |
|
img_offset += num_images |
|
outputs = self.llm.generate( |
|
inputs_embeds=input_embeddings, |
|
attention_mask=attention_mask, |
|
generation_config=generation_config, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
use_cache=True, |
|
**generate_kwargs |
|
) |
|
return outputs |
|
|