Feature Extraction
Transformers
PyTorch
English
Chinese
LimengQiao commited on
Commit
a15d355
·
1 Parent(s): 5d10b26

add: UniViTAR models

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. README.md +85 -3
  3. config.json +19 -0
  4. modeling_univitar.py +604 -0
  5. pytorch_model.bin +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,85 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - mlfoundations/datacomp_1b
5
+ - kakaobrain/coyo-700m
6
+ - laion/laion400m
7
+ language:
8
+ - en
9
+ - zh
10
+ metrics:
11
+ - accuracy
12
+ - recall
13
+ pipeline_tag: feature-extraction
14
+ library_name: transformers
15
+ ---
16
+
17
+ <h1 align="center">Unified Vision Transformer with Native Resolution</h1>
18
+
19
+
20
+ ## 🌠 Introduction
21
+
22
+ We present **UniViTAR**, a family of homogeneous vision foundation models tailored **for unified visual modality and native resolution scenario** in the era of multimodal. We train our UniViTAR family across multiple model scales from **0.3B to 1.4B** exclusively on public accessible image-caption data (14.6B), and observe a trend of performance increasing with parameter scaling. UniViTAR is a Transformer-based encoder model that inherits the original architecture of the conventional Vision Transformer but incorporates the following advanced modifications: *Unified Patchify for Native Image and Video Modality, 2D RoPE, SwiGLU, RMSNorm, and QK-Norm*.
23
+
24
+
25
+ ## 🛠️ Environment
26
+ ```bash
27
+ conda create -n univitar python=3.11 -y
28
+ conda activate univitar
29
+ pip3 install einops==0.8.0 ninja==1.11.1.1 numpy==1.26.4 pillow==10.4.0 psutil==6.0.0 torch==2.2.2 torchvision==0.17.2 transformers==4.49.0 timm==1.0.14
30
+ pip3 install flash-attn==2.6.3
31
+ ```
32
+
33
+
34
+ ## 🗝️ Model Usage
35
+
36
+ ```python
37
+ import torch
38
+ import numpy as np
39
+ from PIL import Image
40
+ from modeling_univitar import UniViTARVisionModel
41
+
42
+ # Prepare Model
43
+ model = UniViTARVisionModel("config.json")
44
+ _ = model.load_state_dict(torch.load(f"pytorch_model.bin", map_location="cpu"))
45
+ model = model.to(torch.bfloat16).cuda()
46
+
47
+ # Prepare Data: [(3, H1, W1), ..., (3, Hn, Wn)] --> (N1+...+Nn, P)
48
+ images = [Image.open(f"xx1.jpg"), Image.open(f"xx2.jpg")]
49
+ data_inputs, grid_shapes = [], []
50
+ for image in images:
51
+ data_item = model.image_transform(image)
52
+ input_data, grid_shape = model.data_patchify(data_item)
53
+ data_inputs.append(input_data.to(torch.bfloat16).cuda())
54
+ grid_shapes.append(grid_shape)
55
+ data_inputs = torch.concatenate(data_inputs, dim=0)
56
+
57
+ # Forward: (N1+...+Nn, P) --> [(N1, D), ..., (Nn, D)]
58
+ data_embeds = model(pixel_values=data_inputs, grid_shapes=grid_shapes)
59
+ data_embeds = data_embeds.split([np.prod(grid_shape) for grid_shape in grid_shapes])
60
+ print(data_embeds[0].shape, data_embeds[1].shape)
61
+ ```
62
+
63
+ ## 📈 Evaluation
64
+
65
+ | Model | Size | \#Seen | IN1K<sup>ZS<sup> | IN1K<sup>LP<sup> | Flickr<sup>T2I<sup> | Flickr<sup>I2T<sup> | K400<sup>ZS<sup> | ADE20K |
66
+ |--------|-----|----|------|------|------|------|------|------|
67
+ | [UniViTAR-0.3B](https://huggingface.co/MM-MVR/UniViTAR-0.3B) | 310M | 14.6B | 81.5 | 87.7 | 84.0 | 95.1 | 66.0 | 54.6 |
68
+ | [UniViTAR-0.6B](https://huggingface.co/MM-MVR/UniViTAR-0.6B) | 637M | 14.6B | 82.3 | 88.3 | 84.1 | 95.5 | 68.6 | 55.1 |
69
+ | [UniViTAR-1B](https://huggingface.co/MM-MVR/UniViTAR-1B) | 1419M | 14.6B | 82.9 | 89.2 | 83.5 | 95.1 | 69.0 | 56.2 |
70
+
71
+ <font size=1>*ZS: Zero-shot Classification, LP: Linear-Probe Classification, T2I/I2T: Text-to-Image/Image-to-Text Retrieval*</font>
72
+
73
+
74
+ ## ✏️ Reference
75
+
76
+ If you find UniViTAR useful in your research or applications, please consider citing the following BibTeX:
77
+
78
+ ```
79
+ @article{qiao2025univitar,
80
+ title={UniViTAR: Unified Vision Transformer with Native Resolution},
81
+ author={Qiao, Limeng and Gan, Yiyang and Wang, Bairui and Qin, Jie and Xu, Shuang and Yang, Siqi and Ma, Lin},
82
+ journal={arXiv preprint arXiv:2504.01792},
83
+ year={2025}
84
+ }
85
+ ```
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resolution_mode": "native",
3
+ "min_tokens": 256,
4
+ "max_tokens": 16384,
5
+ "patch_size": 14,
6
+ "resize_factor": 2,
7
+ "spatial_merge_size": 1,
8
+ "temporal_patch_size": 2,
9
+ "num_hidden_layers": 24,
10
+ "num_attention_heads": 16,
11
+ "hidden_size": 1024,
12
+ "intermediate_size": 4224,
13
+ "pe_type": "rope2d",
14
+ "norm_type": "RMSNorm",
15
+ "hidden_act": "SwiGLU",
16
+ "init_method": "xavier",
17
+ "image_mean": [0.485, 0.456, 0.406],
18
+ "image_std": [0.229, 0.224, 0.225]
19
+ }
modeling_univitar.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, Optional, Tuple, Union, List
2
+
3
+ import os
4
+ import math
5
+ import json
6
+ import torch
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+ import torch.nn.functional as F
11
+
12
+ from PIL import Image
13
+ from einops import rearrange
14
+ from functools import partial
15
+ from timm.layers import DropPath
16
+ from dataclasses import dataclass
17
+ from torchvision import transforms
18
+ from transformers.utils import logging
19
+ from transformers.activations import ACT2FN
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
23
+ from flash_attn.bert_padding import pad_input
24
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ def rotate_half(x):
30
+ """Rotates half the hidden dims of the input."""
31
+ x1 = x[..., : x.shape[-1] // 2]
32
+ x2 = x[..., x.shape[-1] // 2 :]
33
+ return torch.cat((-x2, x1), dim=-1)
34
+
35
+
36
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
37
+ orig_dtype = tensor.dtype
38
+ tensor = tensor.float()
39
+ cos = freqs.cos()
40
+ sin = freqs.sin()
41
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
42
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
43
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
44
+ output = output.to(orig_dtype)
45
+ return output
46
+
47
+
48
+ class VisionRotaryEmbedding2D(nn.Module):
49
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
50
+ super().__init__()
51
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
52
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
53
+
54
+ def forward_(self, seqlen: int) -> torch.Tensor:
55
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
56
+ freqs = torch.outer(seq, self.inv_freq)
57
+ return freqs
58
+
59
+ def forward(self, grid_shapes, spatial_merge_size=2):
60
+ pos_ids = []
61
+ s = spatial_merge_size
62
+ for t, h, w in grid_shapes:
63
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
64
+ hpos_ids = hpos_ids.reshape(h // s, s, w // s, s)
65
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
66
+ hpos_ids = hpos_ids.flatten()
67
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
68
+ wpos_ids = wpos_ids.reshape(h // s, s, w // s, s)
69
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
70
+ wpos_ids = wpos_ids.flatten()
71
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
72
+ pos_ids = torch.cat(pos_ids, dim=0)
73
+ max_grid_size = torch.tensor(grid_shapes).max()
74
+ rotary_pos_emb_full = self.forward_(max_grid_size)
75
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
76
+ return rotary_pos_emb
77
+
78
+
79
+ class FlashAttention(nn.Module):
80
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
81
+ """Implement the scaled dot product attention with softmax.
82
+ Arguments
83
+ ---------
84
+ softmax_scale: The temperature to use for the softmax attention.
85
+ (default: 1/sqrt(d_keys) where d_keys is computed at
86
+ runtime)
87
+ attention_dropout: The dropout rate to apply to the attention
88
+ (default: 0.0)
89
+ """
90
+
91
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
92
+ super().__init__()
93
+ self.softmax_scale = softmax_scale
94
+ self.dropout_p = attention_dropout
95
+ self._deterministic = True
96
+
97
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
98
+ max_s=None, need_weights=False):
99
+ """Implements the multihead softmax attention.
100
+ Arguments
101
+ ---------
102
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
103
+ if unpadded: (nnz, 3, h, d)
104
+ key_padding_mask: a bool tensor of shape (B, S)
105
+ """
106
+ assert not need_weights
107
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
108
+ assert qkv.is_cuda
109
+
110
+ if cu_seqlens is None:
111
+ batch_size = qkv.shape[0]
112
+ seqlen = qkv.shape[1]
113
+ if key_padding_mask is None:
114
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
115
+ max_s = seqlen
116
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
117
+ device=qkv.device)
118
+ output = flash_attn_unpadded_qkvpacked_func(
119
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
120
+ softmax_scale=self.softmax_scale, causal=causal
121
+ )
122
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
123
+ else:
124
+ qkv = qkv.squeeze() # [1, n, h, d] -> [n, h, d]
125
+ seqlens_in_batch = key_padding_mask.sum(dim=-1, dtype=torch.int32)
126
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
127
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
128
+ output = flash_attn_unpadded_qkvpacked_func(
129
+ qkv, cu_seqlens, max_seqlen_in_batch, self.dropout_p if self.training else 0.0,
130
+ softmax_scale=self.softmax_scale, causal=causal, deterministic=self._deterministic
131
+ )
132
+ output = output.unsqueeze(0)
133
+ else:
134
+ assert max_s is not None
135
+ output = flash_attn_unpadded_qkvpacked_func(
136
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
137
+ softmax_scale=self.softmax_scale, causal=causal
138
+ )
139
+
140
+ return output, None
141
+
142
+
143
+ class RMSNorm(nn.Module):
144
+ def __init__(self, hidden_size, eps=1e-6):
145
+ super().__init__()
146
+ self.weight = nn.Parameter(torch.ones(hidden_size))
147
+ self.variance_epsilon = eps
148
+
149
+ def forward(self, hidden_states):
150
+ input_dtype = hidden_states.dtype
151
+ hidden_states = hidden_states.to(torch.float32)
152
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
153
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
154
+ return self.weight * hidden_states.to(input_dtype)
155
+
156
+
157
+ try:
158
+ from apex.normalization import FusedRMSNorm
159
+ RMSNorm = FusedRMSNorm # noqa
160
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of RMSNorm')
161
+ except ImportError: # using the normal RMSNorm
162
+ pass
163
+ except Exception:
164
+ logger.warning('discovered apex but it failed to load, falling back to RMSNorm')
165
+ pass
166
+
167
+
168
+ @dataclass
169
+ class BaseModelOutputWithKwargs(ModelOutput):
170
+ last_hidden_state: torch.FloatTensor = None
171
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
172
+ kwargs: Optional[dict] = None
173
+
174
+
175
+ class UniViTARVisionConfig(PretrainedConfig):
176
+ def __init__(
177
+ self,
178
+ resolution_mode="native",
179
+ init_method="xavier",
180
+ num_channels=3,
181
+ patch_size=14,
182
+ temporal_patch_size=2,
183
+ image_size=1792,
184
+ patch_dropout=0.0,
185
+ attention_dropout=0.0,
186
+ dropout=0.0,
187
+ drop_path_rate=0.0,
188
+ initializer_range=1e-10,
189
+ num_hidden_layers=24,
190
+ num_attention_heads=16,
191
+ hidden_size=1024,
192
+ intermediate_size=4224,
193
+ patch_embedding_bias=True,
194
+ qk_normalization=True,
195
+ qkv_bias=False,
196
+ initializer_factor=0.1,
197
+ use_pre_norm=False,
198
+ pe_type="rope2d",
199
+ rope_theta=10000,
200
+ spatial_merge_size=1,
201
+ norm_type="RMSNorm",
202
+ hidden_act='SwiGLU',
203
+ use_flash_attn=True,
204
+ layer_norm_eps=1e-6,
205
+ min_tokens=576,
206
+ max_tokens=16384,
207
+ image_mean=(0.485, 0.456, 0.406),
208
+ image_std=(0.229, 0.224, 0.225),
209
+ relarge_ratio=1.0,
210
+ **kwargs,
211
+ ):
212
+ super().__init__(**kwargs)
213
+
214
+ self.resolution_mode = resolution_mode
215
+ self.init_method = init_method
216
+ self.pe_type = pe_type
217
+ self.rope_theta = rope_theta
218
+ self.temporal_patch_size = temporal_patch_size
219
+ self.num_channels = num_channels
220
+ self.patch_size = patch_size
221
+ self.image_size = image_size
222
+ self.patch_dropout = patch_dropout
223
+ self.attention_dropout = attention_dropout
224
+ self.dropout = dropout
225
+ self.drop_path_rate = drop_path_rate
226
+ self.initializer_range = initializer_range
227
+ self.num_hidden_layers = num_hidden_layers
228
+ self.num_attention_heads = num_attention_heads
229
+ self.hidden_size = hidden_size
230
+ self.intermediate_size = intermediate_size
231
+ self.patch_embedding_bias = patch_embedding_bias
232
+ self.qk_normalization = qk_normalization
233
+ self.qkv_bias = qkv_bias
234
+ self.initializer_factor = initializer_factor
235
+ self.use_pre_norm = use_pre_norm
236
+ self.norm_type = norm_type
237
+ self.hidden_act = hidden_act
238
+ self.use_flash_attn = use_flash_attn
239
+ self.layer_norm_eps = layer_norm_eps
240
+ self.spatial_merge_size = spatial_merge_size
241
+ self.min_tokens = min_tokens
242
+ self.max_tokens = max_tokens
243
+ self.image_mean = image_mean
244
+ self.image_std = image_std
245
+ self.relarge_ratio = relarge_ratio
246
+
247
+ @classmethod
248
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
249
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
250
+
251
+ if 'vision_config' in config_dict:
252
+ config_dict = config_dict['vision_config']
253
+
254
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
255
+ logger.warning(
256
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
257
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
258
+ )
259
+
260
+ return cls.from_dict(config_dict, **kwargs)
261
+
262
+
263
+ class UniViTARImageTransform(object):
264
+ def __init__(self, config):
265
+ self.config = config
266
+ self.resolution_mode = config.resolution_mode
267
+
268
+ self.image_mean, self.image_std = config.image_mean, config.image_std
269
+ self.patch_size = config.patch_size
270
+ self.temporal_patch_size = config.temporal_patch_size
271
+ self.spatial_merge_size = config.spatial_merge_size
272
+ self.resize_factor = config.patch_size * config.spatial_merge_size * config.resize_factor
273
+ self.relarge_ratio = config.relarge_ratio
274
+
275
+ self.forced_transform = None
276
+ self.min_pixels, self.max_pixels = None, None
277
+ assert self.resolution_mode in ["native", "224", "378", "756"]
278
+ if self.resolution_mode == "native":
279
+ self.min_pixels = config.min_tokens * config.patch_size * config.patch_size
280
+ self.max_pixels = config.max_tokens * config.patch_size * config.patch_size
281
+ else:
282
+ image_size = int(self.resolution_mode)
283
+ self.forced_transform = transforms.Compose([
284
+ transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC),
285
+ self.convert_to_rgb,
286
+ transforms.ToTensor(),
287
+ transforms.Normalize(mean=self.image_mean, std=self.image_std)
288
+ ]
289
+ )
290
+
291
+ def __call__(self, images):
292
+
293
+ if not isinstance(images, List):
294
+ images = [images] # shape of each image is [h, w, c]
295
+ assert len(images) == 1 or len(images) % self.temporal_patch_size == 0
296
+
297
+ if self.resolution_mode == "native":
298
+ sample_num = 1 if len(images) == 1 else len(images) // self.temporal_patch_size
299
+ min_pixels, max_pixels = self.min_pixels // sample_num, self.max_pixels // sample_num
300
+ width, height = images[0].size # (w, h)
301
+ if self.relarge_ratio > 0 and self.relarge_ratio != 1:
302
+ height, width = int(height * self.relarge_ratio), int(width * self.relarge_ratio)
303
+ resized_height, resized_width = self.smart_resize(height, width, self.resize_factor, min_pixels, max_pixels)
304
+ processed_images = []
305
+ for image in images:
306
+ image = self.convert_to_rgb(image)
307
+ image = self.resize(image, size=(resized_height, resized_width), resample=Image.Resampling.BICUBIC)
308
+ image = self.rescale(image, scale=1/255)
309
+ image = self.normalize(image=image, mean=self.image_mean, std=self.image_std)
310
+ processed_images.append(image)
311
+ processed_images = np.array(processed_images) # (num, h, w, c)
312
+ processed_images = processed_images.transpose(0, 3, 1, 2) # (num, c, h, w)
313
+ else:
314
+ processed_images = [self.forced_transform(image).numpy() for image in images]
315
+ processed_images = np.array(processed_images)
316
+
317
+ if processed_images.shape[0] == 1:
318
+ processed_images = np.tile(processed_images, (self.temporal_patch_size, 1, 1, 1))
319
+
320
+ return torch.from_numpy(processed_images)
321
+
322
+ @staticmethod
323
+ def convert_to_rgb(image):
324
+ if not isinstance(image, Image.Image):
325
+ return image
326
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
327
+ # for transparent images. The call to `alpha_composite` handles this case
328
+ if image.mode == "RGB":
329
+ return image
330
+ image_rgba = image.convert("RGBA")
331
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
332
+ alpha_composite = Image.alpha_composite(background, image_rgba)
333
+ alpha_composite = alpha_composite.convert("RGB")
334
+ return alpha_composite
335
+
336
+ @staticmethod
337
+ def resize(image, size, resample, return_numpy: bool = True) -> np.ndarray:
338
+ """
339
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
340
+ """
341
+ if not len(size) == 2:
342
+ raise ValueError("size must have 2 elements")
343
+ assert isinstance(image, Image.Image)
344
+ height, width = size
345
+ resample = resample if resample is not None else Image.Resampling.BILINEAR
346
+ # PIL images are in the format (width, height)
347
+ resized_image = image.resize((width, height), resample=resample, reducing_gap=None)
348
+ if return_numpy:
349
+ resized_image = np.array(resized_image)
350
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
351
+ return resized_image
352
+
353
+ @staticmethod
354
+ def rescale(image: np.ndarray, scale: float, dtype: np.dtype = np.float32) -> np.ndarray:
355
+ if not isinstance(image, np.ndarray):
356
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
357
+ rescaled_image = image * scale
358
+ rescaled_image = rescaled_image.astype(dtype)
359
+ return rescaled_image
360
+
361
+ @staticmethod
362
+ def normalize(image, mean, std) -> np.ndarray:
363
+ """
364
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
365
+ image = (image - mean) / std
366
+ """
367
+ if not isinstance(image, np.ndarray):
368
+ raise ValueError("image must be a numpy array")
369
+ num_channels = image.shape[-1]
370
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
371
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
372
+ if not np.issubdtype(image.dtype, np.floating):
373
+ image = image.astype(np.float32)
374
+ if isinstance(mean, Iterable):
375
+ if len(mean) != num_channels:
376
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
377
+ else:
378
+ mean = [mean] * num_channels
379
+ mean = np.array(mean, dtype=image.dtype)
380
+ if isinstance(std, Iterable):
381
+ if len(std) != num_channels:
382
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
383
+ else:
384
+ std = [std] * num_channels
385
+ std = np.array(std, dtype=image.dtype)
386
+ image = (image - mean) / std
387
+ return image
388
+
389
+ @staticmethod
390
+ def smart_resize(height, width, factor, min_pixels, max_pixels):
391
+ """
392
+ 1. Both dimensions (height and width) are divisible by 'factor'.
393
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
394
+ 3. The aspect ratio of the image is maintained as closely as possible.
395
+ """
396
+ if height < factor or width < factor:
397
+ if height < factor:
398
+ ratio = factor / height
399
+ height, width = factor, int(ratio * width) + 1
400
+ if width < factor:
401
+ ratio = factor / width
402
+ width, height = factor, int(ratio * height) + 1
403
+ h_bar = round(height / factor) * factor
404
+ w_bar = round(width / factor) * factor
405
+ if h_bar * w_bar > max_pixels:
406
+ beta = math.sqrt((height * width) / max_pixels)
407
+ h_bar = math.floor(height / beta / factor) * factor
408
+ w_bar = math.floor(width / beta / factor) * factor
409
+ elif h_bar * w_bar < min_pixels:
410
+ beta = math.sqrt(min_pixels / (height * width))
411
+ h_bar = math.ceil(height * beta / factor) * factor
412
+ w_bar = math.ceil(width * beta / factor) * factor
413
+ return h_bar, w_bar
414
+
415
+
416
+ class SwiGLU(nn.Module):
417
+ def __init__(self, config: UniViTARVisionConfig):
418
+ super().__init__()
419
+ self.config = config
420
+ self.inner_hidden_size = int(config.intermediate_size * 2 / 3)
421
+ self.act = ACT2FN['silu']
422
+ self.fc1 = nn.Linear(config.hidden_size, self.inner_hidden_size)
423
+ self.fc2 = nn.Linear(self.inner_hidden_size, config.hidden_size)
424
+ self.fc3 = nn.Linear(config.hidden_size, self.inner_hidden_size)
425
+ self.norm = RMSNorm(self.inner_hidden_size, eps=config.layer_norm_eps)
426
+
427
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
428
+ hidden_states = self.fc1(x)
429
+ hidden_states = self.act(hidden_states)
430
+ hidden_states = self.fc2(self.norm(hidden_states * self.fc3(x)))
431
+ return hidden_states
432
+
433
+
434
+ class Attention(nn.Module):
435
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
436
+ def __init__(self, config: UniViTARVisionConfig):
437
+ super().__init__()
438
+ self.config = config
439
+ self.embed_dim = config.hidden_size
440
+ self.num_heads = config.num_attention_heads
441
+ self.head_dim = self.embed_dim // self.num_heads
442
+ assert config.use_flash_attn is True, "FlashAttention must be used!"
443
+ assert self.head_dim * self.num_heads == self.embed_dim
444
+
445
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
446
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
447
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
448
+ self.proj_drop = nn.Dropout(config.dropout)
449
+ if self.config.qk_normalization:
450
+ self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
451
+ self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
452
+
453
+ def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
454
+ key_padding_mask = kwargs.get("key_padding_mask", None)
455
+ rotary_pos_emb = kwargs["rotary_pos_emb"]
456
+
457
+ qkv = self.qkv(hidden_states)
458
+ qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads)
459
+ bind_dim = qkv.dim() - 3
460
+ target_dtype = qkv.dtype
461
+ q, k, v = qkv.unbind(bind_dim)
462
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
463
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
464
+ if self.config.qk_normalization:
465
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
466
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
467
+ qkv = torch.stack([q, k, v], dim=bind_dim).to(target_dtype)
468
+ context, _ = self.inner_attn(qkv, key_padding_mask=key_padding_mask, causal=False)
469
+
470
+ outs = self.proj(rearrange(context, '... h d -> ... (h d)')) # input expected to be: [b s h d] or [s h d]
471
+ outs = self.proj_drop(outs)
472
+
473
+ return outs
474
+
475
+
476
+ class UniViTARVisionEmbeddings(nn.Module):
477
+ def __init__(self, config: UniViTARVisionConfig):
478
+ super().__init__()
479
+ self.config = config
480
+ self.embed_dim = config.hidden_size
481
+ self.patch_size = config.patch_size
482
+ self.temporal_patch_size = config.temporal_patch_size
483
+ self.kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
484
+ self.use_bias = config.patch_embedding_bias
485
+ self.patch_embedding = nn.Conv3d(
486
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.kernel_size, stride=self.kernel_size, bias=self.use_bias)
487
+
488
+ def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> torch.Tensor:
489
+ pixel_values = pixel_values.view(-1, 3, *self.kernel_size)
490
+ patch_embeds = self.patch_embedding(pixel_values)
491
+ embeddings = patch_embeds.view(1, -1, self.embed_dim)
492
+ self.num_patches = embeddings.shape[1]
493
+ return embeddings
494
+
495
+
496
+ class UniViTARVisionEncoderLayer(nn.Module):
497
+ def __init__(self, config: UniViTARVisionConfig, drop_path_rate: float):
498
+ super().__init__()
499
+ self.embed_dim = config.hidden_size
500
+ assert config.hidden_act == "SwiGLU"
501
+
502
+ self.attn = Attention(config)
503
+ self.norm1 = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
504
+ self.norm2 = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
505
+ self.mlp = SwiGLU(config)
506
+
507
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
508
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
509
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
510
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
511
+
512
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
513
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states), **kwargs) * self.ls1)
514
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
515
+ return hidden_states
516
+
517
+
518
+ class UniViTARVisionEncoder(nn.Module):
519
+ """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. """
520
+ def __init__(self, config: UniViTARVisionConfig):
521
+ super().__init__()
522
+ self.config = config
523
+ self.gradient_checkpointing = True
524
+
525
+ # stochastic depth decay rule
526
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
527
+ self.layers = nn.ModuleList([UniViTARVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
528
+ if self.config.pe_type == "rope2d":
529
+ head_dim = config.hidden_size // config.num_attention_heads
530
+ self.rotary_pos_emb = VisionRotaryEmbedding2D(head_dim // 2, theta=self.config.rope_theta)
531
+ else:
532
+ raise NotImplementedError
533
+
534
+ def forward(self, inputs_embeds, output_hidden_states = False, **kwargs):
535
+ kwargs["rotary_pos_emb"] = self.rotary_pos_emb(kwargs["grid_shapes"], self.config.spatial_merge_size)
536
+
537
+ encoder_states = () if output_hidden_states else None
538
+ hidden_states = inputs_embeds
539
+ for idx, encoder_layer in enumerate(self.layers):
540
+ if output_hidden_states:
541
+ encoder_states = encoder_states + (hidden_states,)
542
+ if self.gradient_checkpointing and self.training:
543
+ encoder_layer_forward = partial(encoder_layer, **kwargs)
544
+ layer_outputs = torch.utils.checkpoint.checkpoint(encoder_layer_forward, hidden_states, use_reentrant=True)
545
+ else:
546
+ layer_outputs = encoder_layer(hidden_states, **kwargs)
547
+ hidden_states = layer_outputs
548
+ if output_hidden_states:
549
+ encoder_states = encoder_states + (hidden_states,)
550
+
551
+ return BaseModelOutputWithKwargs(last_hidden_state=hidden_states, hidden_states=encoder_states, kwargs=kwargs)
552
+
553
+
554
+ class UniViTARVisionModel(PreTrainedModel):
555
+ main_input_name = 'pixel_values'
556
+ config_class = UniViTARVisionConfig
557
+ _no_split_modules = ['UniViTARVisionEncoderLayer']
558
+
559
+ def __init__(self, model_config_path, *args, **kwargs):
560
+
561
+ model_config_dict = json.load(open(model_config_path, "r", encoding="utf8"))
562
+ config = UniViTARVisionConfig.from_dict(model_config_dict)
563
+
564
+ super().__init__(config)
565
+ self.config = config
566
+ self.image_transform = UniViTARImageTransform(config)
567
+
568
+ self.embeddings = UniViTARVisionEmbeddings(config)
569
+ self.encoder = UniViTARVisionEncoder(config)
570
+
571
+ def get_input_embeddings(self):
572
+ return self.embeddings
573
+
574
+ def get_padding_mask(self, grid_shapes):
575
+ seq_len = torch.tensor([int((np.prod(thw) - 1) + 1) for thw in grid_shapes])
576
+ max_len = torch.max(seq_len)
577
+ batch_size = len(grid_shapes)
578
+ mask = torch.zeros((batch_size, max_len), dtype=torch.long)
579
+ range_matrix = torch.arange(max_len).expand(batch_size, max_len)
580
+ mask = (range_matrix < seq_len.unsqueeze(1))
581
+ return mask.cuda()
582
+
583
+ def forward(self, pixel_values, output_hidden_states = False, **kwargs):
584
+ assert len(pixel_values.shape) == 2, "(batch_num_tokens, hidden_size)"
585
+ assert "grid_shapes" in kwargs, "grid_shapes: [(t, h, w), ..., (t, h, w)]"
586
+ kwargs["key_padding_mask"] = self.get_padding_mask(kwargs["grid_shapes"])
587
+ hidden_states = self.embeddings(pixel_values, **kwargs)
588
+ encoder_outputs = self.encoder(hidden_states, output_hidden_states, **kwargs)
589
+ last_hidden_state = encoder_outputs.last_hidden_state
590
+ return last_hidden_state.squeeze(0)
591
+
592
+ def data_patchify(self, input_data):
593
+ t, c, h, w = input_data.shape
594
+ grid_t, grid_h, grid_w = t // self.config.temporal_patch_size, h // self.config.patch_size, w // self.config.patch_size
595
+ grid_size = c * self.config.temporal_patch_size * self.config.patch_size * self.config.patch_size
596
+ input_data = input_data.reshape(
597
+ grid_t, self.config.temporal_patch_size, c,
598
+ grid_h // self.config.spatial_merge_size, self.config.spatial_merge_size, self.config.patch_size,
599
+ grid_w // self.config.spatial_merge_size, self.config.spatial_merge_size, self.config.patch_size
600
+ )
601
+ input_data = input_data.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
602
+ input_data = input_data.reshape(grid_t * grid_h * grid_w, grid_size).contiguous()
603
+ grid_shape = (grid_t, grid_h, grid_w)
604
+ return input_data, grid_shape
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fed2e7e0bd1596c56fff2f6ca94ceb3f1f7a86a44d7a61e56b2c7a1daf06fc78
3
+ size 619815078