Update vision_tower_builder.py
Browse files- vision_tower_builder.py +8 -2
vision_tower_builder.py
CHANGED
@@ -28,7 +28,11 @@ try:
|
|
28 |
except:
|
29 |
use_flash_attn = False
|
30 |
print("You need to install flash_attn to be faster!")
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
|
@@ -74,7 +78,9 @@ class Attention(nn.Module):
|
|
74 |
attn_type = attn_type
|
75 |
else:
|
76 |
attn_type = 'origin'
|
77 |
-
|
|
|
|
|
78 |
super().__init__()
|
79 |
self.num_heads = num_heads
|
80 |
head_dim = dim // num_heads
|
|
|
28 |
except:
|
29 |
use_flash_attn = False
|
30 |
print("You need to install flash_attn to be faster!")
|
31 |
+
|
32 |
+
try:
|
33 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
34 |
+
except:
|
35 |
+
from timm.models.layers import drop_path, trunc_normal_, to_2tuple
|
36 |
|
37 |
|
38 |
|
|
|
78 |
attn_type = attn_type
|
79 |
else:
|
80 |
attn_type = 'origin'
|
81 |
+
|
82 |
+
print(attn_type)
|
83 |
+
|
84 |
super().__init__()
|
85 |
self.num_heads = num_heads
|
86 |
head_dim = dim // num_heads
|