jiuhai commited on
Commit
781574f
·
verified ·
1 Parent(s): 8bd60d3

Create florence_encoder.py

Browse files
llava/model/multimodal_encoder/florence_encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig, AutoProcessor, AutoModelForCausalLM
5
+
6
+
7
+
8
+
9
+ class FlorenceVisionTower(nn.Module):
10
+ def __init__(self, vision_tower, args, delay_load=False):
11
+ super().__init__()
12
+
13
+ self.is_loaded = False
14
+ self.vision_tower_name = vision_tower
15
+
16
+ if not delay_load:
17
+ self.load_model()
18
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
19
+ self.load_model()
20
+ else:
21
+ self.load_model()
22
+
23
+
24
+ def load_model(self, device_map=None):
25
+ if self.is_loaded:
26
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27
+ return
28
+
29
+ self.image_processor = AutoProcessor.from_pretrained(self.vision_tower_name, trust_remote_code=True)
30
+ self.vision_tower = AutoModelForCausalLM.from_pretrained(self.vision_tower_name, trust_remote_code=True).to(torch.bfloat16)
31
+ self.vision_tower.requires_grad_(False)
32
+
33
+ self.is_loaded = True
34
+
35
+
36
+ @torch.no_grad()
37
+ def forward(self, images):
38
+
39
+ ## hard code for the task prompt
40
+ # task = [
41
+ # 'Describe in detail what is shown in the image.',
42
+ # 'What is the text in the image?',
43
+ # 'Locate the objects in the image, with their descriptions.',
44
+ # ]
45
+
46
+ task_ids = torch.tensor([
47
+ [0, 47066, 21700, 11, 4617, 99, 16, 2343, 11, 5, 2274, 4, 2, 1],
48
+ [0, 2264, 16, 5, 2788, 11, 5, 2274, 116, 2, 1, 1, 1, 1],
49
+ [0, 574, 22486, 5, 8720, 11, 5, 2274, 6, 19, 49, 24173, 4, 2]
50
+ ]).to(device=self.device)
51
+
52
+
53
+ with torch.no_grad():
54
+ generated_ids, image_feature, encoder_last_hidden_state = self.vision_tower.generate(
55
+ input_ids=task_ids,
56
+ pixel_values=images,
57
+ max_new_tokens=1,
58
+ do_sample=False,
59
+ num_beams=1,
60
+ )
61
+ return image_feature, encoder_last_hidden_state
62
+
63
+
64
+ @property
65
+ def dummy_feature(self):
66
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67
+
68
+ @property
69
+ def dtype(self):
70
+ return self.vision_tower.dtype
71
+
72
+ @property
73
+ def device(self):
74
+ return self.vision_tower.device
75
+
76
+ @property
77
+ def config(self):
78
+ if self.is_loaded:
79
+ return self.vision_tower.config
80
+ else:
81
+ return self.cfg_only
82
+
83
+ @property
84
+ def hidden_size(self):
85
+ return self.config.hidden_size
86
+
87
+ @property
88
+ def num_patches_per_side(self):
89
+ return self.config.image_size // self.config.patch_size
90
+
91
+ @property
92
+ def num_patches(self):
93
+ return (self.config.image_size // self.config.patch_size) ** 2