lixin4ever's picture
init (#1)
44d8da2 verified
raw
history blame
28 kB
import ast
import os
import re
import math
import base64
import traceback
from io import BytesIO
from typing import Optional
import torch
import torchvision.transforms.functional as VF
import torch.nn.functional as F
import numpy as np
from transformers import StoppingCriteria
import cv2
import imageio
import ffmpeg
from PIL import Image
from decord import VideoReader, cpu
from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
from pycocotools import mask as maskUtils
def resize_image_mask(images, masks, mask_ids, patch_size=14):
resize_images = []
resize_masks = []
mask_nums = []
for i, mask in enumerate(masks):
image = images[mask_ids[i]]
h, w = image.shape[:2]
if mask.sum()==0:
print('mask is none...')
mask = torch.ones((h, w))
rows, cols = np.where(mask == 1)
min_row, max_row = rows.min(), rows.max()
min_col, max_col = cols.min(), cols.max()
bbox = (max(0,min_row-patch_size*2), max(0,min_col-patch_size*2), min(h-1, max_row+patch_size*2), min(w-1, max_col+patch_size*2))
mask_h = bbox[2] - bbox[0]
mask_w = bbox[3] - bbox[1]
cropping_img = image[bbox[0]: bbox[2], bbox[1]: bbox[3], :]
cropping_mask = mask[bbox[0]: bbox[2], bbox[1]: bbox[3]]
scale_rate = math.ceil(math.sqrt(1960/mask.sum()))
if scale_rate==1:
if (mask.sum()/196)>100:
scale_rate = math.sqrt((mask.sum()/196)/100)
scale_rate = 1/scale_rate
resize_h = math.ceil((mask_h*scale_rate)/patch_size) * patch_size
resize_w = math.ceil((mask_w*scale_rate)/patch_size) * patch_size
resize_img = cv2.resize(cropping_img, (resize_w, resize_h))
resize_mask = F.interpolate(cropping_mask[None, None], size=(resize_h//patch_size, resize_w//patch_size), mode='bilinear', align_corners=False)[0,0]
mask_nums.append(min(10, int(resize_mask.sum())))
resize_images.append(resize_img)
resize_masks.append(resize_mask)
return resize_images, resize_masks, mask_nums
def reshape_images_to_raw_grid(mm_features_raw, grid_thws):
start_idx=0
reshaped_features = []
for thw_group in grid_thws:
for tensor_thw in thw_group:
_, H, W = tensor_thw.squeeze().tolist()
num_elements = H * W
split_tensor = mm_features_raw[start_idx:start_idx + num_elements].view(H, W, -1)
reshaped_features.append(split_tensor)
start_idx += num_elements
assert len(mm_features_raw)==start_idx
return reshaped_features
def annToMask(mask_ann, h=None, w=None):
if isinstance(mask_ann, list):
rles = maskUtils.frPyObjects(mask_ann, h, w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, h, w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def chunk_list(input_list, chunk_size):
return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def grid_divide(image, cell_size):
"""
Divides an image into grid of a specified size.
Args:
image (PIL.Image.Image): The input image.
cell_size (int): The size of each cell.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
grid = []
width, height = image.size
for i in range(0, height, cell_size):
row = []
for j in range(0, width, cell_size):
box = (j, i, j + cell_size, i + cell_size)
row.append(image.crop(box))
grid.append(row)
return grid
def load_images(image_path):
if isinstance(image_path, str) and os.path.isfile(image_path):
images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
# images = [Image.open(image_path).convert('RGB')]
elif isinstance(image_path, str) and os.path.isdir(image_path):
images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
# images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
elif isinstance(image_path, list) and isinstance(image_path[0], str):
images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
# images = [Image.open(f).convert('RGB') for f in image_path]
elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
images = image_path
elif isinstance(image_path, Image.Image):
images = [image_path]
else:
print('image_path: ', image_path)
raise ValueError(f"Unsupported image path type: {image_path}")
return images
def process_pad_image(image, padding_value=(0, 0, 0)):
image = expand2square(image, padding_value)
return [image]
def find_closest_aspect_ratio(src_ratio, tgt_ratios, ori_size, tgt_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = ori_size[0] * ori_size[1]
for ratio in tgt_ratios:
tgt_ratio = ratio[0] / ratio[1]
ratio_diff = abs(src_ratio - tgt_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * tgt_size[0] * tgt_size[1] * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def process_dynamic_image(image, image_size=384, use_thumbnail=True):
# Grid Params:
min_num = 1
max_num = 12
if isinstance(image_size, int):
image_size = (image_size, image_size)
ori_size = image.size
aspect_ratio = ori_size[0] / ori_size[1]
# calculate the existing image aspect ratio
tgt_ratios = []
for n in range(min_num, max_num + 1):
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
tgt_ratios = set(tgt_ratios)
tgt_ratios = sorted(tgt_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
tgt_ratio = find_closest_aspect_ratio(aspect_ratio, tgt_ratios, ori_size, image_size)
# resize the image to the target size
tgt_width = image_size[0] * tgt_ratio[0]
tgt_height = image_size[1] * tgt_ratio[1]
resized_img = image.resize((tgt_width, tgt_height))
# NOTE: internvl2 style split the image into one column grids
# num_grids = tgt_ratio[0] * tgt_ratio[1]
# grid_images = []
# for i in range(num_grids):
# box = (
# (i % tgt_ratio[0]) * image_size[0],
# (i // tgt_ratio[0]) * image_size[1],
# (i % tgt_ratio[0] + 1) * image_size[0],
# (i // tgt_ratio[0] + 1) * image_size[1],
# )
# # crop out the grid image
# grid_images.append(resized_img.crop(box))
# assert len(grid_images) == num_grids
# grid_images = [grid_images]
# NOTE: eager implementation
# num_grids = tgt_ratio[0] * tgt_ratio[1]
# sub_grid_images = []
# tmp_grid_images = []
# for i in range(num_grids):
# box = (
# (i % tgt_ratio[0]) * image_size[0],
# (i // tgt_ratio[0]) * image_size[1],
# (i % tgt_ratio[0] + 1) * image_size[0],
# (i // tgt_ratio[0] + 1) * image_size[1],
# )
# tmp_grid_images.append(resized_img.crop(box))
# if (i + 1) % tgt_ratio[0] == 0:
# sub_grid_images.append(tmp_grid_images)
# tmp_grid_images = []
image_grid = grid_divide(resized_img, image_size[0])
if use_thumbnail:
thumbnail_img = image.resize((image_size[0], image_size[1]))
image_grid = [[thumbnail_img]] + image_grid
return image_grid
def process_highres_image(image_path, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
# Grid Params:
grid_width = [1, 2, 3]
grid_width_real = [x * image_size for x in grid_width]
longest_side = max(image.size)
fit_grid_width_real = [x for x in grid_width_real if x >= longest_side]
if len(fit_grid_width_real) == 0:
select_size = max(grid_width_real)
else:
select_size = min(fit_grid_width_real)
image_padded = expand2square(image, padding_value)
image_padded = image_padded.resize((select_size, select_size))
image_grid = grid_divide(image_padded, image_size)
if use_thumbnail:
thumbnail_img = image.resize((image_size, image_size))
image_grid = [[thumbnail_img]] + image_grid
return image_grid
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def process_anyres_image(image, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# Grid Params:
possible_grids = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
possible_resolutions = [(x * image_size, y * image_size) for x, y in possible_grids]
best_resolution = select_best_resolution(image.size, possible_resolutions)
# resize and padding image
nw, nh = best_resolution
ow, oh = image.size
scale_factor = min(nw / ow, nh / oh)
new_size = (int(ow * scale_factor), int(oh * scale_factor))
image_padded = Image.new("RGB", (nw, nh), padding_value)
image_padded.paste(image.resize(new_size), ((nw - new_size[0]) // 2, (nh - new_size[1]) // 2))
image_grid = grid_divide(image_padded, image_size)
if use_thumbnail:
thumbnail_img = image.resize((image_size, image_size))
image_grid = [[thumbnail_img]] + image_grid
return image_grid
def process_adares_image(image_path, image_size=384, use_thumbnail=True):
# Grid Params:
min_num = 1
max_num = 12
if isinstance(image_size, int):
image_size = (image_size, image_size)
ori_size = image.size
aspect_ratio = ori_size[0] / ori_size[1]
# calculate the existing image aspect ratio
tgt_ratios = []
for n in range(min_num, max_num + 1):
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
tgt_ratios = set(tgt_ratios)
possible_resolutions = [(x * image_size[0], y * image_size[1]) for x, y in tgt_ratios]
# find the most possible resolution
best_resolution = select_best_resolution(ori_size, possible_resolutions)
# resize the image to the target size
resized_img = image.resize((best_resolution[0], best_resolution[1]))
image_grid = grid_divide(resized_img, image_size[0])
if use_thumbnail:
thumbnail_img = image.resize((image_size[0], image_size[1]))
image_grid = [[thumbnail_img]] + image_grid
return image_grid
def process_images(image_path, processor, aspect_ratio='pad', image_size=384, use_thumbnail=True):
images = load_images(image_path)
padding_value = tuple(int(x*255) for x in processor.image_mean)
image_grids = []
for image in images:
if aspect_ratio == 'pad':
image_grid = process_pad_image(image, padding_value=padding_value)
elif aspect_ratio == 'dynamic':
image_grid = process_dynamic_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
elif aspect_ratio == 'highres':
image_grid = process_highres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
elif aspect_ratio == 'anyres':
image_grid = process_anyres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
elif aspect_ratio == 'adares':
image_grid = process_adares_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
else:
image_grid = [image]
image_grid = [processor.preprocess(image_row, return_tensors='pt', num_images=len(images)) for image_row in image_grid]
image_grids.append(image_grid)
return image_grids
def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None):
if mode == 'uniform':
assert num_frames is not None, "Number of frames must be provided for uniform sampling."
if duration <= num_frames:
return np.arange(duration).astype(int)
# NOTE: v1 version
# Calculate the size of each segment from which a frame will be extracted
# if duration <= num_frames:
# return np.arange(duration).astype(int)
# seg_size = float(duration - 1) / num_frames
# frame_ids = []
# for i in range(num_frames):
# # Calculate the start and end indices of each segment
# start = seg_size * i
# end = seg_size * (i + 1)
# # Append the middle index of the segment to the list
# frame_ids.append((start + end) / 2)
# return np.round(np.array(frame_ids) + 1e-6).astype(int)
# NOTE: v0 version
return np.linspace(0, duration-1, num_frames, dtype=int)
elif mode == 'fps':
assert vid_fps is not None, "FPS must be provided for FPS sampling."
fps = fps if fps is not None else NUM_FRAMES_PER_SECOND
segment_len = min(vid_fps // fps, duration)
return np.arange(segment_len // 2, duration, segment_len, dtype=int)
else:
raise ImportError(f'Unsupported frame sampling mode: {mode}')
def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=None, temporal_factor=1, frame_ids=None):
if s is not None and e is not None:
s = s if s >= 0. else 0.
e = e if e >= 0. else 0.
if s > e:
s, e = e, s
elif s == e:
e = s + 1
# 1. Loading Video
if os.path.isdir(video_path):
frame_files = sorted(os.listdir(video_path))
vid_fps = 3
num_frames_of_video = len(frame_files)
elif video_path.endswith('.gif'):
gif_reader = imageio.get_reader(video_path)
vid_fps = 25
num_frames_of_video = len(gif_reader)
else:
vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
# vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
vid_fps = vreader.get_avg_fps()
num_frames_of_video = len(vreader)
# 2. Determine frame range & Calculate frame indices
f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
frame_indices = list(range(f_start, f_end + 1))
duration = len(frame_indices)
# 3. Sampling frame indices
max_frames = max_frames if max_frames is not None else MAX_FRAMES
if fps is not None and duration / vid_fps < max_frames:
try:
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps)]
except:
print('sampled_frame_indices error: ', )
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
else:
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
# 4. Acquire frame data
if os.path.isdir(video_path):
frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices]
elif video_path.endswith('.gif'):
frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
else:
frames = vreader.get_batch(sampled_frame_indices).asnumpy()
# frames = frames.transpose(0, 3, 1, 2)
timestamps = [x / vid_fps for x in sampled_frame_indices]
if temporal_factor > 1:
pad_length = temporal_factor - len(frames) % temporal_factor
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
# NOTE: pad the video with black frames
# while num_frames is not None and len(video_data) < num_frames:
# video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
additional_frames = []
if frame_ids is not None:
if os.path.isdir(video_path):
additional_frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in frame_ids]
elif video_path.endswith('.gif'):
additional_frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in frame_ids]
else:
additional_frames = vreader.get_batch(frame_ids).asnumpy()
return frames, timestamps, additional_frames
def load_video(
video_path: str,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
fps: Optional[float] = None,
max_frames: Optional[float] = None,
size: Optional[int] = None,
size_divisible: int = 1,
precise_time: bool = False,
verbose: bool = False,
temporal_factor: int = 1,
frame_ids = None
):
"""
Load and process a video file and return the frames and the timestamps of each frame.
Args:
video_path (str): Path to the video file.
start_time (float, optional): Start time in seconds. Defaults to None.
end_time (float, optional): End time in seconds. Defaults to None.
fps (float, optional): Frames per second. Defaults to None.
num_frames (float, optional): Number of frames to sample. Defaults to None.
size (int, optional): Size of the shortest side. Defaults to None.
size_divisible (int, optional): Size divisible by this number. Defaults to 1.
precise_time (bool, optional): Whether to use precise time. Defaults to False.
verbose (bool, optional): Print ffmpeg output. Defaults to False.
Returns:
frames (List[PIL.Image]): List of frames.
timestamps (List[float]): List of timestamps.
"""
if start_time is not None and end_time is not None and end_time - start_time < 1:
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
if os.path.isdir(video_path):
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
if video_path.endswith('.gif'):
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
probe = ffmpeg.probe(video_path)
duration = float(probe['format']['duration'])
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
w, h = int(video_stream['width']), int(video_stream['height'])
kwargs, input_kwargs, output_kwargs = {}, {}, {}
do_trim = start_time is not None or end_time is not None
if start_time is not None:
new_start_time = max(float(video_stream['start_time']), start_time)
duration -= new_start_time - start_time
start_time = new_start_time
else:
start_time = float(video_stream['start_time'])
if end_time is not None:
duration = min(duration, end_time - start_time)
else:
duration = duration
if do_trim:
kwargs = {'ss': start_time, 't': duration}
if precise_time:
output_kwargs.update(kwargs)
else:
input_kwargs.update(kwargs)
if size is not None:
scale_factor = size / min(w, h)
new_w, new_h = round(w * scale_factor), round(h * scale_factor)
else:
new_w, new_h = w, h
new_w = new_w // size_divisible * size_divisible
new_h = new_h // size_divisible * size_divisible
# NOTE: It may result in unexpected number of frames in ffmpeg
# if calculate the fps directly according to max_frames
# NOTE: the below lines may hurt the performance
# if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
# fps = max_frames / duration * 2
stream = ffmpeg.input(video_path, **input_kwargs)
if fps is not None:
stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
if new_w != w or new_h != h:
stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
if fps is not None:
timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
else:
timestamps = np.linspace(start_time, start_time + duration, len(frames))
max_frames = max_frames if max_frames is not None else MAX_FRAMES
if max_frames is not None and len(frames) > max_frames:
indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
frames = frames[indices]
timestamps = [timestamps[i] for i in indices]
if temporal_factor > 1:
pad_length = temporal_factor - len(frames) % temporal_factor
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
frames = [frame for frame in frames]
additional_frames = []
# print('frame_ids', frame_ids)
if frame_ids is not None:
vr = VideoReader(video_path, ctx=cpu(0))
additional_frames = vr.get_batch(frame_ids).asnumpy()
return frames, timestamps, additional_frames
def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=None):
fps = 1 if num_frames is None else None
# FFmpeg
frames, timestamps = load_video(video_path, s, e, fps=fps, max_frames=num_frames)
# Decord
# frames, timestamps = load_video_from_ids(video_path, s, e, fps=fps, max_frames=num_frames)
assert len(frames) == len(timestamps), "Number of frames and timestamps must match."
if aspect_ratio == 'pad':
frames = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in frames]
if aspect_ratio == 'qwen2vl':
frames = [processor.preprocess(frame, return_tensors='pt', image_num=len(frames)) for frame in frames]
grid_frames = [frames]
else:
frames = processor.preprocess(frames, return_tensors='pt', image_num=len(frames))
grid_frames = [[frames]]
return grid_frames, timestamps
def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
"""Tokenize text and multimodal tag to input_ids.
Args:
prompt (str): Text prompt (w/ multimodal tag), e.g., '<video>\nDescribe the video.'
tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
multimodal_token (int): Token index corresponding to the multimodal tag.
"""
multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
if multimodal_token_index is None:
input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
else:
prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for idx, chunk in enumerate(prompt.split(multimodal_token))]
input_ids = []
for i in range(1, 2 * len(prompt_chunks)):
if i % 2 == 1:
input_ids.extend(prompt_chunks[i // 2])
else:
input_ids.append(multimodal_token_index)
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)