lixinhao commited on
Commit
c7ecf88
·
verified ·
1 Parent(s): a44e8cc

Update vision_tower_builder.py

Browse files
Files changed (1) hide show
  1. vision_tower_builder.py +18 -8
vision_tower_builder.py CHANGED
@@ -2,9 +2,6 @@ from typing import Optional, Tuple, Union, Dict
2
  from dataclasses import dataclass
3
  from functools import partial, reduce
4
  from PIL import Image
5
- import torch
6
- import torch.utils.checkpoint
7
- from torch import nn
8
  import os
9
  from transformers.image_processing_utils import BatchFeature, get_size_dict
10
  from transformers.image_transforms import (
@@ -27,9 +24,15 @@ import torch.utils.checkpoint as checkpoint
27
  from functools import partial
28
  try:
29
  from flash_attn import flash_attn_qkvpacked_func
 
 
 
 
 
 
 
30
  except:
31
- print("You need to install flash_attn")
32
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
33
 
34
 
35
 
@@ -70,6 +73,14 @@ class Attention(nn.Module):
70
  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
71
  proj_drop=0., attn_head_dim=None,
72
  attn_type='flash_v2'):
 
 
 
 
 
 
 
 
73
  super().__init__()
74
  self.num_heads = num_heads
75
  head_dim = dim // num_heads
@@ -516,7 +527,7 @@ def build_vit(config, pt_type='origin'):
516
  drop_path_rate=0.,
517
  num_frames=config.num_frames,
518
  tubelet_size=1,
519
- use_checkpoint=True,
520
  checkpoint_num=24,
521
  return_index=config.return_idx,
522
  with_ln=True, # merge vision_layernorm in it
@@ -614,9 +625,8 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
614
 
615
 
616
  if "umt-hd" in vision_tower:
617
- raise NotImplementedError
618
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
619
  elif "umt" in vision_tower:
620
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
621
 
622
- raise ValueError(f"Unknown vision tower: {vision_tower}")
 
2
  from dataclasses import dataclass
3
  from functools import partial, reduce
4
  from PIL import Image
 
 
 
5
  import os
6
  from transformers.image_processing_utils import BatchFeature, get_size_dict
7
  from transformers.image_transforms import (
 
24
  from functools import partial
25
  try:
26
  from flash_attn import flash_attn_qkvpacked_func
27
+ use_flash_attn = True
28
+ except:
29
+ use_flash_attn = False
30
+ print("You need to install flash_attn to be faster!")
31
+
32
+ try:
33
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
34
  except:
35
+ from timm.models.layers import drop_path, trunc_normal_, to_2tuple
 
36
 
37
 
38
 
 
73
  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
74
  proj_drop=0., attn_head_dim=None,
75
  attn_type='flash_v2'):
76
+
77
+ if use_flash_attn:
78
+ attn_type = attn_type
79
+ else:
80
+ attn_type = 'origin'
81
+
82
+ print(attn_type)
83
+
84
  super().__init__()
85
  self.num_heads = num_heads
86
  head_dim = dim // num_heads
 
527
  drop_path_rate=0.,
528
  num_frames=config.num_frames,
529
  tubelet_size=1,
530
+ use_checkpoint=False,
531
  checkpoint_num=24,
532
  return_index=config.return_idx,
533
  with_ln=True, # merge vision_layernorm in it
 
625
 
626
 
627
  if "umt-hd" in vision_tower:
 
628
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
629
  elif "umt" in vision_tower:
630
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
631
 
632
+ raise ValueError(f"Unknown vision tower: {vision_tower}")