from typing import List, Optional, Tuple import torch from PIL import Image from torch import nn from transformers import ( CLIPVisionModel, GenerationMixin, PreTrainedModel, PreTrainedTokenizer, ) from .catty import split_image_with_catty from .configuration_points_chat import POINTSChatConfig from .dynamic_high_resolution import split_image from .modeling_llama import CustomLlamaForCausalLM class POINTSChatModel(PreTrainedModel, GenerationMixin): config_class = POINTSChatConfig _no_split_modules = ["CLIPVisionModel", "LLamaDecoderLayer"] """Chat model for POINTS. Official implementation of the paper "POINTS: Improving Your Vision-language Model with Affordable Strategies" # noqa: E501 paper: https://huggingface.co/papers/2409.04828 Args: config (PretrainedConfig): The model config. """ def __init__(self, config: POINTSChatConfig) -> None: super().__init__(config) self.general_vit = CLIPVisionModel(config.vision_config) self.ocr_vit = CLIPVisionModel(config.vision_config) self.llm = CustomLlamaForCausalLM(config.llm_config) self.vision_projector = nn.Sequential( nn.Linear(config.vision_config.hidden_size * 4, config.llm_config.hidden_size), nn.GELU(), nn.Linear(config.llm_config.hidden_size, config.llm_config.hidden_size) ) def apply_chat_template(self, prompt: str, image_num: int) -> str: """Apply the Yi-1.5-Chat template to the prompt. Args: prompt (str): The prompt to apply the template to. image_num (int): The number of the image in the prompt. Returns: str: The prompt with the template applied. """ image_tokens = ('<|endoftext|>' * 144) * image_num prompt = f'<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n<|im_start|>assistant\n' # noqa: E501 return prompt def pixel_shuffle(self, feature_map: torch.Tensor, scale_factor: float = 0.5) -> torch.Tensor: """Implementation of pixel shuffle. Merge several patches into a single patch by concatenating them across the channel dimension. Therefore, we can reduce the image sequence length. In POINTS, we merge 2x2 adjacent patches into a single patch. Args: feature_map (torch.Tensor): The feature map to be pixel shuffled. scale_factor (float, optional): The scale factor for the """ # taken from https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5/blob/main/modeling_internvl_chat.py#L187 # noqa n, w, h, c = feature_map.size() # N, W, H, C --> N, W, H * scale, C // scale feature_map = feature_map.view( n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale feature_map = feature_map.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) feature_map = feature_map.view( n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)), ) feature_map = feature_map.permute(0, 2, 1, 3).contiguous() return feature_map def extract_image_features(self, images: torch.Tensor, vision_encoder: str = 'general_vit') -> torch.Tensor: # noqa: E501 """Extract the image features from the vision encoder. Args: images (torch.Tensor): The images to extract the features from. vision_encoder (str, optional): The vision encoder to use. Defaults to 'general_vit'. Returns: torch.Tensor: The extracted image features. """ if vision_encoder == 'general_vit': image_features = self.general_vit( images, output_hidden_states=True ) else: image_features = self.ocr_vit( images, output_hidden_states=True ) image_features = image_features.hidden_states[-2] image_features = image_features[:, 1:] image_features = image_features.reshape(-1, 24, 24, 1024) image_features = self.pixel_shuffle(image_features, 0.5) image_features = image_features.view(-1, 144, 4096) image_features = self.vision_projector(image_features) return image_features def get_pos_mapping(self, pos: List[list]) -> Tuple[dict, int]: """Get the position mapping for the images. Args: pos (List[list]): The position of the images in the prompt. Returns: Tuple[dict, int]: The position mapping and the total number of images. """ mapping = {} total_images = 0 for i, (start, end) in enumerate(pos): num_image = int((end - start) / 144) mapping[i] = num_image total_images += num_image return mapping, total_images @torch.no_grad() def chat(self, pixel_values: Image, prompt: str, tokenizer: PreTrainedTokenizer, image_processor, catty: bool = True, generation_config: dict = None, max_splits: int = 8) -> str: """Generate a response to the input prompt. Args: pixel_values (Image): The input image. prompt (str): The input prompt. tokenizer (PreTrainedTokenizer): The tokenizer to use. image_processor: The image processor to use. catty (bool, optional): Whether to use catty. Defaults to True. generation_config (dict, optional): The generation config. Defaults to None. max_splits (int, optional): The maximum number of splits. Defaults to 8. Returns: str: The generated response. """ if catty: cropped_images = split_image_with_catty(pixel_values, do_resize=True, max_crop_slices=max_splits) else: cropped_images = split_image(pixel_values, max_splits=max_splits) prompt = self.apply_chat_template(prompt, len(cropped_images)) cropped_images = image_processor.preprocess( cropped_images, return_tensors='pt')['pixel_values'] cropped_images = cropped_images.to(self.device) # extract features with general_vit general_vit_features = self.extract_image_features( cropped_images, vision_encoder='general_vit') # extract features with ocr_vit ocr_vit_features = self.extract_image_features( cropped_images, vision_encoder='ocr_vit') image_features = 0.5 * general_vit_features + 0.5 * ocr_vit_features model_inputs = tokenizer(prompt, return_tensors='pt') input_ids = model_inputs['input_ids'].to(self.device) attention_mask = model_inputs['attention_mask'].to(self.device) # stop token eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") # image token image_token_id = tokenizer.convert_tokens_to_ids("<|endoftext|>") generation_config.update( { 'eos_token_id': eos_token_id, } ) outputs = self.generate( input_ids=input_ids, attention_mask=attention_mask, image_features=[image_features], image_token_id=image_token_id, **generation_config ) response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return response def generate(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, image_features: List[torch.Tensor], image_token_id: int, generation_config: Optional[dict] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs) -> torch.LongTensor: input_embeddings = self.llm.lm.embed_in(input_ids) batch_size = input_ids.shape[0] assert len(image_features) == batch_size for i in range(batch_size): special_pos = input_ids[i] == image_token_id pos = (special_pos[:-1] != special_pos[1:]).nonzero() + 1 if pos.shape[0] % 2 != 0: # when the sequence is # we need to add a dummy token pos = torch.cat([torch.tensor([[0]]).to(pos.device), pos]) pos = pos.reshape(-1, 2).tolist() pos_mapping, total_images = self.get_pos_mapping(pos) assert total_images == len(image_features[i]) img_offset = 0 for j, (start, end) in enumerate(pos): num_images = pos_mapping[j] input_embeddings[i, start:end] = torch.cat( [image_features[i][img_offset+k] for k in range(num_images)], dim=0 ) img_offset += num_images outputs = self.llm.generate( inputs_embeds=input_embeddings, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=True, **generate_kwargs ) return outputs