Keye-VL-1_5-8B / processing_keye_vl_1_5.py
Kwai-Keye's picture
Update model files
9e03dd0 verified
# coding=utf-8
# Copyright 2025 The Kwai Keye Team and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union, Optional
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from .image_processing_keye_vl_1_5 import KeyeVL1_5ImageProcessor
import torch
import torch.nn as nn
import numpy as np
from itertools import chain
class KeyeVL1_5VideosProcessorKwargs(VideosKwargs, total=False):
fps: Optional[Union[List[float], float]]
# 准备reszie到的width(slow)
width: Optional[Union[List[int], int]]
# 准备reszie到的height(slow)
height: Optional[Union[List[int], int]]
# 准备resize到的width(fast)
fast_width: Optional[Union[List[int], int]]
# 准备resize到的height(fast)
fast_height: Optional[Union[List[int], int]]
# 用于标记每一帧的时间戳,数量和帧数相等
timestamps: Optional[Union[List[torch.Tensor], torch.Tensor]]
# 用于标记每一帧的类型是slow还是fast,slow=0, fast=1
frame_types: Optional[Union[List[torch.Tensor], torch.Tensor]]
class KeyeVL1_5ProcessorKwargs(ProcessingKwargs, total=False):
videos_kwargs: KeyeVL1_5VideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"videos_kwargs": {"fps": 2.0},
}
def select_slow_fast_frames(frames: torch.Tensor, frame_types: torch.Tensor):
"""
Selects frames from a tensor based on a mask list.
Args:
frames (torch.Tensor): A tensor of shape (nframes, c, h, w).
frame_types (torch.Tensor): A int tensor of shape (nframes,)
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing two tensors:
- slow_frames: Frames which the type is 0.
- fast_frames: Frames where the type is 1.
"""
nframes, _, _, _ = frames.shape
if frame_types.shape[-1] != nframes:
raise ValueError("Length of mask must be equal to the number of frames.")
mask = (frame_types == 0)
slow_frames = frames[mask]
fast_frames = frames[~mask]
return slow_frames, fast_frames
def split_thw(tensor):
"""Split grid_thw in t dimension, the result tensor should like [[1, h, w],...]"""
repeats = tensor[:, 0]
new_thw = torch.cat([
torch.ones(tensor.shape[0], 1, dtype=tensor.dtype,
device=tensor.device),
tensor[:, 1:]
], dim=1)
return torch.repeat_interleave(new_thw, repeats, dim=0)
def merge_hws(hws):
"""
优化版本:使用更高效的方法合并张量
"""
merged = []
last_hw = [-1, -1]
for hw in hws:
# 找到连续相同形状的张量
if hw[1:] == last_hw:
merged[-1][0] += 1
else:
merged.append(hw)
last_hw = hw[1:]
return torch.tensor(merged)
class KeyeVL1_5Processor(ProcessorMixin):
r"""
[`KeyeVL1_5Processor`] offers all the functionalities of [`KeyeVL1_5ImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~KeyeVL1_5Processor.__call__`] and [`~KeyeVL1_5Processor.decode`] for more information.
Args:
image_processor ([`KeyeVL1_5ImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template","image_std", "min_pixels", "image_mean", "merge_size", "image_processor_type",
"temporal_patch_size", "patch_size", "max_pixels"
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
self.frame_token = "<|frame|>" if not hasattr(tokenizer, "frame_token") else tokenizer.frame_token
self.fast_video_token = "<|fast_video_pad|>" if not hasattr(tokenizer, "fast_video_token") else tokenizer.fast_video_token
self.fast_start = "<|fast_start|>" if not hasattr(tokenizer, "fast_start") else tokenizer.fast_start
self.fast_end = "<|fast_end|>" if not hasattr(tokenizer, "fast_end") else tokenizer.fast_end
super().__init__(image_processor, tokenizer, chat_template=chat_template)
self.slowfast = True
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
videos: VideoInput = None,
**kwargs: Unpack[KeyeVL1_5ProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
KeyeVL1_5ImageProcessor's [`~KeyeVL1_5ImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
KeyeVL1_5ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
# slow_images = images
image_inputs = self.image_processor(images=images, return_tensors="pt")
image_inputs['pixel_values'] = image_inputs['pixel_values']
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
num_frames = []
if videos is not None:
batch_slow_frames = []
batch_fast_frames = []
videos_kwargs = output_kwargs["videos_kwargs"]
num_videos = len(videos)
batch_frame_types = videos_kwargs.get("frame_types", [None] * num_videos)
batch_timestamps = videos_kwargs.get("timestamps", [None] * num_videos)
batch_width = videos_kwargs.get("width", [None] * num_videos)
batch_height = videos_kwargs.get("height", [None] * num_videos)
batch_fast_width = videos_kwargs.get("fast_width", [None] * num_videos)
batch_fast_height = videos_kwargs.get("fast_height", [None] * num_videos)
for index, frames in enumerate(videos):
if isinstance(frames, np.ndarray):
frames = torch.from_numpy(frames)
nframes = frames.shape[0]
num_frames.append(nframes)
assert nframes > 0, "No frames in video"
if batch_frame_types[index] is None:
# default to all slow frames
batch_frame_types[index] = torch.zeros((nframes, ), dtype=torch.long)
frame_types = batch_frame_types[index]
slow_frames, fast_frames = select_slow_fast_frames(frames, frame_types)
has_fast_frames = fast_frames.shape[0] > 0
# resize slow frames
resized_width = batch_width[index]
resized_height = batch_height[index]
if resized_width is not None and resized_height is not None:
slow_frames = nn.functional.interpolate(
slow_frames,
[resized_height, resized_width],
mode="bilinear",
antialias=True,
).float()
do_resize = False
else:
slow_frames = slow_frames.float()
do_resize = True
# Tensor(N, C, H, W) -> Tuple[Tensor(1, C, H, W)]
# slow_frames = list(slow_frames.split(1, dim=0)),不split,在模型里面做
slow_video_inputs = self.image_processor(
images=None, videos=[slow_frames], **output_kwargs["images_kwargs"], do_resize=do_resize)
slow_video_grid_thw = slow_video_inputs["video_grid_thw"]
batch_slow_frames.append(slow_video_inputs)
# # 当前这个视频每一帧的token数
# slow_frames_patch_nums[index] = int(slow_video_inputs["pixel_values_videos"].shape[0] / \
# slow_video_grid_thw.squeeze()[0])
if has_fast_frames:
# TODO: shrink fast_frames
fast_resized_width = batch_fast_width[index]
fast_resized_height = batch_fast_height[index]
if fast_resized_width is not None and fast_resized_height is not None:
fast_frames = nn.functional.interpolate(
fast_frames,
[fast_resized_height, fast_resized_width],
mode="bilinear",
antialias=True,
).float()
do_fast_resize = False
else:
fast_frames = fast_frames.float()
do_fast_resize = True
# Tensor(N, C, H, W) -> Tuple[Tensor(1, C, H, W)]
# fast_frames = list(fast_frames.split(1, dim=0))
fast_video_inputs = self.image_processor(
images=None, videos=[fast_frames], **output_kwargs["images_kwargs"], do_resize=do_fast_resize)
fast_video_grid_thw = fast_video_inputs["video_grid_thw"]
batch_fast_frames.append(fast_video_inputs)
# # 当前这个视频的所有token数
# fast_frames_token_nums[index] = int(fast_video_inputs["pixel_values_videos"].shape[0] / \
# fast_video_grid_thw.squeeze()[0])
assert len(batch_slow_frames) > 0, "Slow frames should not be empty."
slow_pixel_values_videos_list = [
video["pixel_values_videos"] for video in batch_slow_frames if video is not None]
slow_video_grid_thw_list = [
video["video_grid_thw"] for video in batch_slow_frames if video is not None]
slow_pixel_values_videos = torch.concat(slow_pixel_values_videos_list, dim=0)
slow_video_grid_thw = torch.concat(slow_video_grid_thw_list, dim=0)
if has_fast_frames:
fast_pixel_values_videos_list = [
video["pixel_values_videos"] for video in batch_fast_frames \
if video is not None]
fast_video_grid_thw_list = [
video["video_grid_thw"] for video in batch_fast_frames \
if video is not None]
fast_pixel_values_videos = \
torch.concat(fast_pixel_values_videos_list, dim=0)
fast_video_grid_thw = \
torch.concat(fast_video_grid_thw_list, dim=0)
else:
fast_video_grid_thw = None
else:
slow_video_grid_thw = None
fast_video_grid_thw = None
if not isinstance(text, list):
text = [text]
if image_grid_thw is not None:
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
image_place_holder_tempale = "<|placeholder|>" * (
image_grid_thw[index].prod() // self.image_processor.merge_size ** 2)
text[i] = text[i].replace(
self.image_token,
image_place_holder_tempale,
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
pixel_values_videos = []
video_grid_thw = []
videos_inputs = {}
if slow_video_grid_thw is not None:
slow_video_grid_thw = split_thw(slow_video_grid_thw)
if fast_video_grid_thw is not None:
fast_video_grid_thw = split_thw(fast_video_grid_thw)
index = 0
slow_index = 0
fast_index = 0
slow_pixels_index = 0
fast_pixels_index = 0
for i in range(len(text)):
while self.video_token in text[i]:
video_place_holder_tempale = ""
for j in range(batch_frame_types[index].shape[-1]):
if batch_timestamps[index] is not None: # 如果有时间戳
video_place_holder_tempale += self.frame_token + format(batch_timestamps[index][j], ".1f")
else:
video_place_holder_tempale += self.frame_token
# 当前帧是slow
if batch_frame_types[index][j] == 0:
num_patches = int(slow_video_grid_thw[slow_index].prod())
video_place_holder_tempale += "<|placeholder|>" * (
num_patches // self.image_processor.merge_size ** 2)
pixel_values_videos.append(
slow_pixel_values_videos[slow_pixels_index:slow_pixels_index + num_patches])
slow_pixels_index = slow_pixels_index + num_patches
video_grid_thw.append(slow_video_grid_thw[slow_index].tolist())
slow_index += 1
# 当前帧是fast
elif batch_frame_types[index][j] == 1:
num_patches = int(fast_video_grid_thw[fast_index].prod())
video_place_holder_tempale += self.fast_start + "<|placeholder|>" * (
num_patches // self.image_processor.merge_size ** 2) + \
self.fast_end
pixel_values_videos.append(
fast_pixel_values_videos[fast_pixels_index:fast_pixels_index + num_patches])
fast_pixels_index = fast_pixels_index + num_patches
video_grid_thw.append(fast_video_grid_thw[fast_index].tolist())
fast_index += 1
text[i] = text[i].replace(
self.video_token,
video_place_holder_tempale,
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
videos_inputs["pixel_values_videos"] = torch.cat(pixel_values_videos, dim=0)
videos_inputs["video_grid_thw"] = merge_hws(video_grid_thw)
videos_inputs["num_frames"] = torch.tensor(num_frames)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
return names_from_processor
__all__ = ["KeyeVL1_5Processor"]