lixinhao commited on
Commit
1a6395e
·
verified ·
1 Parent(s): 78403f6

Update vision_tower_builder.py

Browse files
Files changed (1) hide show
  1. vision_tower_builder.py +8 -2
vision_tower_builder.py CHANGED
@@ -28,7 +28,11 @@ try:
28
  except:
29
  use_flash_attn = False
30
  print("You need to install flash_attn to be faster!")
31
- from timm.layers import drop_path, to_2tuple, trunc_normal_
 
 
 
 
32
 
33
 
34
 
@@ -74,7 +78,9 @@ class Attention(nn.Module):
74
  attn_type = attn_type
75
  else:
76
  attn_type = 'origin'
77
-
 
 
78
  super().__init__()
79
  self.num_heads = num_heads
80
  head_dim = dim // num_heads
 
28
  except:
29
  use_flash_attn = False
30
  print("You need to install flash_attn to be faster!")
31
+
32
+ try:
33
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
34
+ except:
35
+ from timm.models.layers import drop_path, trunc_normal_, to_2tuple
36
 
37
 
38
 
 
78
  attn_type = attn_type
79
  else:
80
  attn_type = 'origin'
81
+
82
+ print(attn_type)
83
+
84
  super().__init__()
85
  self.num_heads = num_heads
86
  head_dim = dim // num_heads