# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import copy import json import os import pathlib import random import re import sys import warnings import traceback from packaging import version from dataclasses import dataclass, field from typing import Dict, List, Optional, Sequence import numpy as np # torch-related packages # NOTE: torch must be imported before transformers. Otherwise, `Segmentation fault (core dumped)` will occur. import torch import transformers from packaging import version from datasets import load_dataset, concatenate_datasets from torch.utils.data import Dataset from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock sys.path.append('./') from videollama3.constants import (IGNORE_INDEX, MODAL_INDEX_MAP, NUM_FRAMES, DEFAULT_IMAGE_TOKEN, STREAM_MAX_FRAMES, STREAM_DOWNSAMPLING, STREAM_FPS, STREAM_IMAGE_SIZE, STREAM_START_TOKEN, STREAM_END_TOKEN, REGION_TOKEN) from videollama3.mm_utils import (load_images, load_video, tokenizer_multimodal_token, annToMask, resize_image_mask) from videollama3.model import * from videollama3.videollama3_trainer import ( VideoLLaMA3Trainer, find_all_linear_names, get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer) from videollama3.model.processor import Videollama3Processor # NOTE: fast tokenizer warning issue: https://github.com/huggingface/transformers/issues/5486 os.environ["TOKENIZERS_PARALLELISM"] = "true" local_rank = None def rank0_print(*args): if local_rank == 0: print(*args) def set_seed(seed=42): """ Set the random seed for reproducible results. :param seed: An integer value to be used as the random seed. """ torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # for multi-GPU setups torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def int_with_none(value): if value == 'None': return None return int(value) @dataclass class ModelArguments: # LLM Arguments model_type: Optional[str] = field(default="videollama3", metadata={"help": "Model type selected in the list: " + ", ".join(VLLMs.keys())}) model_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.5") version: Optional[str] = field(default="v1", metadata={"help": "Version of the conversation template."}) freeze_backbone: bool = field(default=False, metadata={"help": "Whether to freeze the LLM backbone."}) # Connector Arguments mm_projector_type: Optional[str] = field(default='linear') pretrain_mm_projector: Optional[str] = field(default=None) # Vision tower Arguments vision_encoder: Optional[str] = field(default=None) mm_vision_select_layer: Optional[int] = field(default=-1) mm_vision_select_feature: Optional[str] = field(default="patch") mm_attn_implementation: Optional[str] = field(default="flash_attention_2") # Token downsampling Arguments spatial_merge_size: Optional[int] = field(default=1) mm_max_length: Optional[int] = field(default=9477) use_token_compression: Optional[bool] = field(default=False) @dataclass class DataArguments: # Path Arguments data_path: List[str] = field(default=None, metadata={"help": "Path to the training data."}) # image_folder: Optional[str] = field(default=None) # video_folder: Optional[str] = field(default=None) data_folder: Optional[str] = field(default=None) # Loading Arguments is_multimodal: bool = False fps: Optional[int] = field(default=None) max_frames: Optional[int_with_none] = field(default=None) # Preprocess Arguments image_aspect_ratio: str = 'square' use_batch_flattening: bool = field(default=True, metadata={"help": "Whether to flatten the in-batch sequences of variable lengths."}) dataset_cache_dir: Optional[str] = field(default=None) @dataclass class TrainingArguments(transformers.TrainingArguments): # shut auto processing (_remove_unused_columns) of transformers Trainer remove_unused_columns: bool = field(default=False) optim: str = field(default="adamw_torch") # Training learning rate Arguments vision_encoder_lr: Optional[float] = None mm_projector_lr: Optional[float] = None llm_lr: Optional[float] = None region_encoder_lr: Optional[float] = None # Training Data Arguments group_by_modality_length: bool = field(default=False) model_max_length: int = field( default=512, metadata={ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." }, ) # Lora or Quant Arguments double_quant: bool = field( default=True, metadata={"help": "Compress the quantization statistics through double quantization."} ) quant_type: str = field( default="nf4", metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} ) bits: int = field( default=16, metadata={"help": "How many bits to use."} ) lora_enable: bool = False lora_r: int = 64 lora_alpha: int = 16 lora_dropout: float = 0.05 lora_weight_path: str = "" lora_bias: str = "none" class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__(self, data_path: str, vlprocessor, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() data_objs = [] # try: # for data in data_path: # # NOTE: load_dataset can process both json or jsonl files # if data.endswith(".json") or data.endswith(".jsonl"): # data_objs.append(load_dataset("json", data_files=data, cache_dir=data_args.dataset_cache_dir)["train"]) # else: # raise Exception(f"Unsupported file format (<{data}>)!") # list_data_dict = concatenate_datasets(data_objs) # except: traceback.print_exc() # NOTE: compatible with the old version list_data_dict = [] for data in data_path: if data.endswith(".json"): data = json.load(open(data, "r")) for i in data: i['id'] = len(list_data_dict) list_data_dict.append(i) elif data.endswith(".jsonl"): with open(data, "r", encoding="utf-8") as fp: for line in fp: line = line.strip() obj = json.loads(line) obj["id"] = len(list_data_dict) list_data_dict.append(obj) else: raise Exception(f"Unsupported file format (<{data}>)!!!") rank0_print("Formatting inputs...Skip in lazy mode") self.vlprocessor = vlprocessor self.list_data_dict = list_data_dict self.data_args = data_args def __len__(self): return len(self.list_data_dict) @property def lengths(self): length_list = [] for sample in self.list_data_dict: img_tokens = 576 if 'image' in sample else 0 length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) return length_list @property def modality_lengths(self): length_list = [] for sample in self.list_data_dict: cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) cur_len = cur_len if 'image' in sample else -cur_len length_list.append(cur_len) return length_list def _convert_normal(self, data_dict): data_folder = self.data_args.data_folder conversation = copy.deepcopy(data_dict["conversations"]) # data sanity check and repair start_idx = 0 for sentence in conversation: if sentence["from"] == "human" or sentence["from"] == "system": break start_idx += 1 if start_idx > 0: warnings.warn(f"Find {start_idx} non-user sentences at the beginning of the conversation, remove them automatically!") conversation = conversation[start_idx:] assert len(conversation) > 1, f"Invalid conversation" additional_frames = [] mask_ids = [] if 'image' in data_dict and data_dict['image'] is not None: modal = 'image' if all(not "" in sentence["value"] for sentence in conversation): warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") conversation[0]["value"] = "" + conversation[0]["value"] image_file = data_dict['image'] if isinstance(image_file, list): image_file = [os.path.join(data_folder, f) for f in image_file] else: image_file = os.path.join(data_folder, image_file) images = load_images(image_file) masks = [] if 'masks' in data_dict and data_dict['masks'] is not None and len(data_dict['masks'])>0: if 'height' in data_dict: h = data_dict['height'] w = data_dict['width'] else: h = None w = None for ann in data_dict['masks']: mask = annToMask(ann, h, w) masks.append(mask) mask_ids.append(0) masks = np.stack(masks, axis=0) masks = torch.from_numpy(masks) additional_frames = images.copy() else: masks = None elif 'video' in data_dict and data_dict['video'] is not None: modal = 'video' if all(not "