Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Alibaba DAMO Academy | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
import os | |
import re | |
import einops | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from timm.models.layers import LayerNorm, LayerNorm2d | |
from timm.models.regnet import RegStage | |
from transformers import TRANSFORMERS_CACHE | |
def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"): | |
revision = "main" | |
# 1. parse the downloaded cache folder | |
if cache_dir is None: | |
cache_dir = TRANSFORMERS_CACHE | |
else: | |
cache_dir = cache_dir | |
object_id = repo_id.replace("/", "--") | |
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") | |
# 2. resolve refs (for instance to convert main to the associated commit sha) | |
refs_dir = os.path.join(repo_cache, "refs") | |
if os.path.isdir(refs_dir): | |
revision_file = os.path.join(refs_dir, revision) | |
if os.path.isfile(revision_file): | |
with open(revision_file) as f: | |
revision = f.read() | |
# 3. acquire the snapshot folder | |
folder = os.path.join(repo_cache, "snapshots", revision) | |
return folder | |
def load_mm_projector(model_path, cache_dir=None, token=None): | |
if os.path.exists(os.path.join(model_path, 'mm_projector.bin')): | |
is_local = True | |
folder = model_path | |
else: | |
is_local = False | |
folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model") | |
if not os.path.exists(os.path.join(folder, 'mm_projector.bin')): | |
# downloading from remote repo | |
from huggingface_hub import snapshot_download | |
snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token) | |
mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu') | |
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} | |
return mm_projector_weights | |
class IdentityMap(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x, *args, **kwargs): | |
return x | |
def config(self): | |
return {"mm_projector_type": 'identity'} | |
def build_mlp(depth, hidden_size, output_hidden_size): | |
modules = [nn.Linear(hidden_size, output_hidden_size)] | |
for _ in range(1, depth): | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(output_hidden_size, output_hidden_size)) | |
return nn.Sequential(*modules) | |
class SimSpatialConv(nn.Module): | |
def __init__(self, config, downsample=(2, 2), padding=1, depth=1, mlp_depth=2): | |
super().__init__() | |
self.encoder_hidden_size = encoder_hidden_size = config.mm_hidden_size | |
self.output_hidden_size = output_hidden_size = config.hidden_size | |
self.downsample = downsample | |
self.padding = padding | |
self.sampler = nn.Sequential( | |
nn.Conv2d( | |
in_channels=self.encoder_hidden_size, | |
out_channels=4 * self.encoder_hidden_size, | |
kernel_size=self.downsample, | |
stride=self.downsample, | |
padding=self.padding, | |
bias=True | |
), | |
nn.SiLU(), | |
) | |
self.readout = build_mlp(mlp_depth, 4 * self.encoder_hidden_size, self.output_hidden_size) | |
def forward(self, x): | |
hw = int(x.size(1) ** 0.5) | |
x = einops.rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) | |
x = self.sampler(x) | |
x = einops.rearrange(x, "b d h w -> b (h w) d") | |
x = self.readout(x) | |
return x | |
def cal_proj_size(self, input_size): | |
if isinstance(input_size, int): | |
input_size = (input_size, input_size) | |
height = math.ceil((input_size[0] + self.padding) / self.downsample[0]) | |
width = math.ceil((input_size[1] + self.padding) / self.downsample[1]) | |
return height * width | |
class MlpGeluProjector(nn.Module): | |
def __init__(self, config, projector_type): | |
super().__init__() | |
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) | |
mlp_depth = int(mlp_gelu_match.group(1)) | |
self.readout = build_mlp(mlp_depth, config.mm_hidden_size, config.hidden_size) | |
def forward(self, x): | |
x = self.readout(x) | |
return x | |
def cal_proj_size(self, input_size): | |
if isinstance(input_size, int): | |
input_size = (input_size, input_size) | |
height = input_size[0] | |
width = input_size[1] | |
return height * width | |
def build_vision_projector(config, delay_load=False, **kwargs): | |
# videollama3 projector only support image-wise operation now, i.e., prohibit the temporal aggregation | |
projector_type = getattr(config, 'mm_projector_type', 'linear') | |
if projector_type == "linear": | |
# NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features | |
return nn.Linear(config.mm_hidden_size, config.hidden_size) | |
elif projector_type == "simp_spatial_conv": | |
return SimSpatialConv(config) | |
elif projector_type.startswith("mlp"): | |
return MlpGeluProjector(config, projector_type) | |
if projector_type == 'identity': | |
return IdentityMap() | |
raise ValueError(f'Unknown projector type: {projector_type}') | |