Update vision_tower_builder.py
Browse files- vision_tower_builder.py +18 -8
vision_tower_builder.py
CHANGED
@@ -2,9 +2,6 @@ from typing import Optional, Tuple, Union, Dict
|
|
2 |
from dataclasses import dataclass
|
3 |
from functools import partial, reduce
|
4 |
from PIL import Image
|
5 |
-
import torch
|
6 |
-
import torch.utils.checkpoint
|
7 |
-
from torch import nn
|
8 |
import os
|
9 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
10 |
from transformers.image_transforms import (
|
@@ -27,9 +24,15 @@ import torch.utils.checkpoint as checkpoint
|
|
27 |
from functools import partial
|
28 |
try:
|
29 |
from flash_attn import flash_attn_qkvpacked_func
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
except:
|
31 |
-
|
32 |
-
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
33 |
|
34 |
|
35 |
|
@@ -70,6 +73,14 @@ class Attention(nn.Module):
|
|
70 |
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
71 |
proj_drop=0., attn_head_dim=None,
|
72 |
attn_type='flash_v2'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
super().__init__()
|
74 |
self.num_heads = num_heads
|
75 |
head_dim = dim // num_heads
|
@@ -516,7 +527,7 @@ def build_vit(config, pt_type='origin'):
|
|
516 |
drop_path_rate=0.,
|
517 |
num_frames=config.num_frames,
|
518 |
tubelet_size=1,
|
519 |
-
use_checkpoint=
|
520 |
checkpoint_num=24,
|
521 |
return_index=config.return_idx,
|
522 |
with_ln=True, # merge vision_layernorm in it
|
@@ -614,9 +625,8 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
|
|
614 |
|
615 |
|
616 |
if "umt-hd" in vision_tower:
|
617 |
-
raise NotImplementedError
|
618 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
|
619 |
elif "umt" in vision_tower:
|
620 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
621 |
|
622 |
-
raise ValueError(f"Unknown vision tower: {vision_tower}")
|
|
|
2 |
from dataclasses import dataclass
|
3 |
from functools import partial, reduce
|
4 |
from PIL import Image
|
|
|
|
|
|
|
5 |
import os
|
6 |
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
7 |
from transformers.image_transforms import (
|
|
|
24 |
from functools import partial
|
25 |
try:
|
26 |
from flash_attn import flash_attn_qkvpacked_func
|
27 |
+
use_flash_attn = True
|
28 |
+
except:
|
29 |
+
use_flash_attn = False
|
30 |
+
print("You need to install flash_attn to be faster!")
|
31 |
+
|
32 |
+
try:
|
33 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
34 |
except:
|
35 |
+
from timm.models.layers import drop_path, trunc_normal_, to_2tuple
|
|
|
36 |
|
37 |
|
38 |
|
|
|
73 |
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
74 |
proj_drop=0., attn_head_dim=None,
|
75 |
attn_type='flash_v2'):
|
76 |
+
|
77 |
+
if use_flash_attn:
|
78 |
+
attn_type = attn_type
|
79 |
+
else:
|
80 |
+
attn_type = 'origin'
|
81 |
+
|
82 |
+
print(attn_type)
|
83 |
+
|
84 |
super().__init__()
|
85 |
self.num_heads = num_heads
|
86 |
head_dim = dim // num_heads
|
|
|
527 |
drop_path_rate=0.,
|
528 |
num_frames=config.num_frames,
|
529 |
tubelet_size=1,
|
530 |
+
use_checkpoint=False,
|
531 |
checkpoint_num=24,
|
532 |
return_index=config.return_idx,
|
533 |
with_ln=True, # merge vision_layernorm in it
|
|
|
625 |
|
626 |
|
627 |
if "umt-hd" in vision_tower:
|
|
|
628 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
|
629 |
elif "umt" in vision_tower:
|
630 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
631 |
|
632 |
+
raise ValueError(f"Unknown vision tower: {vision_tower}")
|