Spaces:
Running
on
Zero
Running
on
Zero
machuofan
commited on
Commit
·
7385f22
1
Parent(s):
bf8190d
init
Browse files- .gitattributes +0 -35
- constants.py +27 -0
- conversation.py +460 -0
- i2t.py +192 -0
- mm_utils.py +105 -0
- model/__init__.py +7 -0
- model/__pycache__/__init__.cpython-311.pyc +0 -0
- model/__pycache__/__init__.cpython-39.pyc +0 -0
- model/__pycache__/arhead.cpython-311.pyc +0 -0
- model/__pycache__/arhead.cpython-39.pyc +0 -0
- model/__pycache__/builder.cpython-311.pyc +0 -0
- model/__pycache__/builder.cpython-39.pyc +0 -0
- model/__pycache__/liquid.cpython-311.pyc +0 -0
- model/__pycache__/mini_gemini_arch.cpython-311.pyc +0 -0
- model/__pycache__/mini_gemini_arch.cpython-39.pyc +0 -0
- model/__pycache__/quant.cpython-311.pyc +0 -0
- model/__pycache__/quant.cpython-39.pyc +0 -0
- model/arhead.py +241 -0
- model/builder.py +138 -0
- model/language_model/__pycache__/mini_gemini_llama.cpython-311.pyc +0 -0
- model/language_model/__pycache__/mini_gemini_llama.cpython-39.pyc +0 -0
- model/language_model/mini_gemini_llama.py +488 -0
- model/liquid.py +669 -0
- model/multimodal_encoder/__pycache__/builder.cpython-311.pyc +0 -0
- model/multimodal_encoder/__pycache__/builder.cpython-39.pyc +0 -0
- model/multimodal_encoder/__pycache__/clip_encoder.cpython-311.pyc +0 -0
- model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc +0 -0
- model/multimodal_encoder/__pycache__/eva_encoder.cpython-311.pyc +0 -0
- model/multimodal_encoder/__pycache__/eva_encoder.cpython-39.pyc +0 -0
- model/multimodal_encoder/__pycache__/openclip_encoder.cpython-311.pyc +0 -0
- model/multimodal_encoder/__pycache__/openclip_encoder.cpython-39.pyc +0 -0
- model/multimodal_encoder/builder.py +33 -0
- model/multimodal_encoder/clip_encoder.py +89 -0
- model/multimodal_encoder/eva_encoder.py +551 -0
- model/multimodal_encoder/openclip_encoder.py +188 -0
- model/multimodal_projector/__pycache__/builder.cpython-311.pyc +0 -0
- model/multimodal_projector/__pycache__/builder.cpython-39.pyc +0 -0
- model/multimodal_projector/builder.py +50 -0
- model/processor/__pycache__/video_processor.cpython-311.pyc +0 -0
- model/processor/__pycache__/video_processor.cpython-39.pyc +0 -0
- model/processor/video_processor.py +74 -0
- model/quant.py +519 -0
- t2i.py +224 -0
- tools.py +126 -0
- unitok/config.py +243 -0
- unitok/dist.py +302 -0
- unitok/model.py +184 -0
- unitok/quant.py +185 -0
- unitok/vitamin.py +792 -0
- unitok/vqvae.py +175 -0
.gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constants.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
PREDICT_TOKEN_INDEX = -300
|
10 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
11 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
12 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
13 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
14 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
15 |
+
DEFAULT_PREDICT_TOKEN = "<predict>"
|
16 |
+
|
17 |
+
DESCRIPT_PROMPT = [
|
18 |
+
"Describe this image thoroughly.",
|
19 |
+
"Provide a detailed description in this picture.",
|
20 |
+
"Detail every aspect of what's in this picture.",
|
21 |
+
"Explain this image with precision and detail.",
|
22 |
+
"Give a comprehensive description of this visual.",
|
23 |
+
"Elaborate on the specifics within this image.",
|
24 |
+
"Offer a detailed account of this picture's contents.",
|
25 |
+
"Describe in detail what this image portrays.",
|
26 |
+
"Break down this image into detailed descriptions.",
|
27 |
+
"Provide a thorough description of the elements in this image."]
|
conversation.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class SeparatorStyle(Enum):
|
10 |
+
"""Different separator style."""
|
11 |
+
SINGLE = auto()
|
12 |
+
TWO = auto()
|
13 |
+
MPT = auto()
|
14 |
+
PLAIN = auto()
|
15 |
+
LLAMA_2 = auto()
|
16 |
+
GEMMA = auto()
|
17 |
+
|
18 |
+
|
19 |
+
@dataclasses.dataclass
|
20 |
+
class Conversation:
|
21 |
+
"""A class that keeps all conversation history."""
|
22 |
+
system: str
|
23 |
+
roles: List[str]
|
24 |
+
messages: List[List[str]]
|
25 |
+
offset: int
|
26 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
27 |
+
sep: str = "###"
|
28 |
+
sep2: str = None
|
29 |
+
version: str = "Unknown"
|
30 |
+
|
31 |
+
skip_next: bool = False
|
32 |
+
|
33 |
+
def get_prompt(self):
|
34 |
+
messages = self.messages
|
35 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
36 |
+
messages = self.messages.copy()
|
37 |
+
init_role, init_msg = messages[0].copy()
|
38 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
39 |
+
if 'mmtag' in self.version:
|
40 |
+
messages[0] = (init_role, init_msg)
|
41 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
42 |
+
messages.insert(1, (self.roles[1], "Received."))
|
43 |
+
else:
|
44 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
45 |
+
|
46 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
47 |
+
ret = self.system + self.sep
|
48 |
+
for role, message in messages:
|
49 |
+
if message:
|
50 |
+
if type(message) is tuple:
|
51 |
+
message = message[0]
|
52 |
+
ret += role + ": " + message + self.sep
|
53 |
+
else:
|
54 |
+
ret += role + ":"
|
55 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
56 |
+
seps = [self.sep, self.sep2]
|
57 |
+
ret = self.system + seps[0]
|
58 |
+
for i, (role, message) in enumerate(messages):
|
59 |
+
if message:
|
60 |
+
if type(message) is tuple:
|
61 |
+
message = message[0]
|
62 |
+
ret += role + ": " + message + seps[i % 2]
|
63 |
+
else:
|
64 |
+
ret += role + ":"
|
65 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
66 |
+
ret = self.system + self.sep
|
67 |
+
for role, message in messages:
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message = message[0]
|
71 |
+
ret += role + message + self.sep
|
72 |
+
else:
|
73 |
+
ret += role
|
74 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
75 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
76 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
77 |
+
ret = ""
|
78 |
+
|
79 |
+
for i, (role, message) in enumerate(messages):
|
80 |
+
if i == 0:
|
81 |
+
assert message, "first message should not be none"
|
82 |
+
assert role == self.roles[0], "first message should come from user"
|
83 |
+
if message:
|
84 |
+
if type(message) is tuple:
|
85 |
+
message, _, _ = message
|
86 |
+
if i == 0: message = wrap_sys(self.system) + message
|
87 |
+
if i % 2 == 0:
|
88 |
+
message = wrap_inst(message)
|
89 |
+
ret += self.sep + message
|
90 |
+
else:
|
91 |
+
ret += " " + message + " " + self.sep2
|
92 |
+
else:
|
93 |
+
ret += ""
|
94 |
+
ret = ret.lstrip(self.sep)
|
95 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
96 |
+
seps = [self.sep, self.sep2]
|
97 |
+
ret = self.system + seps[0]
|
98 |
+
for i, (role, message) in enumerate(messages):
|
99 |
+
if message:
|
100 |
+
if type(message) is tuple:
|
101 |
+
message, _, _ = message
|
102 |
+
ret += "<start_of_turn>" + role + "\n" + message + "<end_of_turn>\n" + seps[i % 2]
|
103 |
+
else:
|
104 |
+
ret += "<start_of_turn>" + role + "\n"
|
105 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
106 |
+
seps = [self.sep, self.sep2]
|
107 |
+
ret = self.system
|
108 |
+
for i, (role, message) in enumerate(messages):
|
109 |
+
if message:
|
110 |
+
if type(message) is tuple:
|
111 |
+
message, _, _ = message
|
112 |
+
ret += message + seps[i % 2]
|
113 |
+
else:
|
114 |
+
ret += ""
|
115 |
+
else:
|
116 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
117 |
+
|
118 |
+
return ret
|
119 |
+
|
120 |
+
def append_message(self, role, message):
|
121 |
+
self.messages.append([role, message])
|
122 |
+
|
123 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
|
124 |
+
if image_process_mode == "Pad":
|
125 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
126 |
+
width, height = pil_img.size
|
127 |
+
if width == height:
|
128 |
+
return pil_img
|
129 |
+
elif width > height:
|
130 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
131 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
132 |
+
return result
|
133 |
+
else:
|
134 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
135 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
136 |
+
return result
|
137 |
+
image = expand2square(image)
|
138 |
+
elif image_process_mode in ["Default", "Crop"]:
|
139 |
+
pass
|
140 |
+
elif image_process_mode == "Resize":
|
141 |
+
image = image.resize((336, 336))
|
142 |
+
else:
|
143 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
144 |
+
if max(image.size) > max_len:
|
145 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
146 |
+
aspect_ratio = max_hw / min_hw
|
147 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
148 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
149 |
+
W, H = image.size
|
150 |
+
if H > W:
|
151 |
+
H, W = longest_edge, shortest_edge
|
152 |
+
else:
|
153 |
+
H, W = shortest_edge, longest_edge
|
154 |
+
image = image.resize((W, H))
|
155 |
+
if return_pil:
|
156 |
+
return image
|
157 |
+
else:
|
158 |
+
buffered = BytesIO()
|
159 |
+
image.save(buffered, format=image_format)
|
160 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
161 |
+
return img_b64_str
|
162 |
+
|
163 |
+
def get_images(self, return_pil=False):
|
164 |
+
images = []
|
165 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
166 |
+
if i % 2 == 0:
|
167 |
+
if type(msg) is tuple:
|
168 |
+
msg, image, image_process_mode = msg
|
169 |
+
image = self.process_image(image, image_process_mode, return_pil=return_pil)
|
170 |
+
images.append(image)
|
171 |
+
return images
|
172 |
+
|
173 |
+
def to_gradio_chatbot(self):
|
174 |
+
ret = []
|
175 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
176 |
+
if i % 2 == 0:
|
177 |
+
if type(msg) is tuple:
|
178 |
+
msg, image, image_process_mode = msg
|
179 |
+
img_b64_str = self.process_image(
|
180 |
+
image, "Default", return_pil=False,
|
181 |
+
image_format='JPEG')
|
182 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
183 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
184 |
+
ret.append([msg, None])
|
185 |
+
else:
|
186 |
+
ret.append([msg, None])
|
187 |
+
else:
|
188 |
+
if type(msg) is tuple and len(msg) == 2:
|
189 |
+
msg, img_b64_str = msg
|
190 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
|
191 |
+
msg = msg.strip() + img_str
|
192 |
+
ret[-1][-1] = msg
|
193 |
+
return ret
|
194 |
+
|
195 |
+
def copy(self):
|
196 |
+
return Conversation(
|
197 |
+
system=self.system,
|
198 |
+
roles=self.roles,
|
199 |
+
messages=[[x, y] for x, y in self.messages],
|
200 |
+
offset=self.offset,
|
201 |
+
sep_style=self.sep_style,
|
202 |
+
sep=self.sep,
|
203 |
+
sep2=self.sep2,
|
204 |
+
version=self.version)
|
205 |
+
|
206 |
+
def dict(self):
|
207 |
+
if len(self.get_images()) > 0:
|
208 |
+
return {
|
209 |
+
"system": self.system,
|
210 |
+
"roles": self.roles,
|
211 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
212 |
+
"offset": self.offset,
|
213 |
+
"sep": self.sep,
|
214 |
+
"sep2": self.sep2,
|
215 |
+
}
|
216 |
+
return {
|
217 |
+
"system": self.system,
|
218 |
+
"roles": self.roles,
|
219 |
+
"messages": self.messages,
|
220 |
+
"offset": self.offset,
|
221 |
+
"sep": self.sep,
|
222 |
+
"sep2": self.sep2,
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
conv_vicuna_v0 = Conversation(
|
227 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
228 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
229 |
+
roles=("Human", "Assistant"),
|
230 |
+
messages=(
|
231 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
232 |
+
("Assistant",
|
233 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
234 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
235 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
236 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
237 |
+
"renewable and non-renewable energy sources:\n"
|
238 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
239 |
+
"energy sources are finite and will eventually run out.\n"
|
240 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
241 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
242 |
+
"and other negative effects.\n"
|
243 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
244 |
+
"have lower operational costs than non-renewable sources.\n"
|
245 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
246 |
+
"locations than non-renewable sources.\n"
|
247 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
248 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
249 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
250 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
251 |
+
),
|
252 |
+
offset=2,
|
253 |
+
sep_style=SeparatorStyle.SINGLE,
|
254 |
+
sep="###",
|
255 |
+
)
|
256 |
+
|
257 |
+
conv_vicuna_v1 = Conversation(
|
258 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
259 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
260 |
+
roles=("USER", "ASSISTANT"),
|
261 |
+
version="v1",
|
262 |
+
messages=(),
|
263 |
+
offset=0,
|
264 |
+
sep_style=SeparatorStyle.TWO,
|
265 |
+
sep=" ",
|
266 |
+
sep2="</s>",
|
267 |
+
)
|
268 |
+
|
269 |
+
conv_llama_2 = Conversation(
|
270 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
271 |
+
|
272 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
273 |
+
roles=("USER", "ASSISTANT"),
|
274 |
+
version="llama_v2",
|
275 |
+
messages=(),
|
276 |
+
offset=0,
|
277 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
278 |
+
sep="<s>",
|
279 |
+
sep2="</s>",
|
280 |
+
)
|
281 |
+
|
282 |
+
conv_llava_llama_2 = Conversation(
|
283 |
+
system="You are a helpful language and vision assistant. "
|
284 |
+
"You are able to understand the visual content that the user provides, "
|
285 |
+
"and assist the user with a variety of tasks using natural language.",
|
286 |
+
roles=("USER", "ASSISTANT"),
|
287 |
+
version="llama_v2",
|
288 |
+
messages=(),
|
289 |
+
offset=0,
|
290 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
291 |
+
sep="<s>",
|
292 |
+
sep2="</s>",
|
293 |
+
)
|
294 |
+
|
295 |
+
conv_mpt = Conversation(
|
296 |
+
system="""<|im_start|>system
|
297 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
298 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
299 |
+
version="mpt",
|
300 |
+
messages=(),
|
301 |
+
offset=0,
|
302 |
+
sep_style=SeparatorStyle.MPT,
|
303 |
+
sep="<|im_end|>",
|
304 |
+
)
|
305 |
+
|
306 |
+
conv_llava_plain = Conversation(
|
307 |
+
system="",
|
308 |
+
roles=("", ""),
|
309 |
+
messages=(
|
310 |
+
),
|
311 |
+
offset=0,
|
312 |
+
sep_style=SeparatorStyle.PLAIN,
|
313 |
+
sep="\n",
|
314 |
+
)
|
315 |
+
|
316 |
+
conv_llava_v0 = Conversation(
|
317 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
318 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
319 |
+
roles=("Human", "Assistant"),
|
320 |
+
messages=(
|
321 |
+
),
|
322 |
+
offset=0,
|
323 |
+
sep_style=SeparatorStyle.SINGLE,
|
324 |
+
sep="###",
|
325 |
+
)
|
326 |
+
|
327 |
+
conv_llava_v0_mmtag = Conversation(
|
328 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
329 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
330 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
331 |
+
roles=("Human", "Assistant"),
|
332 |
+
messages=(
|
333 |
+
),
|
334 |
+
offset=0,
|
335 |
+
sep_style=SeparatorStyle.SINGLE,
|
336 |
+
sep="###",
|
337 |
+
version="v0_mmtag",
|
338 |
+
)
|
339 |
+
|
340 |
+
conv_llava_v1 = Conversation(
|
341 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
342 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
343 |
+
roles=("USER", "ASSISTANT"),
|
344 |
+
version="v1",
|
345 |
+
messages=(),
|
346 |
+
offset=0,
|
347 |
+
sep_style=SeparatorStyle.TWO,
|
348 |
+
sep=" ",
|
349 |
+
sep2="</s>",
|
350 |
+
)
|
351 |
+
|
352 |
+
conv_vicuna_imgsp_v1 = Conversation(
|
353 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
354 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
355 |
+
roles=("USER", "ASSISTANT"),
|
356 |
+
version="imgsp_v1",
|
357 |
+
messages=(),
|
358 |
+
offset=0,
|
359 |
+
sep_style=SeparatorStyle.TWO,
|
360 |
+
sep=" ",
|
361 |
+
sep2="</s>",
|
362 |
+
)
|
363 |
+
|
364 |
+
conv_llava_plain_guided = Conversation(
|
365 |
+
system="",
|
366 |
+
roles=("", ""),
|
367 |
+
version="plain_guided",
|
368 |
+
messages=(
|
369 |
+
),
|
370 |
+
offset=0,
|
371 |
+
sep_style=SeparatorStyle.PLAIN,
|
372 |
+
sep="\n",
|
373 |
+
)
|
374 |
+
|
375 |
+
conv_llava_v1_mmtag = Conversation(
|
376 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
377 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
378 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
379 |
+
roles=("USER", "ASSISTANT"),
|
380 |
+
messages=(),
|
381 |
+
offset=0,
|
382 |
+
sep_style=SeparatorStyle.TWO,
|
383 |
+
sep=" ",
|
384 |
+
sep2="</s>",
|
385 |
+
version="v1_mmtag",
|
386 |
+
)
|
387 |
+
|
388 |
+
conv_phi_2 = Conversation(
|
389 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
390 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
391 |
+
roles=("USER", "ASSISTANT"),
|
392 |
+
version="phi2",
|
393 |
+
messages=(),
|
394 |
+
offset=0,
|
395 |
+
sep_style=SeparatorStyle.TWO,
|
396 |
+
sep=" ",
|
397 |
+
sep2="<|endoftext|>",
|
398 |
+
)
|
399 |
+
|
400 |
+
conv_mistral_instruct = Conversation(
|
401 |
+
system="",
|
402 |
+
roles=("USER", "ASSISTANT"),
|
403 |
+
version="llama_v2",
|
404 |
+
messages=(),
|
405 |
+
offset=0,
|
406 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
407 |
+
sep="<s>",
|
408 |
+
sep2="</s>",
|
409 |
+
)
|
410 |
+
|
411 |
+
conv_gemma = Conversation(
|
412 |
+
system="",
|
413 |
+
roles=("user", "model"),
|
414 |
+
version="gemma",
|
415 |
+
messages=(),
|
416 |
+
offset=0,
|
417 |
+
sep_style=SeparatorStyle.GEMMA,
|
418 |
+
sep="",
|
419 |
+
sep2="<eos>",
|
420 |
+
)
|
421 |
+
|
422 |
+
conv_chatml_direct = Conversation(
|
423 |
+
system="""<|im_start|>system
|
424 |
+
Answer the questions.""",
|
425 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
426 |
+
version="mpt",
|
427 |
+
messages=(),
|
428 |
+
offset=0,
|
429 |
+
sep_style=SeparatorStyle.MPT,
|
430 |
+
sep="<|im_end|>",
|
431 |
+
)
|
432 |
+
|
433 |
+
default_conversation = conv_vicuna_v1
|
434 |
+
conv_templates = {
|
435 |
+
"default": conv_vicuna_v0,
|
436 |
+
"v0": conv_vicuna_v0,
|
437 |
+
"v1": conv_vicuna_v1,
|
438 |
+
"vicuna_v1": conv_vicuna_v1,
|
439 |
+
"phi_2": conv_phi_2,
|
440 |
+
"gemma": conv_gemma,
|
441 |
+
"llama_2": conv_llama_2,
|
442 |
+
"imgsp_v1": conv_vicuna_imgsp_v1,
|
443 |
+
"plain_guided": conv_llava_plain_guided,
|
444 |
+
"mistral_instruct": conv_mistral_instruct,
|
445 |
+
"chatml_direct": conv_chatml_direct,
|
446 |
+
"mistral_direct": conv_chatml_direct,
|
447 |
+
"plain": conv_llava_plain,
|
448 |
+
"v0_plain": conv_llava_plain,
|
449 |
+
"llava_v0": conv_llava_v0,
|
450 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
451 |
+
"llava_v1": conv_llava_v1,
|
452 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
453 |
+
"llava_llama_2": conv_llava_llama_2,
|
454 |
+
|
455 |
+
"mpt": conv_mpt,
|
456 |
+
}
|
457 |
+
|
458 |
+
|
459 |
+
if __name__ == "__main__":
|
460 |
+
print(default_conversation.get_prompt())
|
i2t.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import shortuuid
|
8 |
+
from tqdm import tqdm
|
9 |
+
from PIL import Image
|
10 |
+
from PIL import ImageFile
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
14 |
+
from conversation import conv_templates, SeparatorStyle
|
15 |
+
from model.builder import load_pretrained_model
|
16 |
+
from tools import disable_torch_init
|
17 |
+
from mm_utils import tokenizer_image_token, get_model_name_from_path
|
18 |
+
from torch.utils.data import Dataset, DataLoader
|
19 |
+
|
20 |
+
|
21 |
+
from unitok.config import Args
|
22 |
+
from unitok.model import UniTok
|
23 |
+
|
24 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = False
|
25 |
+
torch.set_grad_enabled(False)
|
26 |
+
|
27 |
+
|
28 |
+
def split_list(lst, n):
|
29 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
30 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
31 |
+
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
|
32 |
+
|
33 |
+
|
34 |
+
def get_chunk(lst, n, k):
|
35 |
+
chunks = split_list(lst, n)
|
36 |
+
return chunks[k]
|
37 |
+
|
38 |
+
|
39 |
+
# Custom dataset class
|
40 |
+
class CustomDataset(Dataset):
|
41 |
+
def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
|
42 |
+
self.questions = questions
|
43 |
+
self.image_folder = image_folder
|
44 |
+
self.tokenizer = tokenizer
|
45 |
+
self.image_processor = image_processor
|
46 |
+
self.model_config = model_config
|
47 |
+
|
48 |
+
def __getitem__(self, index):
|
49 |
+
line = self.questions[index]
|
50 |
+
image_file = line["image"]
|
51 |
+
qs = line["text"]
|
52 |
+
|
53 |
+
qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
|
54 |
+
|
55 |
+
conv = conv_templates[args.conv_mode].copy()
|
56 |
+
conv.append_message(conv.roles[0], qs)
|
57 |
+
conv.append_message(conv.roles[1], None)
|
58 |
+
prompt = conv.get_prompt()
|
59 |
+
# prompt = prompt.replace('<image>','<boi><image><eoi>')
|
60 |
+
# import pdb;pdb.set_trace()
|
61 |
+
image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
|
62 |
+
# import pdb;pdb.set_trace()
|
63 |
+
pad_image = expand2square(image, (122, 116, 104) )
|
64 |
+
# import pdb;pdb.set_trace()
|
65 |
+
img = self.image_processor[0](pad_image).unsqueeze(0)
|
66 |
+
img = img.to('cuda')
|
67 |
+
# import pdb;pdb.set_trace()
|
68 |
+
with torch.no_grad():
|
69 |
+
vq_code = self.image_processor[1].img_to_idx(img)
|
70 |
+
vqcode = vq_code.cpu()
|
71 |
+
|
72 |
+
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
|
73 |
+
|
74 |
+
|
75 |
+
return input_ids,vqcode,os.path.join(self.image_folder, image_file) #, image_tensor, image_tensor_aux
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self.questions)
|
79 |
+
|
80 |
+
|
81 |
+
# DataLoader
|
82 |
+
def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=0):
|
83 |
+
assert batch_size == 1, "batch_size must be 1"
|
84 |
+
dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
|
85 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
|
86 |
+
return data_loader
|
87 |
+
|
88 |
+
def expand2square(pil_img, background_color):
|
89 |
+
width, height = pil_img.size
|
90 |
+
if width == height:
|
91 |
+
return pil_img
|
92 |
+
elif width > height:
|
93 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
94 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
95 |
+
return result
|
96 |
+
else:
|
97 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
98 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
99 |
+
return result
|
100 |
+
|
101 |
+
def eval_model(args):
|
102 |
+
# Model
|
103 |
+
disable_torch_init()
|
104 |
+
model_path = os.path.expanduser(args.model_path)
|
105 |
+
model_name = get_model_name_from_path(model_path)
|
106 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, load_8bit=args.load_8bit)
|
107 |
+
|
108 |
+
ckpt = torch.load(args.tokenizer_path, map_location='cpu')
|
109 |
+
vae_cfg = Args()
|
110 |
+
vae_cfg.load_state_dict(ckpt['args'])
|
111 |
+
vq_model = UniTok(vae_cfg)
|
112 |
+
vq_model.load_state_dict(ckpt['trainer']['unitok'])
|
113 |
+
vq_model.to('cuda')
|
114 |
+
vq_model.eval()
|
115 |
+
del ckpt
|
116 |
+
|
117 |
+
crop_size = 256
|
118 |
+
transform = transforms.Compose([
|
119 |
+
transforms.Resize((crop_size, crop_size)),
|
120 |
+
transforms.ToTensor(),
|
121 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
122 |
+
])
|
123 |
+
image_processor = (transform, vq_model)
|
124 |
+
|
125 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
126 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
127 |
+
answers_file = os.path.expanduser(args.answers_file)
|
128 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
129 |
+
ans_file = open(answers_file, "w")
|
130 |
+
|
131 |
+
if 'plain' in args.conv_mode and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
|
132 |
+
args.conv_mode = args.conv_mode + '_mmtag'
|
133 |
+
print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
|
134 |
+
|
135 |
+
data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
|
136 |
+
|
137 |
+
for (input_ids, image_codes,imagepath), line in tqdm(zip(data_loader, questions), total=len(questions)):
|
138 |
+
idx = line["question_id"]
|
139 |
+
cur_prompt = line["text"]
|
140 |
+
|
141 |
+
input_ids = input_ids.to(device=model.device, non_blocking=True)
|
142 |
+
image_codes = image_codes.to(device=model.device, non_blocking=True)
|
143 |
+
if hasattr(model, "update_prompt"):
|
144 |
+
model.update_prompt([[cur_prompt]])
|
145 |
+
|
146 |
+
with torch.inference_mode():
|
147 |
+
output_ids = model.generate_mllm(
|
148 |
+
input_ids,
|
149 |
+
images=image_codes,
|
150 |
+
images_aux= None,
|
151 |
+
do_sample=True if args.temperature > 0 else False,
|
152 |
+
temperature=args.temperature,
|
153 |
+
top_p=args.top_p,
|
154 |
+
num_beams=args.num_beams,
|
155 |
+
max_new_tokens=args.max_new_tokens,
|
156 |
+
bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
|
157 |
+
eos_token_id=tokenizer.eos_token_id, # End of sequence token
|
158 |
+
pad_token_id=tokenizer.pad_token_id, # Pad token
|
159 |
+
use_cache=False
|
160 |
+
)
|
161 |
+
|
162 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
163 |
+
ans_id = shortuuid.uuid()
|
164 |
+
ans_file.write(json.dumps({
|
165 |
+
"question_id": idx,
|
166 |
+
"prompt": cur_prompt,
|
167 |
+
"text": outputs,
|
168 |
+
"answer_id": ans_id,
|
169 |
+
"model_id": model_name,
|
170 |
+
"metadata": {}
|
171 |
+
}) + "\n")
|
172 |
+
ans_file.close()
|
173 |
+
|
174 |
+
if __name__ == "__main__":
|
175 |
+
parser = argparse.ArgumentParser()
|
176 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
177 |
+
parser.add_argument("--tokenizer-path", type=str, required=True)
|
178 |
+
parser.add_argument("--model-base", type=str, default=None)
|
179 |
+
parser.add_argument("--image-folder", type=str, default="")
|
180 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
181 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
182 |
+
parser.add_argument("--conv-mode", type=str, default="llava_v1")
|
183 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
184 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
185 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
186 |
+
parser.add_argument("--top_p", type=float, default=None)
|
187 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
188 |
+
parser.add_argument('--load_8bit', type=bool, default=False)
|
189 |
+
parser.add_argument("--max_new_tokens", type=int, default=128)
|
190 |
+
args = parser.parse_args()
|
191 |
+
|
192 |
+
eval_model(args)
|
mm_utils.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from constants import IMAGE_TOKEN_INDEX
|
8 |
+
|
9 |
+
|
10 |
+
def load_image_from_base64(image):
|
11 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
12 |
+
|
13 |
+
|
14 |
+
def expand2square(pil_img, background_color):
|
15 |
+
width, height = pil_img.size
|
16 |
+
if width == height:
|
17 |
+
return pil_img
|
18 |
+
elif width > height:
|
19 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
20 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
21 |
+
return result
|
22 |
+
else:
|
23 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
24 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
25 |
+
return result
|
26 |
+
|
27 |
+
|
28 |
+
def process_images(images, image_processor, model_cfg):
|
29 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
30 |
+
new_images = []
|
31 |
+
if image_aspect_ratio == 'pad':
|
32 |
+
for image in images:
|
33 |
+
image = expand2square(image.convert('RGB'), tuple(int(x*255) for x in image_processor.image_mean))
|
34 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
35 |
+
new_images.append(image)
|
36 |
+
else:
|
37 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
38 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
39 |
+
new_images = torch.stack(new_images, dim=0)
|
40 |
+
return new_images
|
41 |
+
|
42 |
+
|
43 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
44 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
|
45 |
+
|
46 |
+
def insert_separator(X, sep):
|
47 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
48 |
+
|
49 |
+
input_ids = []
|
50 |
+
offset = 0
|
51 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
52 |
+
offset = 1
|
53 |
+
input_ids.append(prompt_chunks[0][0])
|
54 |
+
|
55 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
56 |
+
input_ids.extend(x[offset:])
|
57 |
+
|
58 |
+
if return_tensors is not None:
|
59 |
+
if return_tensors == 'pt':
|
60 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
61 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
62 |
+
return input_ids
|
63 |
+
|
64 |
+
|
65 |
+
def get_model_name_from_path(model_path):
|
66 |
+
model_path = model_path.strip("/")
|
67 |
+
model_paths = model_path.split("/")
|
68 |
+
if model_paths[-1].startswith('checkpoint-'):
|
69 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
70 |
+
else:
|
71 |
+
return model_paths[-1]
|
72 |
+
|
73 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
74 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
75 |
+
self.keywords = keywords
|
76 |
+
self.keyword_ids = []
|
77 |
+
self.max_keyword_len = 0
|
78 |
+
for keyword in keywords:
|
79 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
80 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
81 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
82 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
83 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
84 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
85 |
+
self.tokenizer = tokenizer
|
86 |
+
self.start_len = input_ids.shape[1]
|
87 |
+
|
88 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
89 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
90 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
91 |
+
for keyword_id in self.keyword_ids:
|
92 |
+
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
|
93 |
+
if torch.equal(truncated_output_ids, keyword_id):
|
94 |
+
return True
|
95 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
96 |
+
for keyword in self.keywords:
|
97 |
+
if keyword in outputs:
|
98 |
+
return True
|
99 |
+
return False
|
100 |
+
|
101 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
102 |
+
outputs = []
|
103 |
+
for i in range(output_ids.shape[0]):
|
104 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
105 |
+
return all(outputs)
|
model/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .language_model.mini_gemini_llama import MiniGeminiLlamaForCausalLM
|
2 |
+
try:
|
3 |
+
from .language_model.mini_gemini_mistral import MiniGeminiMistralForCausalLM
|
4 |
+
from .language_model.mini_gemini_mixtral import MiniGeminiMixtralForCausalLM
|
5 |
+
from .language_model.mini_gemini_gemma import MiniGeminiGemmaForCausalLM
|
6 |
+
except:
|
7 |
+
ImportWarning("New model not imported. Try to update Transformers.")
|
model/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (714 Bytes). View file
|
|
model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (585 Bytes). View file
|
|
model/__pycache__/arhead.cpython-311.pyc
ADDED
Binary file (11.1 kB). View file
|
|
model/__pycache__/arhead.cpython-39.pyc
ADDED
Binary file (5.89 kB). View file
|
|
model/__pycache__/builder.cpython-311.pyc
ADDED
Binary file (5.87 kB). View file
|
|
model/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (2.91 kB). View file
|
|
model/__pycache__/liquid.cpython-311.pyc
ADDED
Binary file (43.8 kB). View file
|
|
model/__pycache__/mini_gemini_arch.cpython-311.pyc
ADDED
Binary file (49.9 kB). View file
|
|
model/__pycache__/mini_gemini_arch.cpython-39.pyc
ADDED
Binary file (20.5 kB). View file
|
|
model/__pycache__/quant.cpython-311.pyc
ADDED
Binary file (38.9 kB). View file
|
|
model/__pycache__/quant.cpython-39.pyc
ADDED
Binary file (16.5 kB). View file
|
|
model/arhead.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.checkpoint
|
3 |
+
from torch import nn
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
7 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaDecoderLayer
|
8 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
9 |
+
|
10 |
+
class AR_head(nn.Module):
|
11 |
+
"""
|
12 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
|
13 |
+
|
14 |
+
Args:
|
15 |
+
config: GemmaConfig
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, config, codebook_size, num_codebooks):
|
19 |
+
super().__init__()
|
20 |
+
# import pdb;pdb.set_trace()
|
21 |
+
self.num_codebooks = num_codebooks
|
22 |
+
vocab_size = codebook_size
|
23 |
+
self.sub_vocab_size = vocab_size // self.num_codebooks
|
24 |
+
|
25 |
+
# self.layers = nn.ModuleList(
|
26 |
+
# [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(3)]
|
27 |
+
# )
|
28 |
+
# self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
29 |
+
self.linear_head = nn.Linear(config.hidden_size, self.sub_vocab_size)
|
30 |
+
|
31 |
+
self.layers = nn.ModuleList(
|
32 |
+
[LlamaDecoderLayer(config, layer_idx) for layer_idx in range(3)]
|
33 |
+
)
|
34 |
+
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
35 |
+
self.gradient_checkpointing = False
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
# vocab_size 16384
|
41 |
+
self.codebooks = nn.ModuleList()
|
42 |
+
for _ in range(self.num_codebooks-1):
|
43 |
+
codebook = nn.Embedding(self.sub_vocab_size, config.hidden_size)
|
44 |
+
self.codebooks.append(codebook)
|
45 |
+
# import pdb;pdb.set_trace()
|
46 |
+
self.config = config
|
47 |
+
self.gradient_checkpointing = False
|
48 |
+
|
49 |
+
# Initialize weights and apply final processing
|
50 |
+
self._init_weights(self.layers)
|
51 |
+
|
52 |
+
def set_input_embeddings(self, value):
|
53 |
+
self.embed_tokens = value
|
54 |
+
|
55 |
+
def _init_weights(self, module):
|
56 |
+
std = self.config.initializer_range
|
57 |
+
if isinstance(module, nn.Linear):
|
58 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
59 |
+
if module.bias is not None:
|
60 |
+
module.bias.data.zero_()
|
61 |
+
elif isinstance(module, nn.Embedding):
|
62 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
63 |
+
if module.padding_idx is not None:
|
64 |
+
module.weight.data[module.padding_idx].zero_()
|
65 |
+
|
66 |
+
# Ignore copy
|
67 |
+
def forward(
|
68 |
+
self,
|
69 |
+
input_ids: torch.LongTensor = None,
|
70 |
+
attention_mask: Optional[torch.Tensor] = None,
|
71 |
+
position_ids: Optional[torch.LongTensor] = None,
|
72 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
73 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
74 |
+
use_cache: Optional[bool] = None,
|
75 |
+
output_attentions: Optional[bool] = None,
|
76 |
+
output_hidden_states: Optional[bool] = None,
|
77 |
+
return_dict: Optional[bool] = None,
|
78 |
+
cache_position: Optional[torch.LongTensor] = None,
|
79 |
+
) -> torch.tensor:
|
80 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
81 |
+
output_hidden_states = (
|
82 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
83 |
+
)
|
84 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
85 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
86 |
+
|
87 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
88 |
+
raise ValueError(
|
89 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
90 |
+
)
|
91 |
+
|
92 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
93 |
+
logger.warning_once(
|
94 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
95 |
+
)
|
96 |
+
use_cache = False
|
97 |
+
|
98 |
+
if inputs_embeds is None:
|
99 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
100 |
+
|
101 |
+
past_seen_tokens = 0
|
102 |
+
if use_cache: # kept for BC (cache positions)
|
103 |
+
if not isinstance(past_key_values, StaticCache):
|
104 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
105 |
+
past_seen_tokens = past_key_values.get_seq_length()
|
106 |
+
|
107 |
+
if cache_position is None:
|
108 |
+
if isinstance(past_key_values, StaticCache):
|
109 |
+
raise ValueError("cache_position is a required argument when using StaticCache.")
|
110 |
+
cache_position = torch.arange(
|
111 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
112 |
+
)
|
113 |
+
|
114 |
+
if position_ids is None:
|
115 |
+
position_ids = cache_position.unsqueeze(0)
|
116 |
+
|
117 |
+
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
118 |
+
|
119 |
+
# embed positions
|
120 |
+
hidden_states = inputs_embeds
|
121 |
+
|
122 |
+
# decoder layers
|
123 |
+
all_hidden_states = () if output_hidden_states else None
|
124 |
+
all_self_attns = () if output_attentions else None
|
125 |
+
next_decoder_cache = None
|
126 |
+
|
127 |
+
for decoder_layer in self.layers:
|
128 |
+
if output_hidden_states:
|
129 |
+
all_hidden_states += (hidden_states,)
|
130 |
+
|
131 |
+
if self.gradient_checkpointing and self.training:
|
132 |
+
layer_outputs = self._gradient_checkpointing_func(
|
133 |
+
decoder_layer.__call__,
|
134 |
+
hidden_states,
|
135 |
+
causal_mask,
|
136 |
+
position_ids,
|
137 |
+
past_key_values,
|
138 |
+
output_attentions,
|
139 |
+
use_cache,
|
140 |
+
cache_position,
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
layer_outputs = decoder_layer(
|
144 |
+
hidden_states,
|
145 |
+
attention_mask=causal_mask,
|
146 |
+
position_ids=position_ids,
|
147 |
+
past_key_value=past_key_values,
|
148 |
+
output_attentions=output_attentions,
|
149 |
+
use_cache=use_cache,
|
150 |
+
cache_position=cache_position,
|
151 |
+
)
|
152 |
+
|
153 |
+
hidden_states = layer_outputs[0]
|
154 |
+
|
155 |
+
if use_cache:
|
156 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
157 |
+
|
158 |
+
if output_attentions:
|
159 |
+
all_self_attns += (layer_outputs[1],)
|
160 |
+
|
161 |
+
hidden_states = self.norm(hidden_states)
|
162 |
+
|
163 |
+
if output_hidden_states:
|
164 |
+
all_hidden_states += (hidden_states,)
|
165 |
+
|
166 |
+
next_cache = None
|
167 |
+
if use_cache:
|
168 |
+
next_cache = (
|
169 |
+
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
170 |
+
)
|
171 |
+
if not return_dict:
|
172 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
173 |
+
return BaseModelOutputWithPast(
|
174 |
+
last_hidden_state=hidden_states,
|
175 |
+
past_key_values=next_cache,
|
176 |
+
hidden_states=all_hidden_states,
|
177 |
+
attentions=all_self_attns,
|
178 |
+
)
|
179 |
+
|
180 |
+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
181 |
+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
182 |
+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
183 |
+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
184 |
+
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
185 |
+
if self.config._attn_implementation == "flash_attention_2":
|
186 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
187 |
+
return attention_mask
|
188 |
+
return None
|
189 |
+
|
190 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
191 |
+
min_dtype = torch.finfo(dtype).min
|
192 |
+
sequence_length = input_tensor.shape[1]
|
193 |
+
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
|
194 |
+
target_length = self.config.max_position_embeddings
|
195 |
+
else: # dynamic cache
|
196 |
+
target_length = (
|
197 |
+
attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
198 |
+
)
|
199 |
+
|
200 |
+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
201 |
+
if sequence_length != 1:
|
202 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
203 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
204 |
+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
205 |
+
if attention_mask is not None:
|
206 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
207 |
+
if attention_mask.dim() == 2:
|
208 |
+
mask_length = attention_mask.shape[-1]
|
209 |
+
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
210 |
+
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
211 |
+
elif attention_mask.dim() == 4:
|
212 |
+
# backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
213 |
+
# cache. In that case, the 4D attention mask attends to the newest tokens only.
|
214 |
+
if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
215 |
+
offset = cache_position[0]
|
216 |
+
else:
|
217 |
+
offset = 0
|
218 |
+
mask_shape = attention_mask.shape
|
219 |
+
mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
220 |
+
causal_mask[
|
221 |
+
: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
|
222 |
+
] = mask_slice
|
223 |
+
|
224 |
+
if (
|
225 |
+
self.config._attn_implementation == "sdpa"
|
226 |
+
and attention_mask is not None
|
227 |
+
and attention_mask.device.type == "cuda"
|
228 |
+
):
|
229 |
+
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
230 |
+
is_tracing = (
|
231 |
+
torch.jit.is_tracing()
|
232 |
+
or isinstance(input_tensor, torch.fx.Proxy)
|
233 |
+
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
234 |
+
)
|
235 |
+
if not is_tracing and torch.any(attention_mask != 1):
|
236 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
237 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
238 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
239 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
240 |
+
|
241 |
+
return causal_mask
|
model/builder.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
|
16 |
+
# Copyright 2024 Yanwei Li
|
17 |
+
# ------------------------------------------------------------------------
|
18 |
+
|
19 |
+
import os
|
20 |
+
import torch
|
21 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
22 |
+
|
23 |
+
from model import *
|
24 |
+
from constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
25 |
+
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
|
27 |
+
kwargs = {"device_map": device_map, **kwargs}
|
28 |
+
|
29 |
+
if device != "cuda":
|
30 |
+
kwargs['device_map'] = {"": device}
|
31 |
+
|
32 |
+
if load_8bit:
|
33 |
+
kwargs['load_in_8bit'] = True
|
34 |
+
elif load_4bit:
|
35 |
+
kwargs['load_in_4bit'] = True
|
36 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
37 |
+
load_in_4bit=True,
|
38 |
+
bnb_4bit_compute_dtype=torch.float16,
|
39 |
+
bnb_4bit_use_double_quant=True,
|
40 |
+
bnb_4bit_quant_type='nf4'
|
41 |
+
)
|
42 |
+
else:
|
43 |
+
kwargs['torch_dtype'] = torch.float16
|
44 |
+
|
45 |
+
if use_flash_attn:
|
46 |
+
kwargs['attn_implementation'] = 'flash_attention_2'
|
47 |
+
# import pdb;pdb.set_trace()
|
48 |
+
if 'mini-gemini' in model_name.lower():
|
49 |
+
# Load MiniGemini model
|
50 |
+
if model_base is not None:
|
51 |
+
# this may be mm projector only
|
52 |
+
print('Loading MiniGemini from base model...')
|
53 |
+
|
54 |
+
if "8x7b" in model_name.lower():
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base)
|
56 |
+
model = MiniGeminiMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
57 |
+
elif "2b" in model_name.lower():
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base)
|
59 |
+
model = MiniGeminiGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
60 |
+
else:
|
61 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
62 |
+
model = MiniGeminiLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
63 |
+
mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
64 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
65 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
66 |
+
else:
|
67 |
+
if "8x7b" in model_name.lower():
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
69 |
+
model = MiniGeminiMixtralForCausalLM.from_pretrained(model_path, **kwargs)
|
70 |
+
elif "2b" in model_name.lower():
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
72 |
+
model = MiniGeminiGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
73 |
+
else:
|
74 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
75 |
+
model = MiniGeminiLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
76 |
+
|
77 |
+
if 'gemma' in model_name.lower():
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
79 |
+
model = MiniGeminiGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
80 |
+
elif 'vicuna' in model_name.lower() or 'unitok' in model_name.lower():
|
81 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
82 |
+
model = MiniGeminiLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
83 |
+
else:
|
84 |
+
# Load language model
|
85 |
+
if model_base is not None:
|
86 |
+
# PEFT model
|
87 |
+
from peft import PeftModel
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
|
89 |
+
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
|
90 |
+
print(f"Loading LoRA weights from {model_path}")
|
91 |
+
model = PeftModel.from_pretrained(model, model_path)
|
92 |
+
print(f"Merging weights")
|
93 |
+
model = model.merge_and_unload()
|
94 |
+
print('Convert to FP16...')
|
95 |
+
model.to(torch.float16)
|
96 |
+
else:
|
97 |
+
if 'mpt' in model_name.lower():
|
98 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
99 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
|
100 |
+
else:
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
102 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
|
103 |
+
|
104 |
+
image_processor = None
|
105 |
+
# import pdb;pdb.set_trace()
|
106 |
+
|
107 |
+
|
108 |
+
# mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
109 |
+
# mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
110 |
+
# if mm_use_im_patch_token:
|
111 |
+
# tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
112 |
+
# if mm_use_im_start_end:
|
113 |
+
# tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
114 |
+
|
115 |
+
# model.resize_token_embeddings(len(tokenizer))
|
116 |
+
|
117 |
+
# vision_tower = model.get_vision_tower()
|
118 |
+
# if not vision_tower.is_loaded:
|
119 |
+
# vision_tower.load_model()
|
120 |
+
# vision_tower.to(device=device, dtype=torch.float16)
|
121 |
+
# image_processor = vision_tower.image_processor
|
122 |
+
|
123 |
+
if 'mini-gemini' in model_name.lower():
|
124 |
+
vision_tower_aux = model.get_vision_tower_aux()
|
125 |
+
if not vision_tower_aux.is_loaded:
|
126 |
+
vision_tower_aux.load_model()
|
127 |
+
vision_tower_aux.to(device=device, dtype=torch.float16)
|
128 |
+
|
129 |
+
# initialize attention modules
|
130 |
+
model.config.model_path = model_path
|
131 |
+
model.get_model().initialize_uni_modules(model.config, for_eval=True)
|
132 |
+
|
133 |
+
if hasattr(model.config, "max_sequence_length"):
|
134 |
+
context_len = model.config.max_sequence_length
|
135 |
+
else:
|
136 |
+
context_len = 2048
|
137 |
+
|
138 |
+
return tokenizer, model, image_processor, context_len
|
model/language_model/__pycache__/mini_gemini_llama.cpython-311.pyc
ADDED
Binary file (18 kB). View file
|
|
model/language_model/__pycache__/mini_gemini_llama.cpython-39.pyc
ADDED
Binary file (7.73 kB). View file
|
|
model/language_model/mini_gemini_llama.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
|
16 |
+
# Copyright 2024 Yanwei Li
|
17 |
+
# ------------------------------------------------------------------------
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch.nn import CrossEntropyLoss
|
23 |
+
from typing import List, Optional, Tuple, Union
|
24 |
+
|
25 |
+
from transformers.utils import logging
|
26 |
+
from transformers.generation.utils import GenerateOutput
|
27 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
28 |
+
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
|
29 |
+
|
30 |
+
from model.arhead import AR_head
|
31 |
+
from model.liquid import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class MiniGeminiConfig(LlamaConfig):
|
38 |
+
model_type = "mini_gemini"
|
39 |
+
|
40 |
+
|
41 |
+
class MiniGeminiLlamaModel(MiniGeminiMetaModel, LlamaModel):
|
42 |
+
config_class = MiniGeminiConfig
|
43 |
+
|
44 |
+
def __init__(self, config: LlamaConfig):
|
45 |
+
super(MiniGeminiLlamaModel, self).__init__(config)
|
46 |
+
|
47 |
+
|
48 |
+
class MiniGeminiLlamaForCausalLM(LlamaForCausalLM, MiniGeminiMetaForCausalLM):
|
49 |
+
config_class = MiniGeminiConfig
|
50 |
+
|
51 |
+
def __init__(self, config):
|
52 |
+
super(LlamaForCausalLM, self).__init__(config)
|
53 |
+
self.model = MiniGeminiLlamaModel(config)
|
54 |
+
self.pretraining_tp = config.pretraining_tp
|
55 |
+
self.vocab_size = config.vocab_size
|
56 |
+
|
57 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
58 |
+
self.ar_head = AR_head(self.config, codebook_size=32768, num_codebooks=8)
|
59 |
+
|
60 |
+
# Initialize weights and apply final processing
|
61 |
+
self.post_init()
|
62 |
+
|
63 |
+
def get_model(self):
|
64 |
+
return self.model
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
input_ids: torch.LongTensor = None,
|
69 |
+
attention_mask: Optional[torch.Tensor] = None,
|
70 |
+
position_ids: Optional[torch.LongTensor] = None,
|
71 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
72 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
73 |
+
labels: Optional[torch.LongTensor] = None,
|
74 |
+
data_types: torch.LongTensor = None,
|
75 |
+
use_cache: Optional[bool] = None,
|
76 |
+
cache_position: Optional[torch.LongTensor] = None,
|
77 |
+
output_attentions: Optional[bool] = None,
|
78 |
+
output_hidden_states: Optional[bool] = None,
|
79 |
+
images: Optional[torch.FloatTensor] = None,
|
80 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
81 |
+
return_dict: Optional[bool] = None,
|
82 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
83 |
+
|
84 |
+
|
85 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
86 |
+
output_hidden_states = (
|
87 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
88 |
+
)
|
89 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
90 |
+
|
91 |
+
additional_image_indexs = None
|
92 |
+
if inputs_embeds is None and past_key_values is None: # no in inference
|
93 |
+
(
|
94 |
+
input_ids,
|
95 |
+
position_ids,
|
96 |
+
attention_mask,
|
97 |
+
past_key_values,
|
98 |
+
inputs_embeds,
|
99 |
+
labels,
|
100 |
+
data_types,
|
101 |
+
additional_image_labels,
|
102 |
+
additional_image_indexs
|
103 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
104 |
+
input_ids,
|
105 |
+
position_ids,
|
106 |
+
attention_mask,
|
107 |
+
past_key_values,
|
108 |
+
labels,
|
109 |
+
images,
|
110 |
+
images_aux,
|
111 |
+
data_types
|
112 |
+
)
|
113 |
+
|
114 |
+
outputs = self.model(
|
115 |
+
input_ids=input_ids,
|
116 |
+
attention_mask=attention_mask,
|
117 |
+
position_ids=position_ids,
|
118 |
+
past_key_values=past_key_values,
|
119 |
+
inputs_embeds=inputs_embeds,
|
120 |
+
use_cache=use_cache,
|
121 |
+
output_attentions=output_attentions,
|
122 |
+
output_hidden_states=output_hidden_states,
|
123 |
+
return_dict=return_dict,
|
124 |
+
)
|
125 |
+
|
126 |
+
hidden_states = outputs[0]
|
127 |
+
|
128 |
+
if self.pretraining_tp > 1:
|
129 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
|
130 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
|
131 |
+
logits = torch.cat(logits, dim=-1)
|
132 |
+
else:
|
133 |
+
logits = self.lm_head(hidden_states)
|
134 |
+
logits = logits.float()
|
135 |
+
|
136 |
+
text_loss = None
|
137 |
+
if labels is not None:
|
138 |
+
# Shift so that tokens < n predict n
|
139 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
140 |
+
shift_labels = labels[..., 1:].contiguous()
|
141 |
+
# Flatten the tokens
|
142 |
+
loss_fct = CrossEntropyLoss()
|
143 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
144 |
+
shift_labels = shift_labels.view(-1)
|
145 |
+
# Enable model parallelism
|
146 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
147 |
+
text_loss = loss_fct(shift_logits, shift_labels)
|
148 |
+
num_text_tokens = (shift_labels != -100).sum().item()
|
149 |
+
|
150 |
+
if additional_image_indexs is None:
|
151 |
+
return CausalLMOutputWithPast(
|
152 |
+
loss=text_loss,
|
153 |
+
logits=logits,
|
154 |
+
past_key_values=outputs.past_key_values,
|
155 |
+
hidden_states=outputs.hidden_states,
|
156 |
+
attentions=outputs.attentions,
|
157 |
+
)
|
158 |
+
|
159 |
+
to_image_mask = data_types == 1 # where to get t2i loss in each batch [True, False, False, True....]
|
160 |
+
|
161 |
+
if len(additional_image_indexs) > 0 and len(to_image_mask) == len(hidden_states): # image generation loss
|
162 |
+
to_image_states = hidden_states[to_image_mask]
|
163 |
+
|
164 |
+
# assert len(to_image_states) == len(additional_image_indexs)
|
165 |
+
if len(to_image_states) != len(additional_image_indexs):
|
166 |
+
print('to_image_mask', to_image_mask)
|
167 |
+
print('additional_image_indexs', additional_image_indexs)
|
168 |
+
shift_image_states = torch.stack([state[start_id - 1:end_id - 1] for (start_id, end_id), state in
|
169 |
+
zip(additional_image_indexs, to_image_states)]) # Shift so that tokens < n predict n [bz, seq_len, hidden_dim]
|
170 |
+
base_tokens = shift_image_states
|
171 |
+
|
172 |
+
K = self.ar_head.num_codebooks
|
173 |
+
B, L, C = base_tokens.shape
|
174 |
+
base_tokens = base_tokens.reshape(B * L, 1, C)
|
175 |
+
|
176 |
+
targets = torch.cat(additional_image_labels, dim=0) # [B, K, L]
|
177 |
+
image_code_labels = targets
|
178 |
+
targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1]
|
179 |
+
index_embeddings = []
|
180 |
+
for i in range(K - 1):
|
181 |
+
index_embed = self.ar_head.codebooks[i](targets[:, i])
|
182 |
+
index_embeddings.append(index_embed)
|
183 |
+
index_embeddings = torch.stack(index_embeddings, dim=1)
|
184 |
+
# import pdb;pdb.set_trace()
|
185 |
+
h = torch.cat((base_tokens, index_embeddings), dim=1) # [B*L, K, C]
|
186 |
+
|
187 |
+
multicode_embedding = self.ar_head(
|
188 |
+
input_ids=None,
|
189 |
+
attention_mask=None,
|
190 |
+
position_ids=None,
|
191 |
+
past_key_values=None,
|
192 |
+
inputs_embeds=h,
|
193 |
+
use_cache=False,
|
194 |
+
output_attentions=False,
|
195 |
+
output_hidden_states=False,
|
196 |
+
return_dict=False,
|
197 |
+
cache_position=None,
|
198 |
+
)
|
199 |
+
image_logits = self.ar_head.linear_head(multicode_embedding)
|
200 |
+
image_logits = image_logits.reshape(B, L, K, -1).permute(0, 2, 1, 3) # [B, K, L, sub_vocab_size]
|
201 |
+
loss_fct = CrossEntropyLoss()
|
202 |
+
image_logits = image_logits.reshape(-1, self.ar_head.sub_vocab_size)
|
203 |
+
image_labels = image_code_labels.view(-1)
|
204 |
+
image_labels = image_labels.to(image_logits.device)
|
205 |
+
image_softmax_normalizer = image_logits.max(-1).values ** 2
|
206 |
+
image_z_loss = 0.00005 * image_softmax_normalizer.mean()
|
207 |
+
image_loss = loss_fct(image_logits, image_labels) + image_z_loss
|
208 |
+
num_image_tokens = image_labels.shape[0]
|
209 |
+
else:
|
210 |
+
if len(hidden_states) != len(to_image_mask):
|
211 |
+
print('to_image_mask', to_image_mask)
|
212 |
+
print('hidden_states', hidden_states.shape)
|
213 |
+
print('inputs_embeds', inputs_embeds.shape)
|
214 |
+
print('additional_image_indexs', additional_image_indexs)
|
215 |
+
fake_ids = torch.ones(1, self.model.multi_embedder.num_codebooks - 1).to(inputs_embeds).long()
|
216 |
+
index_embeddings = []
|
217 |
+
for i in range(self.model.multi_embedder.num_codebooks - 1):
|
218 |
+
index_embed = self.ar_head.codebooks[i](fake_ids[:, i])
|
219 |
+
index_embeddings.append(index_embed)
|
220 |
+
index_embeddings = torch.stack(index_embeddings, dim=1)
|
221 |
+
|
222 |
+
multicode_embedding = self.ar_head(
|
223 |
+
input_ids=None,
|
224 |
+
attention_mask=None,
|
225 |
+
position_ids=None,
|
226 |
+
past_key_values=None,
|
227 |
+
inputs_embeds=index_embeddings,
|
228 |
+
use_cache=False,
|
229 |
+
output_attentions=False,
|
230 |
+
output_hidden_states=False,
|
231 |
+
return_dict=False,
|
232 |
+
cache_position=None,
|
233 |
+
)
|
234 |
+
image_logits = self.ar_head.linear_head(multicode_embedding)
|
235 |
+
|
236 |
+
num_image_tokens = 0
|
237 |
+
image_loss = (image_logits * 0).sum() # + (base_tokens*0).sum()
|
238 |
+
pass
|
239 |
+
|
240 |
+
loss = image_loss * (num_image_tokens / (num_image_tokens + num_text_tokens)) + \
|
241 |
+
text_loss * (num_text_tokens / (num_image_tokens + num_text_tokens))
|
242 |
+
|
243 |
+
# t2i_ratio = to_image_mask.sum() / len(to_image_mask)
|
244 |
+
# loss = image_loss * t2i_ratio + text_loss * (1 - t2i_ratio)
|
245 |
+
|
246 |
+
if not return_dict:
|
247 |
+
output = (logits,) + outputs[1:]
|
248 |
+
return (loss,) + output if loss is not None else output
|
249 |
+
|
250 |
+
return CausalLMOutputWithPast(
|
251 |
+
loss=loss,
|
252 |
+
logits=logits,
|
253 |
+
past_key_values=outputs.past_key_values,
|
254 |
+
hidden_states=outputs.hidden_states,
|
255 |
+
attentions=outputs.attentions,
|
256 |
+
)
|
257 |
+
|
258 |
+
@torch.no_grad()
|
259 |
+
def generate_mllm(
|
260 |
+
self,
|
261 |
+
inputs: Optional[torch.Tensor] = None,
|
262 |
+
images: Optional[torch.Tensor] = None,
|
263 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
264 |
+
**kwargs,
|
265 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
266 |
+
position_ids = kwargs.pop("position_ids", None)
|
267 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
268 |
+
if "inputs_embeds" in kwargs:
|
269 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
270 |
+
# import pdb;pdb.set_trace()
|
271 |
+
if images is not None:
|
272 |
+
(
|
273 |
+
inputs,
|
274 |
+
position_ids,
|
275 |
+
attention_mask,
|
276 |
+
_,
|
277 |
+
inputs_embeds,
|
278 |
+
_
|
279 |
+
) = self.prepare_inputs_for_multimodal(
|
280 |
+
inputs,
|
281 |
+
position_ids,
|
282 |
+
attention_mask,
|
283 |
+
None,
|
284 |
+
None,
|
285 |
+
images,
|
286 |
+
images_aux
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
290 |
+
# import pdb;pdb.set_trace()
|
291 |
+
return super().generate(
|
292 |
+
position_ids=position_ids,
|
293 |
+
attention_mask=attention_mask,
|
294 |
+
inputs_embeds=inputs_embeds,
|
295 |
+
**kwargs
|
296 |
+
)
|
297 |
+
|
298 |
+
@torch.no_grad()
|
299 |
+
def generate(
|
300 |
+
self,
|
301 |
+
inputs: Optional[torch.Tensor] = None,
|
302 |
+
images: Optional[torch.Tensor] = None,
|
303 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
304 |
+
**kwargs,
|
305 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
306 |
+
position_ids = kwargs.pop("position_ids", None)
|
307 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
308 |
+
if "inputs_embeds" in kwargs:
|
309 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
310 |
+
|
311 |
+
if images is not None:
|
312 |
+
(
|
313 |
+
inputs,
|
314 |
+
position_ids,
|
315 |
+
attention_mask,
|
316 |
+
_,
|
317 |
+
inputs_embeds,
|
318 |
+
_
|
319 |
+
) = self.prepare_inputs_for_multimodal(
|
320 |
+
inputs,
|
321 |
+
position_ids,
|
322 |
+
attention_mask,
|
323 |
+
None,
|
324 |
+
None,
|
325 |
+
images,
|
326 |
+
images_aux
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
330 |
+
|
331 |
+
return super().generate(
|
332 |
+
position_ids=position_ids,
|
333 |
+
attention_mask=attention_mask,
|
334 |
+
inputs_embeds=inputs_embeds,
|
335 |
+
**kwargs
|
336 |
+
)
|
337 |
+
|
338 |
+
def test_forward(
|
339 |
+
self,
|
340 |
+
input_ids: torch.LongTensor = None,
|
341 |
+
attention_mask: Optional[torch.Tensor] = None,
|
342 |
+
position_ids: Optional[torch.LongTensor] = None,
|
343 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
344 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
345 |
+
labels: Optional[torch.LongTensor] = None,
|
346 |
+
input_multi_ids: torch.LongTensor = None,
|
347 |
+
data_types: torch.LongTensor = None,
|
348 |
+
use_cache: Optional[bool] = None,
|
349 |
+
cache_position: Optional[torch.LongTensor] = None,
|
350 |
+
output_attentions: Optional[bool] = None,
|
351 |
+
output_hidden_states: Optional[bool] = None,
|
352 |
+
images: Optional[torch.FloatTensor] = None,
|
353 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
354 |
+
return_dict: Optional[bool] = None,
|
355 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
356 |
+
# import pdb;pdb.set_trace()
|
357 |
+
if input_multi_ids is not None:
|
358 |
+
input_multi_ids = input_multi_ids.unsqueeze(-1) # [B,K,1]
|
359 |
+
input_ids = None # [B,1]
|
360 |
+
inputs_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C]
|
361 |
+
|
362 |
+
outputs = self.model(
|
363 |
+
input_ids=input_ids,
|
364 |
+
attention_mask=attention_mask,
|
365 |
+
position_ids=position_ids,
|
366 |
+
past_key_values=past_key_values,
|
367 |
+
inputs_embeds=inputs_embeds,
|
368 |
+
use_cache=use_cache,
|
369 |
+
output_attentions=output_attentions,
|
370 |
+
output_hidden_states=output_hidden_states,
|
371 |
+
return_dict=return_dict,
|
372 |
+
)
|
373 |
+
return outputs
|
374 |
+
|
375 |
+
def T2I_forward_nocache(
|
376 |
+
self,
|
377 |
+
input_ids: torch.LongTensor = None,
|
378 |
+
attention_mask: Optional[torch.Tensor] = None,
|
379 |
+
position_ids: Optional[torch.LongTensor] = None,
|
380 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
381 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
382 |
+
labels: Optional[torch.LongTensor] = None,
|
383 |
+
input_multi_ids: torch.LongTensor = None,
|
384 |
+
data_types: torch.LongTensor = None,
|
385 |
+
use_cache: Optional[bool] = None,
|
386 |
+
cache_position: Optional[torch.LongTensor] = None,
|
387 |
+
output_attentions: Optional[bool] = None,
|
388 |
+
output_hidden_states: Optional[bool] = None,
|
389 |
+
images: Optional[torch.FloatTensor] = None,
|
390 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
391 |
+
return_dict: Optional[bool] = None,
|
392 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
393 |
+
# import pdb;pdb.set_trace()
|
394 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
395 |
+
output_hidden_states = (
|
396 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
397 |
+
)
|
398 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
399 |
+
|
400 |
+
if input_multi_ids is not None:
|
401 |
+
inputs_text_embeds = self.get_model().embed_tokens(input_ids)
|
402 |
+
input_ids = None # [B,1]
|
403 |
+
inputs_image_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C]
|
404 |
+
inputs_image_mask = torch.empty(inputs_image_embeds.shape[0], inputs_image_embeds.shape[1]).fill_(1).to(
|
405 |
+
attention_mask)
|
406 |
+
inputs_embeds = torch.cat([inputs_text_embeds, inputs_image_embeds], dim=1)
|
407 |
+
attention_mask = torch.cat([attention_mask, inputs_image_mask], dim=1)
|
408 |
+
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).repeat(
|
409 |
+
inputs_embeds.shape[0], 1)
|
410 |
+
else:
|
411 |
+
inputs_embeds = self.get_model().embed_tokens(input_ids)
|
412 |
+
input_ids = None
|
413 |
+
|
414 |
+
outputs = self.model(
|
415 |
+
input_ids=input_ids,
|
416 |
+
attention_mask=attention_mask,
|
417 |
+
position_ids=position_ids,
|
418 |
+
past_key_values=past_key_values,
|
419 |
+
inputs_embeds=inputs_embeds,
|
420 |
+
use_cache=use_cache,
|
421 |
+
output_attentions=output_attentions,
|
422 |
+
output_hidden_states=output_hidden_states,
|
423 |
+
return_dict=return_dict,
|
424 |
+
)
|
425 |
+
return outputs
|
426 |
+
|
427 |
+
def T2I_forward_withcache(
|
428 |
+
self,
|
429 |
+
input_ids: torch.LongTensor = None,
|
430 |
+
attention_mask: Optional[torch.Tensor] = None,
|
431 |
+
position_ids: Optional[torch.LongTensor] = None,
|
432 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
433 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
434 |
+
labels: Optional[torch.LongTensor] = None,
|
435 |
+
input_multi_ids: torch.LongTensor = None,
|
436 |
+
data_types: torch.LongTensor = None,
|
437 |
+
use_cache: Optional[bool] = None,
|
438 |
+
cache_position: Optional[torch.LongTensor] = None,
|
439 |
+
output_attentions: Optional[bool] = None,
|
440 |
+
output_hidden_states: Optional[bool] = None,
|
441 |
+
images: Optional[torch.FloatTensor] = None,
|
442 |
+
images_aux: Optional[torch.FloatTensor] = None,
|
443 |
+
return_dict: Optional[bool] = None,
|
444 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
445 |
+
# import pdb;pdb.set_trace()
|
446 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
447 |
+
output_hidden_states = (
|
448 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
449 |
+
)
|
450 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
451 |
+
|
452 |
+
if input_multi_ids is not None:
|
453 |
+
inputs_image_embeds = self.model.multi_embedder(input_multi_ids[:, :, -1:]) # [B,1,C]
|
454 |
+
inputs_embeds = inputs_image_embeds
|
455 |
+
input_ids = None # [B,1]
|
456 |
+
else:
|
457 |
+
inputs_embeds = self.get_model().embed_tokens(input_ids)
|
458 |
+
input_ids = None
|
459 |
+
|
460 |
+
outputs = self.model(
|
461 |
+
input_ids=input_ids,
|
462 |
+
attention_mask=attention_mask,
|
463 |
+
position_ids=position_ids,
|
464 |
+
past_key_values=past_key_values,
|
465 |
+
inputs_embeds=inputs_embeds,
|
466 |
+
use_cache=use_cache,
|
467 |
+
output_attentions=output_attentions,
|
468 |
+
output_hidden_states=output_hidden_states,
|
469 |
+
return_dict=return_dict,
|
470 |
+
)
|
471 |
+
return outputs
|
472 |
+
|
473 |
+
|
474 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
475 |
+
images = kwargs.pop("images", None)
|
476 |
+
images_aux = kwargs.pop("images_aux", None)
|
477 |
+
_inputs = super().prepare_inputs_for_generation(
|
478 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
479 |
+
)
|
480 |
+
if images is not None:
|
481 |
+
_inputs['images'] = images
|
482 |
+
if images_aux is not None:
|
483 |
+
_inputs['images_aux'] = images_aux
|
484 |
+
return _inputs
|
485 |
+
|
486 |
+
|
487 |
+
AutoConfig.register("mini_gemini", MiniGeminiConfig)
|
488 |
+
AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiLlamaForCausalLM)
|
model/liquid.py
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ------------------------------------------------------------------------
|
15 |
+
# Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
|
16 |
+
# Copyright 2024 Yanwei Li
|
17 |
+
# ------------------------------------------------------------------------
|
18 |
+
# Modified from MiniGemini (https://github.com/dvlab-research/MGM)
|
19 |
+
# Copyright 2025 ByteDance
|
20 |
+
# ------------------------------------------------------------------------
|
21 |
+
|
22 |
+
import os
|
23 |
+
import json
|
24 |
+
import torch
|
25 |
+
import deepspeed
|
26 |
+
import safetensors
|
27 |
+
import transformers
|
28 |
+
import torch.nn as nn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from abc import ABC, abstractmethod
|
31 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
32 |
+
|
33 |
+
from model.quant import VectorQuantizerM, AttnProjection
|
34 |
+
from model.multimodal_projector.builder import build_vision_projector
|
35 |
+
from model.multimodal_encoder.builder import build_vision_tower, build_vision_tower_aux
|
36 |
+
from constants import (
|
37 |
+
DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN,
|
38 |
+
IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
IS_NEW_TRANSFORMERS = transformers.__version__ >= "4.34.0"
|
43 |
+
|
44 |
+
|
45 |
+
class MiniGeminiMetaModel:
|
46 |
+
def __init__(self, config):
|
47 |
+
super(MiniGeminiMetaModel, self).__init__(config)
|
48 |
+
self.config = config
|
49 |
+
self.multi_embedder = TokenEmbedder(self.config.hidden_size)
|
50 |
+
if hasattr(config, "mm_vision_tower"):
|
51 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
52 |
+
self.mm_projector = build_vision_projector(config)
|
53 |
+
if hasattr(config, "mm_vision_tower_aux"):
|
54 |
+
self.vision_tower_aux = build_vision_tower_aux(config, delay_load=True)
|
55 |
+
|
56 |
+
def get_vision_tower(self):
|
57 |
+
vision_tower = getattr(self, 'vision_tower', None)
|
58 |
+
if type(vision_tower) is list:
|
59 |
+
vision_tower = vision_tower[0]
|
60 |
+
return vision_tower
|
61 |
+
|
62 |
+
def get_vision_tower_aux(self):
|
63 |
+
vision_tower_aux = getattr(self, 'vision_tower_aux', None)
|
64 |
+
if type(vision_tower_aux) is list:
|
65 |
+
vision_tower_aux = vision_tower_aux[0]
|
66 |
+
return vision_tower_aux
|
67 |
+
|
68 |
+
def initialize_embedder(self, unitok_pth, mm_projecter_pth=None):
|
69 |
+
self.multi_embedder = TokenEmbedder(self.config.hidden_size)
|
70 |
+
|
71 |
+
if unitok_pth is not None:
|
72 |
+
ckpt = torch.load(unitok_pth, map_location='cpu')
|
73 |
+
unitok_ckpt = ckpt['trainer']['unitok']
|
74 |
+
quantizer_weights = dict()
|
75 |
+
for k, v in unitok_ckpt.items():
|
76 |
+
if k.startswith('quantizer'):
|
77 |
+
new_k = k.replace('quantizer.', '')
|
78 |
+
quantizer_weights[new_k] = v
|
79 |
+
attn_proj_weights = dict()
|
80 |
+
for k, v in unitok_ckpt.items():
|
81 |
+
if k.startswith('post_quant_proj'):
|
82 |
+
new_k = k.replace('post_quant_proj.', '')
|
83 |
+
attn_proj_weights[new_k] = v
|
84 |
+
|
85 |
+
if is_deepspeed_zero3_enabled():
|
86 |
+
with deepspeed.zero.GatheredParameters(quantizer_weights, modifier_rank=0):
|
87 |
+
if torch.distributed.get_rank() == 0:
|
88 |
+
self.multi_embedder.quantizer.load_state_dict(quantizer_weights)
|
89 |
+
with deepspeed.zero.GatheredParameters(attn_proj_weights, modifier_rank=0):
|
90 |
+
if torch.distributed.get_rank() == 0:
|
91 |
+
self.multi_embedder.attn_projection.load_state_dict(attn_proj_weights)
|
92 |
+
else:
|
93 |
+
status = self.multi_embedder.quantizer.load_state_dict(quantizer_weights)
|
94 |
+
print('missing_keys:', status.missing_keys)
|
95 |
+
status = self.multi_embedder.attn_projection.load_state_dict(attn_proj_weights)
|
96 |
+
print('missing_keys:', status.missing_keys)
|
97 |
+
|
98 |
+
if mm_projecter_pth is not None:
|
99 |
+
mm_projector_weights = torch.load(mm_projecter_pth, map_location='cpu')
|
100 |
+
|
101 |
+
def get_w(weights, keyword):
|
102 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
|
103 |
+
|
104 |
+
named_parameters = get_w(mm_projector_weights, 'mm_projector')
|
105 |
+
|
106 |
+
if is_deepspeed_zero3_enabled():
|
107 |
+
with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
|
108 |
+
if torch.distributed.get_rank() == 0:
|
109 |
+
self.multi_embedder.mm_projector.load_state_dict(named_parameters)
|
110 |
+
else:
|
111 |
+
status = self.multi_embedder.mm_projector.load_state_dict(named_parameters)
|
112 |
+
print('missing_keys:', status.missing_keys)
|
113 |
+
|
114 |
+
self.multi_embedder = self.multi_embedder.to(device='cuda')
|
115 |
+
|
116 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
117 |
+
vision_tower = model_args.vision_tower
|
118 |
+
vision_tower_aux = model_args.vision_tower_aux
|
119 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
120 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
121 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
122 |
+
|
123 |
+
self.config.mm_vision_tower = vision_tower
|
124 |
+
self.config.mm_vision_tower_aux = vision_tower_aux
|
125 |
+
|
126 |
+
if self.get_vision_tower() is None:
|
127 |
+
vision_tower = build_vision_tower(model_args)
|
128 |
+
|
129 |
+
if fsdp is not None and len(fsdp) > 0:
|
130 |
+
self.vision_tower = [vision_tower]
|
131 |
+
else:
|
132 |
+
self.vision_tower = vision_tower
|
133 |
+
else:
|
134 |
+
if fsdp is not None and len(fsdp) > 0:
|
135 |
+
vision_tower = self.vision_tower[0]
|
136 |
+
else:
|
137 |
+
vision_tower = self.vision_tower
|
138 |
+
vision_tower.load_model()
|
139 |
+
|
140 |
+
if vision_tower_aux is not None:
|
141 |
+
if self.get_vision_tower_aux() is None:
|
142 |
+
vision_tower_aux = build_vision_tower_aux(model_args)
|
143 |
+
|
144 |
+
if fsdp is not None and len(fsdp) > 0:
|
145 |
+
self.vision_tower_aux = [vision_tower_aux]
|
146 |
+
else:
|
147 |
+
self.vision_tower_aux = vision_tower_aux
|
148 |
+
else:
|
149 |
+
if fsdp is not None and len(fsdp) > 0:
|
150 |
+
vision_tower_aux = self.vision_tower_aux[0]
|
151 |
+
else:
|
152 |
+
vision_tower_aux = self.vision_tower_aux
|
153 |
+
vision_tower_aux.load_model()
|
154 |
+
self.config.mm_hidden_size_aux = vision_tower_aux.hidden_size
|
155 |
+
|
156 |
+
self.config.use_mm_proj = True
|
157 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
158 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
159 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
160 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
161 |
+
|
162 |
+
if getattr(self, 'mm_projector', None) is None:
|
163 |
+
self.mm_projector = build_vision_projector(self.config)
|
164 |
+
else:
|
165 |
+
# In case it is frozen by LoRA
|
166 |
+
for p in self.mm_projector.parameters():
|
167 |
+
p.requires_grad = True
|
168 |
+
|
169 |
+
if pretrain_mm_mlp_adapter is not None:
|
170 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
171 |
+
|
172 |
+
def get_w(weights, keyword):
|
173 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
|
174 |
+
|
175 |
+
if 'model' in mm_projector_weights.keys():
|
176 |
+
mm_projector_weights = mm_projector_weights['model']
|
177 |
+
if is_deepspeed_zero3_enabled():
|
178 |
+
if len(mm_projector_weights) > 0:
|
179 |
+
with deepspeed.zero.GatheredParameters(mm_projector_weights, modifier_rank=0):
|
180 |
+
if torch.distributed.get_rank() == 0:
|
181 |
+
self.mm_projector.load_state_dict(mm_projector_weights)
|
182 |
+
else:
|
183 |
+
status = self.mm_projector.load_state_dict(mm_projector_weights, strict=False)
|
184 |
+
print('missing_keys:', status.missing_keys)
|
185 |
+
else:
|
186 |
+
if is_deepspeed_zero3_enabled():
|
187 |
+
named_parameters = get_w(mm_projector_weights, 'mm_projector')
|
188 |
+
if len(named_parameters) > 0:
|
189 |
+
with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
|
190 |
+
if torch.distributed.get_rank() == 0:
|
191 |
+
self.mm_projector.load_state_dict(named_parameters)
|
192 |
+
else:
|
193 |
+
status = self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'),
|
194 |
+
strict=False)
|
195 |
+
print('missing_keys:', status.missing_keys)
|
196 |
+
self.mm_projector = self.mm_projector.to(device='cuda')
|
197 |
+
|
198 |
+
def initialize_uni_modules(self, model_args, for_eval=False):
|
199 |
+
pretrain_mm_mlp_adapter = getattr(model_args, "pretrain_mm_mlp_adapter", None)
|
200 |
+
self.config.image_size_aux = getattr(model_args, 'image_size_aux', 320)
|
201 |
+
self.config.optimize_vision_tower = getattr(model_args, 'optimize_vision_tower', False)
|
202 |
+
self.config.optimize_vision_tower_aux = getattr(model_args, 'optimize_vision_tower_aux', False)
|
203 |
+
|
204 |
+
self.vlm_uni_query_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size),
|
205 |
+
nn.Linear(self.config.mm_hidden_size, self.config.mm_hidden_size))
|
206 |
+
self.vlm_uni_aux_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
|
207 |
+
nn.Linear(self.config.mm_hidden_size_aux,
|
208 |
+
self.config.mm_hidden_size))
|
209 |
+
self.vlm_uni_val_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
|
210 |
+
nn.Linear(self.config.mm_hidden_size_aux,
|
211 |
+
self.config.mm_hidden_size))
|
212 |
+
|
213 |
+
if pretrain_mm_mlp_adapter is not None:
|
214 |
+
projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
215 |
+
else:
|
216 |
+
trainable_module = ['vlm_uni', 'vision_fpn', 'vision_stages']
|
217 |
+
if hasattr(model_args, 'model_name_or_path'):
|
218 |
+
model_save_path = model_args.model_name_or_path
|
219 |
+
else:
|
220 |
+
model_save_path = model_args.model_path
|
221 |
+
model_idx_path = getattr(model_args, 'model_path', model_save_path)
|
222 |
+
if IS_NEW_TRANSFORMERS:
|
223 |
+
try:
|
224 |
+
weight_file = json.load(open(os.path.join(model_idx_path, 'model.safetensors.index.json'), 'r'))[
|
225 |
+
'weight_map']
|
226 |
+
except:
|
227 |
+
weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))[
|
228 |
+
'weight_map']
|
229 |
+
else:
|
230 |
+
weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))[
|
231 |
+
'weight_map']
|
232 |
+
model_path = set(
|
233 |
+
[weight_file[_key] for _key in weight_file if any([_module in _key for _module in trainable_module])])
|
234 |
+
projector_weights = {}
|
235 |
+
for _model in model_path:
|
236 |
+
if not IS_NEW_TRANSFORMERS:
|
237 |
+
projector_weights.update(torch.load(os.path.join(model_idx_path, _model), map_location='cpu'))
|
238 |
+
else:
|
239 |
+
with safetensors.safe_open(os.path.join(model_idx_path, _model), framework="pt", device='cpu') as f:
|
240 |
+
for _key in f.keys():
|
241 |
+
projector_weights.update({_key: f.get_tensor(_key)})
|
242 |
+
if len(projector_weights) == 0:
|
243 |
+
return
|
244 |
+
|
245 |
+
def get_w(weights, keyword, main_module, sub_module):
|
246 |
+
if getattr(main_module, sub_module, None) is None:
|
247 |
+
return
|
248 |
+
|
249 |
+
pretrain_weight = {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
|
250 |
+
if len(pretrain_weight) == 0:
|
251 |
+
return
|
252 |
+
if is_deepspeed_zero3_enabled():
|
253 |
+
named_parameters = [v for k, v in getattr(main_module, sub_module).named_parameters()]
|
254 |
+
if len(named_parameters) > 0:
|
255 |
+
# because zero3 puts placeholders in model params, this context
|
256 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
257 |
+
# the state dict and then re-partitions them again
|
258 |
+
with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
|
259 |
+
if torch.distributed.get_rank() == 0:
|
260 |
+
getattr(main_module, sub_module).load_state_dict(pretrain_weight)
|
261 |
+
with deepspeed.zero.GatheredParameters(self.mm_projector[0].weight, modifier_rank=None):
|
262 |
+
weight_type = self.mm_projector[0].weight.dtype
|
263 |
+
device_type = self.mm_projector[0].weight.device
|
264 |
+
else:
|
265 |
+
weight_type = self.mm_projector[0].weight.dtype
|
266 |
+
device_type = self.mm_projector[0].weight.device
|
267 |
+
getattr(main_module, sub_module).load_state_dict(pretrain_weight)
|
268 |
+
if weight_type == torch.uint8 or weight_type == torch.int8 or weight_type == torch.int16:
|
269 |
+
weight_type = torch.float16
|
270 |
+
getattr(main_module, sub_module).to(device=device_type, dtype=weight_type)
|
271 |
+
print(f"Loading {sub_module} weights...")
|
272 |
+
|
273 |
+
# load pretrained weights
|
274 |
+
get_w(projector_weights, 'vision_tower.vision_tower', self.vision_tower, 'vision_tower')
|
275 |
+
|
276 |
+
# load pretrained weights
|
277 |
+
if self.config.optimize_vision_tower_aux:
|
278 |
+
# not optimize vision stem, just used to check
|
279 |
+
get_w(projector_weights, 'vision_tower_aux.vision_stem', self.vision_tower_aux, 'vision_stem')
|
280 |
+
get_w(projector_weights, 'vision_tower_aux.vision_stages', self.vision_tower_aux, 'vision_stages')
|
281 |
+
get_w(projector_weights, 'vlm_uni_query_projector', self, 'vlm_uni_query_projector')
|
282 |
+
get_w(projector_weights, 'vlm_uni_aux_projector', self, 'vlm_uni_aux_projector')
|
283 |
+
get_w(projector_weights, 'vlm_uni_val_projector', self, 'vlm_uni_val_projector')
|
284 |
+
|
285 |
+
|
286 |
+
class TokenEmbedder(nn.Module):
|
287 |
+
def __init__(self, hidden_size):
|
288 |
+
super().__init__()
|
289 |
+
# hard coding for unitok, need to be fixed
|
290 |
+
self.num_codebooks = 8
|
291 |
+
self.quantizer = VectorQuantizerM(32768, 64, 0.25, False, 0.01, 8)
|
292 |
+
self.attn_projection = AttnProjection(64, 1024, 16)
|
293 |
+
self.mm_projector = nn.Sequential(
|
294 |
+
nn.LayerNorm(1024, eps=1e-6),
|
295 |
+
nn.Linear(1024, hidden_size),
|
296 |
+
nn.GELU(),
|
297 |
+
nn.Linear(hidden_size, hidden_size),
|
298 |
+
)
|
299 |
+
|
300 |
+
def forward(self, indices): # input [bz,num-codebook,256]
|
301 |
+
assert indices.shape[1] == self.num_codebooks
|
302 |
+
features = self.quantizer.idx_to_f(indices) # [bz,256,C]
|
303 |
+
features = self.attn_projection(features) # [bz,256,1024]
|
304 |
+
latent_features = self.mm_projector(features) # [bz,256,hidden_size]
|
305 |
+
return latent_features # [bz,256,hidden_size
|
306 |
+
|
307 |
+
|
308 |
+
class MiniGeminiMetaForCausalLM(ABC):
|
309 |
+
@abstractmethod
|
310 |
+
def get_model(self):
|
311 |
+
pass
|
312 |
+
|
313 |
+
def get_vision_tower(self):
|
314 |
+
return self.get_model().get_vision_tower()
|
315 |
+
|
316 |
+
def get_vision_tower_aux(self):
|
317 |
+
return self.get_model().get_vision_tower_aux()
|
318 |
+
|
319 |
+
def encode_images(self, images, images_aux=None, is_video=False):
|
320 |
+
image_grid = getattr(self.config, 'image_grid', 1)
|
321 |
+
image_global = getattr(self.config, 'image_global', False)
|
322 |
+
if image_grid > 1:
|
323 |
+
batch_size = images.shape[0]
|
324 |
+
if image_global:
|
325 |
+
global_images = images[:, -1:].flatten(0, 1).contiguous()
|
326 |
+
grid_images = images[:, :-1].flatten(0, 1).contiguous()
|
327 |
+
images = torch.cat([grid_images, global_images], dim=0)
|
328 |
+
else:
|
329 |
+
images = images.flatten(0, 1).contiguous()
|
330 |
+
|
331 |
+
image_features = self.get_model().get_vision_tower()(images)
|
332 |
+
|
333 |
+
if image_global:
|
334 |
+
image_feat_global = image_features[-len(global_images):]
|
335 |
+
image_features = image_features[:len(grid_images)]
|
336 |
+
|
337 |
+
if images_aux is not None:
|
338 |
+
image_aux_features_raw = self.get_model().get_vision_tower_aux()(images_aux).to(
|
339 |
+
dtype=image_features.dtype, device=image_features.device)
|
340 |
+
|
341 |
+
if image_global:
|
342 |
+
image_aux_features_global = F.interpolate(image_aux_features_raw.float(),
|
343 |
+
scale_factor=1 / image_grid,
|
344 |
+
mode='bilinear',
|
345 |
+
align_corners=False).to(dtype=image_aux_features_raw.dtype)
|
346 |
+
image_feat_global, image_aux_feat_global = self.unified_resampler(image_feat_global,
|
347 |
+
image_aux_features_global)
|
348 |
+
|
349 |
+
if image_grid > 1:
|
350 |
+
image_aux_features_raw = image_aux_features_raw.reshape(*image_aux_features_raw.shape[:2],
|
351 |
+
image_grid,
|
352 |
+
image_aux_features_raw.shape[-2] // image_grid,
|
353 |
+
image_grid,
|
354 |
+
image_aux_features_raw.shape[-1] // image_grid)
|
355 |
+
image_aux_features_raw = image_aux_features_raw.permute(0, 2, 4, 1, 3, 5).flatten(1, 2).flatten(0,
|
356 |
+
1).contiguous()
|
357 |
+
image_features, image_aux_features = self.unified_resampler(image_features, image_aux_features_raw)
|
358 |
+
|
359 |
+
if image_grid > 1:
|
360 |
+
image_features = image_features.reshape(batch_size, image_grid ** 2, *image_features.shape[1:])
|
361 |
+
image_features = image_features.flatten(1, 2).contiguous()
|
362 |
+
image_aux_features = image_aux_features.reshape(batch_size, image_grid ** 2,
|
363 |
+
*image_aux_features.shape[1:])
|
364 |
+
image_aux_features = image_aux_features.flatten(1, 2).contiguous()
|
365 |
+
|
366 |
+
# add global features, [global, local]
|
367 |
+
if image_global:
|
368 |
+
image_features = torch.cat([image_feat_global, image_features], dim=1)
|
369 |
+
image_aux_features = torch.cat([image_aux_feat_global, image_aux_features], dim=1)
|
370 |
+
|
371 |
+
# token generation
|
372 |
+
image_features = image_features + image_aux_features
|
373 |
+
|
374 |
+
# process image features after token generation
|
375 |
+
image_features = self.get_model().mm_projector(image_features)
|
376 |
+
|
377 |
+
return image_features
|
378 |
+
|
379 |
+
def unified_resampler(self, images, images_aux):
|
380 |
+
# patchwise with square images
|
381 |
+
patch_num = int(images.shape[1] ** 0.5)
|
382 |
+
patch_size = images_aux.shape[-1] // patch_num
|
383 |
+
# within patch attention
|
384 |
+
images_aux = images_aux.permute(0, 2, 3, 1)
|
385 |
+
images_aux = images_aux.reshape(len(images_aux), patch_num, patch_size, patch_num, patch_size,
|
386 |
+
images_aux.shape[-1])
|
387 |
+
images_aux = images_aux.permute(0, 1, 3, 2, 4, 5)
|
388 |
+
images_aux = images_aux.reshape(len(images_aux), patch_num ** 2, patch_size ** 2,
|
389 |
+
images_aux.shape[-1]).contiguous()
|
390 |
+
|
391 |
+
# token attention
|
392 |
+
embed_query = self.get_model().vlm_uni_query_projector(images)
|
393 |
+
embed_aux = self.get_model().vlm_uni_aux_projector(images_aux)
|
394 |
+
embed_value = self.get_model().vlm_uni_val_projector(images_aux)
|
395 |
+
embed_att = embed_query[:, :, None] @ (embed_aux.transpose(-1, -2) / (embed_aux.shape[-1] ** 0.5))
|
396 |
+
embed_att = embed_att.nan_to_num()
|
397 |
+
embed_feat = (embed_att.softmax(-1) @ embed_value).mean(2)
|
398 |
+
|
399 |
+
return images, embed_feat
|
400 |
+
|
401 |
+
def prepare_inputs_labels_for_multimodal(
|
402 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels, images=None, images_aux=None,
|
403 |
+
data_types=None,
|
404 |
+
):
|
405 |
+
vision_tower = self.get_vision_tower()
|
406 |
+
multi_embedder = self.model.multi_embedder
|
407 |
+
# import pdb;pdb.set_trace()
|
408 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
409 |
+
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
|
410 |
+
1] == 1:
|
411 |
+
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
412 |
+
attention_mask = torch.cat((attention_mask, torch.ones(
|
413 |
+
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
414 |
+
dtype=attention_mask.dtype,
|
415 |
+
device=attention_mask.device
|
416 |
+
)), dim=1)
|
417 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
418 |
+
|
419 |
+
if position_ids is None:
|
420 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
421 |
+
bug_flag = False
|
422 |
+
if images is not None:
|
423 |
+
_labels = labels
|
424 |
+
_position_ids = position_ids
|
425 |
+
_attention_mask = attention_mask
|
426 |
+
new_input_embeds = []
|
427 |
+
new_labels = []
|
428 |
+
additional_image_labels = []
|
429 |
+
additional_image_indexs = []
|
430 |
+
if attention_mask is not None:
|
431 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
|
432 |
+
zip(input_ids, attention_mask)]
|
433 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in
|
434 |
+
zip(labels, attention_mask)]
|
435 |
+
# import pdb;pdb.set_trace()
|
436 |
+
for image, cur_input_ids, cur_labels, data_type in zip(images, input_ids, labels, data_types):
|
437 |
+
# import pdb;pdb.set_trace()
|
438 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
439 |
+
# import pdb;pdb.set_trace()
|
440 |
+
if num_images == 0:
|
441 |
+
# import pdb;pdb.set_trace()
|
442 |
+
empty_image_embed = multi_embedder(
|
443 |
+
torch.zeros(1, self.model.multi_embedder.num_codebooks, 1).long().to(cur_input_ids))[0, :0]
|
444 |
+
new_input_embeds.append(
|
445 |
+
torch.cat([self.get_model().embed_tokens(cur_input_ids), empty_image_embed], dim=0))
|
446 |
+
new_labels.append(cur_labels)
|
447 |
+
continue # pure text data
|
448 |
+
assert len(image.shape) == 3 # [bz,num-codebook,256] image token id
|
449 |
+
if len(image) > num_images:
|
450 |
+
image = image[:num_images] # remove cutted images
|
451 |
+
image_embedding = multi_embedder(image) # get image embeddings
|
452 |
+
|
453 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
|
454 |
+
cur_input_ids.shape[0]]
|
455 |
+
cur_input_ids_noim = []
|
456 |
+
cur_labels_noim = []
|
457 |
+
for i in range(len(image_token_indices) - 1):
|
458 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
|
459 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
|
460 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
461 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
462 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
463 |
+
cur_new_input_embeds = []
|
464 |
+
cur_new_labels = []
|
465 |
+
# import pdb;pdb.set_trace()
|
466 |
+
max_pos_id = 0
|
467 |
+
for i in range(num_images + 1):
|
468 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
469 |
+
cur_new_labels.append(cur_labels_noim[i])
|
470 |
+
# import pdb;pdb.set_trace()
|
471 |
+
max_pos_id += cur_input_embeds_no_im[i].shape[0]
|
472 |
+
if i < num_images:
|
473 |
+
cur_image_features = image_embedding[i]
|
474 |
+
cur_new_input_embeds.append(cur_image_features)
|
475 |
+
|
476 |
+
if data_type == 1: # to Image, loss on 4x image tokens
|
477 |
+
additional_image_labels.append(image)
|
478 |
+
additional_image_indexs.append((cur_new_labels[-1].shape[0],
|
479 |
+
cur_new_labels[-1].shape[0] + cur_image_features.shape[
|
480 |
+
0]))
|
481 |
+
### input: describe xxxx: boi 8*[256] (256 embedding) eoi eos
|
482 |
+
### labels: -100 -100 -100 -100 -100 -100 -100 -100 -100 eoi eos
|
483 |
+
cur_new_labels.append(
|
484 |
+
torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
|
485 |
+
dtype=cur_labels.dtype))
|
486 |
+
max_pos_id += cur_image_features.shape[0]
|
487 |
+
|
488 |
+
cur_new_input_embeds = [x.to(device=cur_input_embeds.device) for x in cur_new_input_embeds]
|
489 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
490 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
491 |
+
|
492 |
+
new_input_embeds.append(cur_new_input_embeds)
|
493 |
+
new_labels.append(cur_new_labels)
|
494 |
+
|
495 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
496 |
+
|
497 |
+
if tokenizer_model_max_length is not None:
|
498 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
499 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
500 |
+
|
501 |
+
# Combine them
|
502 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
503 |
+
batch_size = len(new_input_embeds)
|
504 |
+
assert len(new_labels) == len(data_types)
|
505 |
+
new_input_embeds_padded = []
|
506 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
|
507 |
+
device=new_labels[0].device)
|
508 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype,
|
509 |
+
device=attention_mask.device)
|
510 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
511 |
+
# import pdb;pdb.set_trace()
|
512 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
513 |
+
cur_len = cur_new_embed.shape[0]
|
514 |
+
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
|
515 |
+
new_input_embeds_padded.append(torch.cat((
|
516 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
|
517 |
+
device=cur_new_embed.device),
|
518 |
+
cur_new_embed
|
519 |
+
), dim=0))
|
520 |
+
if cur_len > 0:
|
521 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
522 |
+
attention_mask[i, -cur_len:] = True
|
523 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
|
524 |
+
device=position_ids.device)
|
525 |
+
else:
|
526 |
+
new_input_embeds_padded.append(torch.cat((
|
527 |
+
cur_new_embed,
|
528 |
+
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
|
529 |
+
device=cur_new_embed.device)
|
530 |
+
), dim=0))
|
531 |
+
if cur_len > 0:
|
532 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
533 |
+
attention_mask[i, :cur_len] = True
|
534 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
|
535 |
+
device=position_ids.device)
|
536 |
+
|
537 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
538 |
+
|
539 |
+
if _labels is None:
|
540 |
+
new_labels = None
|
541 |
+
else:
|
542 |
+
new_labels = new_labels_padded
|
543 |
+
|
544 |
+
if _attention_mask is None:
|
545 |
+
attention_mask = None
|
546 |
+
else:
|
547 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
548 |
+
|
549 |
+
if _position_ids is None:
|
550 |
+
position_ids = None
|
551 |
+
# import pdb;pdb.set_trace()
|
552 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, data_types, additional_image_labels, additional_image_indexs
|
553 |
+
|
554 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
555 |
+
|
556 |
+
def prepare_inputs_for_multimodal(
|
557 |
+
self, input_ids, position_ids, attention_mask,
|
558 |
+
past_key_values, labels, images=None, images_aux=None, data_types=None,
|
559 |
+
):
|
560 |
+
multi_embedder = self.model.multi_embedder
|
561 |
+
# import pdb;pdb.set_trace()
|
562 |
+
_labels = labels
|
563 |
+
_position_ids = position_ids
|
564 |
+
_attention_mask = attention_mask
|
565 |
+
if images is not None:
|
566 |
+
new_input_embeds = []
|
567 |
+
for image, cur_input_ids in zip(images, input_ids):
|
568 |
+
# import pdb;pdb.set_trace()
|
569 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
570 |
+
if num_images == 0:
|
571 |
+
new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
572 |
+
continue # pure text data
|
573 |
+
image_embedding = multi_embedder(image)
|
574 |
+
# import pdb;pdb.set_trace()
|
575 |
+
|
576 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
|
577 |
+
cur_input_ids.shape[0]]
|
578 |
+
cur_input_ids_noim = []
|
579 |
+
for i in range(len(image_token_indices) - 1):
|
580 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
|
581 |
+
split_sizes = [x.shape[0] for x in cur_input_ids_noim]
|
582 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
583 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
584 |
+
cur_new_input_embeds = []
|
585 |
+
# import pdb;pdb.set_trace()
|
586 |
+
max_pos_id = 0
|
587 |
+
for i in range(num_images + 1):
|
588 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
589 |
+
# import pdb;pdb.set_trace()
|
590 |
+
max_pos_id += cur_input_embeds_no_im[i].shape[0]
|
591 |
+
if i < num_images:
|
592 |
+
cur_image_features = image_embedding[i]
|
593 |
+
cur_new_input_embeds.append(cur_image_features)
|
594 |
+
max_pos_id += cur_image_features.shape[0]
|
595 |
+
|
596 |
+
cur_new_input_embeds = [x.to(device=cur_input_embeds.device) for x in cur_new_input_embeds]
|
597 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
598 |
+
new_input_embeds.append(cur_new_input_embeds)
|
599 |
+
|
600 |
+
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
|
601 |
+
if tokenizer_model_max_length is not None:
|
602 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
603 |
+
# import pdb;pdb.set_trace()
|
604 |
+
# Combine them
|
605 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
606 |
+
batch_size = len(new_input_embeds)
|
607 |
+
new_input_embeds_padded = []
|
608 |
+
|
609 |
+
new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
610 |
+
# import pdb;pdb.set_trace()
|
611 |
+
if _labels is None:
|
612 |
+
new_labels = None
|
613 |
+
else:
|
614 |
+
new_labels = new_labels_padded
|
615 |
+
|
616 |
+
if _attention_mask is None:
|
617 |
+
attention_mask = None
|
618 |
+
else:
|
619 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
620 |
+
|
621 |
+
if _position_ids is None:
|
622 |
+
position_ids = None
|
623 |
+
# import pdb;pdb.set_trace()
|
624 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
625 |
+
|
626 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
627 |
+
if model_args.mm_use_im_patch_token:
|
628 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
629 |
+
self.resize_token_embeddings(len(tokenizer))
|
630 |
+
|
631 |
+
if model_args.mm_use_im_start_end:
|
632 |
+
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
633 |
+
self.resize_token_embeddings(len(tokenizer))
|
634 |
+
|
635 |
+
if num_new_tokens > 0:
|
636 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
637 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
638 |
+
|
639 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
640 |
+
dim=0, keepdim=True)
|
641 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
642 |
+
dim=0, keepdim=True)
|
643 |
+
|
644 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
645 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
646 |
+
|
647 |
+
if model_args.tune_mm_mlp_adapter:
|
648 |
+
for p in self.get_input_embeddings().parameters():
|
649 |
+
p.requires_grad = True
|
650 |
+
for p in self.get_output_embeddings().parameters():
|
651 |
+
p.requires_grad = False
|
652 |
+
|
653 |
+
if model_args.pretrain_mm_mlp_adapter:
|
654 |
+
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
|
655 |
+
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
|
656 |
+
assert num_new_tokens == 2
|
657 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
658 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
659 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
660 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
661 |
+
else:
|
662 |
+
raise ValueError(
|
663 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
|
664 |
+
elif model_args.mm_use_im_patch_token:
|
665 |
+
if model_args.tune_mm_mlp_adapter:
|
666 |
+
for p in self.get_input_embeddings().parameters():
|
667 |
+
p.requires_grad = False
|
668 |
+
for p in self.get_output_embeddings().parameters():
|
669 |
+
p.requires_grad = False
|
model/multimodal_encoder/__pycache__/builder.cpython-311.pyc
ADDED
Binary file (2.35 kB). View file
|
|
model/multimodal_encoder/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (1.27 kB). View file
|
|
model/multimodal_encoder/__pycache__/clip_encoder.cpython-311.pyc
ADDED
Binary file (5.92 kB). View file
|
|
model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc
ADDED
Binary file (3.37 kB). View file
|
|
model/multimodal_encoder/__pycache__/eva_encoder.cpython-311.pyc
ADDED
Binary file (34.2 kB). View file
|
|
model/multimodal_encoder/__pycache__/eva_encoder.cpython-39.pyc
ADDED
Binary file (17.3 kB). View file
|
|
model/multimodal_encoder/__pycache__/openclip_encoder.cpython-311.pyc
ADDED
Binary file (12.3 kB). View file
|
|
model/multimodal_encoder/__pycache__/openclip_encoder.cpython-39.pyc
ADDED
Binary file (6.52 kB). View file
|
|
model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
from .eva_encoder import EVAVisionTower
|
4 |
+
from .openclip_encoder import OpenCLIPVisionTower
|
5 |
+
|
6 |
+
|
7 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
8 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
9 |
+
image_processor = getattr(vision_tower_cfg, 'image_processor', getattr(vision_tower_cfg, 'image_processor', "../processor/clip-patch14-224"))
|
10 |
+
|
11 |
+
if not os.path.exists(vision_tower):
|
12 |
+
raise ValueError(f'Not find vision tower: {vision_tower}')
|
13 |
+
|
14 |
+
if "openai" in vision_tower.lower() or "ShareGPT4V" in vision_tower:
|
15 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
16 |
+
elif "lavis" in vision_tower.lower() or "eva" in vision_tower.lower():
|
17 |
+
return EVAVisionTower(vision_tower, image_processor, args=vision_tower_cfg, **kwargs)
|
18 |
+
else:
|
19 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
20 |
+
|
21 |
+
|
22 |
+
def build_vision_tower_aux(vision_tower_cfg, **kwargs):
|
23 |
+
vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', getattr(vision_tower_cfg, 'vision_tower_aux', None))
|
24 |
+
|
25 |
+
if not os.path.exists(vision_tower_aux):
|
26 |
+
raise ValueError(f'Not find vision tower: {vision_tower_aux}')
|
27 |
+
|
28 |
+
if "openclip" in vision_tower_aux.lower():
|
29 |
+
return OpenCLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
|
30 |
+
elif "openai" in vision_tower_aux.lower():
|
31 |
+
return CLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
|
32 |
+
else:
|
33 |
+
raise ValueError(f'Unknown vision tower: {vision_tower_aux}')
|
model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
from ..processor.video_processor import VideoFramesProcessor
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.is_loaded = False
|
12 |
+
|
13 |
+
self.vision_tower_name = vision_tower
|
14 |
+
self.select_layer = args.mm_vision_select_layer
|
15 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
16 |
+
self.is_optimize = getattr(args, 'optimize_vision_tower', False)
|
17 |
+
|
18 |
+
if not delay_load:
|
19 |
+
self.load_model()
|
20 |
+
elif getattr(args, 'unfreeze_mm_vision_tower', False):
|
21 |
+
self.load_model()
|
22 |
+
else:
|
23 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
24 |
+
|
25 |
+
def load_model(self):
|
26 |
+
self.image_processor = VideoFramesProcessor.from_pretrained(self.vision_tower_name)
|
27 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
28 |
+
self.vision_tower.requires_grad_(False)
|
29 |
+
|
30 |
+
self.is_loaded = True
|
31 |
+
|
32 |
+
def feature_select(self, image_forward_outs):
|
33 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
34 |
+
if self.select_feature == 'patch':
|
35 |
+
image_features = image_features[:, 1:]
|
36 |
+
elif self.select_feature == 'cls_patch':
|
37 |
+
image_features = image_features
|
38 |
+
else:
|
39 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
40 |
+
return image_features
|
41 |
+
|
42 |
+
def image_forward(self, images):
|
43 |
+
if type(images) is list:
|
44 |
+
image_features = []
|
45 |
+
for image in images:
|
46 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
47 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
48 |
+
image_features.append(image_feature)
|
49 |
+
else:
|
50 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
51 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
52 |
+
|
53 |
+
return image_features
|
54 |
+
|
55 |
+
def forward(self, images):
|
56 |
+
if not self.is_optimize:
|
57 |
+
with torch.no_grad():
|
58 |
+
image_features = self.image_forward(images)
|
59 |
+
else:
|
60 |
+
image_features = self.image_forward(images)
|
61 |
+
|
62 |
+
return image_features
|
63 |
+
|
64 |
+
@property
|
65 |
+
def dummy_feature(self):
|
66 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def dtype(self):
|
70 |
+
return self.vision_tower.dtype
|
71 |
+
|
72 |
+
@property
|
73 |
+
def device(self):
|
74 |
+
return self.vision_tower.device
|
75 |
+
|
76 |
+
@property
|
77 |
+
def config(self):
|
78 |
+
if self.is_loaded:
|
79 |
+
return self.vision_tower.config
|
80 |
+
else:
|
81 |
+
return self.cfg_only
|
82 |
+
|
83 |
+
@property
|
84 |
+
def hidden_size(self):
|
85 |
+
return self.config.hidden_size
|
86 |
+
|
87 |
+
@property
|
88 |
+
def num_patches(self):
|
89 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
model/multimodal_encoder/eva_encoder.py
ADDED
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on EVA, BEIT, timm and DeiT code bases
|
2 |
+
# https://github.com/baaivision/EVA
|
3 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
4 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
5 |
+
# https://github.com/facebookresearch/deit/
|
6 |
+
# https://github.com/facebookresearch/dino
|
7 |
+
# --------------------------------------------------------'
|
8 |
+
import math
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import torch.utils.checkpoint as checkpoint
|
15 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
16 |
+
from timm.models.registry import register_model
|
17 |
+
from transformers import CLIPImageProcessor, CLIPVisionConfig
|
18 |
+
from ..processor.video_processor import VideoFramesProcessor
|
19 |
+
|
20 |
+
def _cfg(url='', **kwargs):
|
21 |
+
return {
|
22 |
+
'url': url,
|
23 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
24 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
25 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
26 |
+
**kwargs
|
27 |
+
}
|
28 |
+
|
29 |
+
class DropPath(nn.Module):
|
30 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
31 |
+
"""
|
32 |
+
def __init__(self, drop_prob=None):
|
33 |
+
super(DropPath, self).__init__()
|
34 |
+
self.drop_prob = drop_prob
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return drop_path(x, self.drop_prob, self.training)
|
38 |
+
|
39 |
+
def extra_repr(self) -> str:
|
40 |
+
return 'p={}'.format(self.drop_prob)
|
41 |
+
|
42 |
+
|
43 |
+
class Mlp(nn.Module):
|
44 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
45 |
+
super().__init__()
|
46 |
+
out_features = out_features or in_features
|
47 |
+
hidden_features = hidden_features or in_features
|
48 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
49 |
+
self.act = act_layer()
|
50 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
51 |
+
self.drop = nn.Dropout(drop)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.fc1(x)
|
55 |
+
x = self.act(x)
|
56 |
+
# x = self.drop(x)
|
57 |
+
# commit this for the orignal BERT implement
|
58 |
+
x = self.fc2(x)
|
59 |
+
x = self.drop(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class Attention(nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
66 |
+
proj_drop=0., window_size=None, attn_head_dim=None):
|
67 |
+
super().__init__()
|
68 |
+
self.num_heads = num_heads
|
69 |
+
head_dim = dim // num_heads
|
70 |
+
if attn_head_dim is not None:
|
71 |
+
head_dim = attn_head_dim
|
72 |
+
all_head_dim = head_dim * self.num_heads
|
73 |
+
self.scale = qk_scale or head_dim ** -0.5
|
74 |
+
|
75 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
76 |
+
if qkv_bias:
|
77 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
78 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
79 |
+
else:
|
80 |
+
self.q_bias = None
|
81 |
+
self.v_bias = None
|
82 |
+
|
83 |
+
if window_size:
|
84 |
+
self.window_size = window_size
|
85 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
86 |
+
self.relative_position_bias_table = nn.Parameter(
|
87 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
88 |
+
# cls to token & token 2 cls & cls to cls
|
89 |
+
|
90 |
+
# get pair-wise relative position index for each token inside the window
|
91 |
+
coords_h = torch.arange(window_size[0])
|
92 |
+
coords_w = torch.arange(window_size[1])
|
93 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
94 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
95 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
96 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
97 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
98 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
99 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
100 |
+
relative_position_index = \
|
101 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
102 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
103 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
104 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
105 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
106 |
+
|
107 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
108 |
+
else:
|
109 |
+
self.window_size = None
|
110 |
+
self.relative_position_bias_table = None
|
111 |
+
self.relative_position_index = None
|
112 |
+
|
113 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
114 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
115 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
116 |
+
|
117 |
+
def forward(self, x, rel_pos_bias=None):
|
118 |
+
B, N, C = x.shape
|
119 |
+
qkv_bias = None
|
120 |
+
if self.q_bias is not None:
|
121 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
122 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
123 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
124 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
125 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
126 |
+
|
127 |
+
q = q * self.scale
|
128 |
+
attn = (q @ k.transpose(-2, -1))
|
129 |
+
|
130 |
+
if self.relative_position_bias_table is not None:
|
131 |
+
relative_position_bias = \
|
132 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
133 |
+
self.window_size[0] * self.window_size[1] + 1,
|
134 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
135 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
136 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
137 |
+
|
138 |
+
if rel_pos_bias is not None:
|
139 |
+
attn = attn + rel_pos_bias
|
140 |
+
|
141 |
+
attn = attn.softmax(dim=-1)
|
142 |
+
attn = self.attn_drop(attn)
|
143 |
+
|
144 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
145 |
+
x = self.proj(x)
|
146 |
+
x = self.proj_drop(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class Block(nn.Module):
|
151 |
+
|
152 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
153 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
154 |
+
window_size=None, attn_head_dim=None):
|
155 |
+
super().__init__()
|
156 |
+
self.norm1 = norm_layer(dim)
|
157 |
+
self.attn = Attention(
|
158 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
159 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
|
160 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
161 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
162 |
+
self.norm2 = norm_layer(dim)
|
163 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
164 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
165 |
+
|
166 |
+
if init_values is not None and init_values > 0:
|
167 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
168 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
169 |
+
else:
|
170 |
+
self.gamma_1, self.gamma_2 = None, None
|
171 |
+
|
172 |
+
def forward(self, x, rel_pos_bias=None):
|
173 |
+
if self.gamma_1 is None:
|
174 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
175 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
176 |
+
else:
|
177 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
178 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class PatchEmbed(nn.Module):
|
183 |
+
""" Image to Patch Embedding
|
184 |
+
"""
|
185 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
186 |
+
super().__init__()
|
187 |
+
img_size = to_2tuple(img_size)
|
188 |
+
patch_size = to_2tuple(patch_size)
|
189 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
190 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
191 |
+
self.img_size = img_size
|
192 |
+
self.patch_size = patch_size
|
193 |
+
self.num_patches = num_patches
|
194 |
+
|
195 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
196 |
+
|
197 |
+
def forward(self, x, **kwargs):
|
198 |
+
B, C, H, W = x.shape
|
199 |
+
# FIXME look at relaxing size constraints
|
200 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
201 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
202 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
203 |
+
return x
|
204 |
+
|
205 |
+
|
206 |
+
class RelativePositionBias(nn.Module):
|
207 |
+
|
208 |
+
def __init__(self, window_size, num_heads):
|
209 |
+
super().__init__()
|
210 |
+
self.window_size = window_size
|
211 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
212 |
+
self.relative_position_bias_table = nn.Parameter(
|
213 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
214 |
+
# cls to token & token 2 cls & cls to cls
|
215 |
+
|
216 |
+
# get pair-wise relative position index for each token inside the window
|
217 |
+
coords_h = torch.arange(window_size[0])
|
218 |
+
coords_w = torch.arange(window_size[1])
|
219 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
220 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
221 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
222 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
223 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
224 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
225 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
226 |
+
relative_position_index = \
|
227 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
228 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
229 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
230 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
231 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
232 |
+
|
233 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
234 |
+
|
235 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
236 |
+
|
237 |
+
def forward(self):
|
238 |
+
relative_position_bias = \
|
239 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
240 |
+
self.window_size[0] * self.window_size[1] + 1,
|
241 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
242 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
243 |
+
|
244 |
+
|
245 |
+
class VisionTransformer(nn.Module):
|
246 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
247 |
+
"""
|
248 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
249 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
250 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
|
251 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
|
252 |
+
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
|
253 |
+
super().__init__()
|
254 |
+
self.image_size = img_size
|
255 |
+
self.num_classes = num_classes
|
256 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
257 |
+
|
258 |
+
self.patch_embed = PatchEmbed(
|
259 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
260 |
+
num_patches = self.patch_embed.num_patches
|
261 |
+
|
262 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
263 |
+
if use_abs_pos_emb:
|
264 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
265 |
+
else:
|
266 |
+
self.pos_embed = None
|
267 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
268 |
+
|
269 |
+
if use_shared_rel_pos_bias:
|
270 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
271 |
+
else:
|
272 |
+
self.rel_pos_bias = None
|
273 |
+
self.use_checkpoint = use_checkpoint
|
274 |
+
|
275 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
276 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
277 |
+
self.blocks = nn.ModuleList([
|
278 |
+
Block(
|
279 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
280 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
281 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
|
282 |
+
for i in range(depth)])
|
283 |
+
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
284 |
+
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
285 |
+
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
286 |
+
|
287 |
+
if self.pos_embed is not None:
|
288 |
+
trunc_normal_(self.pos_embed, std=.02)
|
289 |
+
trunc_normal_(self.cls_token, std=.02)
|
290 |
+
# trunc_normal_(self.mask_token, std=.02)
|
291 |
+
# if isinstance(self.head, nn.Linear):
|
292 |
+
# trunc_normal_(self.head.weight, std=.02)
|
293 |
+
self.apply(self._init_weights)
|
294 |
+
self.fix_init_weight()
|
295 |
+
# if isinstance(self.head, nn.Linear):
|
296 |
+
# self.head.weight.data.mul_(init_scale)
|
297 |
+
# self.head.bias.data.mul_(init_scale)
|
298 |
+
|
299 |
+
def fix_init_weight(self):
|
300 |
+
def rescale(param, layer_id):
|
301 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
302 |
+
|
303 |
+
for layer_id, layer in enumerate(self.blocks):
|
304 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
305 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
306 |
+
|
307 |
+
def _init_weights(self, m):
|
308 |
+
if isinstance(m, nn.Linear):
|
309 |
+
trunc_normal_(m.weight, std=.02)
|
310 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
311 |
+
nn.init.constant_(m.bias, 0)
|
312 |
+
elif isinstance(m, nn.LayerNorm):
|
313 |
+
nn.init.constant_(m.bias, 0)
|
314 |
+
nn.init.constant_(m.weight, 1.0)
|
315 |
+
|
316 |
+
def get_classifier(self):
|
317 |
+
return self.head
|
318 |
+
|
319 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
320 |
+
self.num_classes = num_classes
|
321 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
322 |
+
|
323 |
+
def forward_features(self, x):
|
324 |
+
x = self.patch_embed(x)
|
325 |
+
batch_size, seq_len, _ = x.size()
|
326 |
+
|
327 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
328 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
329 |
+
if self.pos_embed is not None:
|
330 |
+
x = x + self.pos_embed
|
331 |
+
x = self.pos_drop(x)
|
332 |
+
|
333 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
334 |
+
for blk in self.blocks:
|
335 |
+
if self.use_checkpoint:
|
336 |
+
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
337 |
+
else:
|
338 |
+
x = blk(x, rel_pos_bias)
|
339 |
+
return x
|
340 |
+
# x = self.norm(x)
|
341 |
+
|
342 |
+
# if self.fc_norm is not None:
|
343 |
+
# t = x[:, 1:, :]
|
344 |
+
# return self.fc_norm(t.mean(1))
|
345 |
+
# else:
|
346 |
+
# return x[:, 0]
|
347 |
+
|
348 |
+
def forward(self, x):
|
349 |
+
x = self.forward_features(x)
|
350 |
+
# x = self.head(x)
|
351 |
+
return x
|
352 |
+
|
353 |
+
def get_intermediate_layers(self, x):
|
354 |
+
x = self.patch_embed(x)
|
355 |
+
batch_size, seq_len, _ = x.size()
|
356 |
+
|
357 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
358 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
359 |
+
if self.pos_embed is not None:
|
360 |
+
x = x + self.pos_embed
|
361 |
+
x = self.pos_drop(x)
|
362 |
+
|
363 |
+
features = []
|
364 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
365 |
+
for blk in self.blocks:
|
366 |
+
x = blk(x, rel_pos_bias)
|
367 |
+
features.append(x)
|
368 |
+
|
369 |
+
return features
|
370 |
+
|
371 |
+
@property
|
372 |
+
def dtype(self):
|
373 |
+
return self.cls_token.dtype
|
374 |
+
|
375 |
+
@property
|
376 |
+
def device(self):
|
377 |
+
return self.cls_token.device
|
378 |
+
|
379 |
+
def get_num_layer(self, var_name=""):
|
380 |
+
if var_name in ("cls_token", "mask_token", "pos_embed"):
|
381 |
+
return 0
|
382 |
+
elif var_name.startswith("patch_embed"):
|
383 |
+
return 0
|
384 |
+
elif var_name.startswith("rel_pos_bias"):
|
385 |
+
return len(self.blocks) - 1
|
386 |
+
elif var_name.startswith("blocks"):
|
387 |
+
layer_id = int(var_name.split('.')[1])
|
388 |
+
return layer_id + 1
|
389 |
+
else:
|
390 |
+
return len(self.blocks)
|
391 |
+
|
392 |
+
|
393 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
394 |
+
if 'pos_embed' in checkpoint_model:
|
395 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
|
396 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
397 |
+
num_patches = model.patch_embed.num_patches
|
398 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
399 |
+
# height (== width) for the checkpoint position embedding
|
400 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
401 |
+
# height (== width) for the new position embedding
|
402 |
+
new_size = int(num_patches ** 0.5)
|
403 |
+
# class_token and dist_token are kept unchanged
|
404 |
+
if orig_size != new_size:
|
405 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
406 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
407 |
+
# only the position tokens are interpolated
|
408 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
409 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
410 |
+
pos_tokens = torch.nn.functional.interpolate(
|
411 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
412 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
413 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
414 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
415 |
+
|
416 |
+
|
417 |
+
def convert_weights_to_fp16(model: nn.Module):
|
418 |
+
"""Convert applicable model parameters to fp16"""
|
419 |
+
|
420 |
+
def _convert_weights_to_fp16(l):
|
421 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
422 |
+
l.weight.data = l.weight.data.half()
|
423 |
+
if l.bias is not None:
|
424 |
+
l.bias.data = l.bias.data.half()
|
425 |
+
|
426 |
+
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
427 |
+
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
428 |
+
# tensor = getattr(l, attr)
|
429 |
+
# if tensor is not None:
|
430 |
+
# tensor.data = tensor.data.half()
|
431 |
+
|
432 |
+
model.apply(_convert_weights_to_fp16)
|
433 |
+
|
434 |
+
class EVAVisionTower(nn.Module):
|
435 |
+
def __init__(self, vision_tower, image_processor, args, use_checkpoint=False, drop_path_rate=0.0, delay_load=False, dtype=torch.float32):
|
436 |
+
super().__init__()
|
437 |
+
|
438 |
+
self.is_loaded = False
|
439 |
+
self.use_checkpoint = use_checkpoint
|
440 |
+
self.vision_tower_name = vision_tower
|
441 |
+
self.image_processor_name = image_processor
|
442 |
+
self.drop_path_rate = drop_path_rate
|
443 |
+
self.patch_size = 14
|
444 |
+
self.out_channel = 1408
|
445 |
+
if not delay_load:
|
446 |
+
self.load_model()
|
447 |
+
|
448 |
+
self.vision_config = CLIPVisionConfig.from_pretrained(image_processor)
|
449 |
+
|
450 |
+
def load_model(self):
|
451 |
+
# self.image_processor = CLIPImageProcessor.from_pretrained(self.image_processor_name)
|
452 |
+
self.image_processor = VideoFramesProcessor.from_pretrained(self.image_processor_name)
|
453 |
+
self.vision_tower = VisionTransformer(
|
454 |
+
img_size=self.image_processor.size['shortest_edge'],
|
455 |
+
patch_size=self.patch_size,
|
456 |
+
use_mean_pooling=False,
|
457 |
+
embed_dim=self.out_channel,
|
458 |
+
depth=39,
|
459 |
+
num_heads=self.out_channel//88,
|
460 |
+
mlp_ratio=4.3637,
|
461 |
+
qkv_bias=True,
|
462 |
+
drop_path_rate=self.drop_path_rate,
|
463 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
464 |
+
use_checkpoint=self.use_checkpoint,
|
465 |
+
)
|
466 |
+
|
467 |
+
state_dict = torch.load(self.vision_tower_name, map_location="cpu")
|
468 |
+
interpolate_pos_embed(self.vision_tower, state_dict)
|
469 |
+
incompatible_keys = self.vision_tower.load_state_dict(state_dict, strict=False)
|
470 |
+
print(incompatible_keys)
|
471 |
+
self.vision_tower.requires_grad_(False)
|
472 |
+
|
473 |
+
self.is_loaded = True
|
474 |
+
|
475 |
+
@torch.no_grad()
|
476 |
+
def forward(self, images):
|
477 |
+
if type(images) is list:
|
478 |
+
image_features = []
|
479 |
+
for image in images:
|
480 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
481 |
+
image_feature = image_forward_out.to(image.dtype)
|
482 |
+
image_features.append(image_feature)
|
483 |
+
else:
|
484 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
|
485 |
+
image_features = image_forward_outs.to(images.dtype)
|
486 |
+
|
487 |
+
return image_features
|
488 |
+
|
489 |
+
def feature_select(self, image_features):
|
490 |
+
# image_features = image_features.hidden_states[self.select_layer]
|
491 |
+
if self.select_feature == 'patch':
|
492 |
+
image_features = image_features[:, 1:]
|
493 |
+
elif self.select_feature == 'cls_patch':
|
494 |
+
image_features = image_features
|
495 |
+
else:
|
496 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
497 |
+
return image_features
|
498 |
+
|
499 |
+
@property
|
500 |
+
def dummy_feature(self):
|
501 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
502 |
+
|
503 |
+
@property
|
504 |
+
def dtype(self):
|
505 |
+
return self.vision_tower.dtype
|
506 |
+
|
507 |
+
@property
|
508 |
+
def device(self):
|
509 |
+
return self.vision_tower.device
|
510 |
+
|
511 |
+
@property
|
512 |
+
def config(self):
|
513 |
+
return self.vision_config
|
514 |
+
|
515 |
+
@property
|
516 |
+
def hidden_size(self):
|
517 |
+
return self.out_channel
|
518 |
+
|
519 |
+
@property
|
520 |
+
def num_patches(self):
|
521 |
+
return (self.image_processor.size['shortest_edge'] // self.patch_size) ** 2
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,model_path=None,precision="fp16"):
|
526 |
+
model = VisionTransformer(
|
527 |
+
img_size=img_size,
|
528 |
+
patch_size=14,
|
529 |
+
use_mean_pooling=False,
|
530 |
+
embed_dim=1408,
|
531 |
+
depth=39,
|
532 |
+
num_heads=1408//88,
|
533 |
+
mlp_ratio=4.3637,
|
534 |
+
qkv_bias=True,
|
535 |
+
drop_path_rate=drop_path_rate,
|
536 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
537 |
+
use_checkpoint=use_checkpoint,
|
538 |
+
)
|
539 |
+
# url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
540 |
+
# cached_file = download_cached_file(
|
541 |
+
# url, check_hash=False, progress=True
|
542 |
+
# )
|
543 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
544 |
+
interpolate_pos_embed(model,state_dict)
|
545 |
+
|
546 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
547 |
+
print(incompatible_keys)
|
548 |
+
|
549 |
+
if precision == "fp16":
|
550 |
+
convert_weights_to_fp16(model)
|
551 |
+
return model
|
model/multimodal_encoder/openclip_encoder.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import deepspeed
|
8 |
+
from pathlib import Path
|
9 |
+
from open_clip.factory import load_state_dict, get_model_config
|
10 |
+
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed
|
11 |
+
from typing import Dict, Optional
|
12 |
+
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
13 |
+
|
14 |
+
|
15 |
+
class OpenCLIPVisionTower(nn.Module):
|
16 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.is_loaded = False
|
20 |
+
self.vision_tower_name = vision_tower
|
21 |
+
self.vision_config = json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r'))
|
22 |
+
self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False)
|
23 |
+
|
24 |
+
if not delay_load:
|
25 |
+
self.load_model()
|
26 |
+
|
27 |
+
def load_model(self):
|
28 |
+
ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin')
|
29 |
+
if 'convnext' in self.vision_tower_name:
|
30 |
+
if 'large' in self.vision_tower_name and 'd-320' in self.vision_tower_name:
|
31 |
+
self.model_type = 'convnext_large_d_320'
|
32 |
+
self.model_channel = [192, 384, 768, 1536] # stage 0-3
|
33 |
+
elif 'base' in self.vision_tower_name and 'w-320' in self.vision_tower_name:
|
34 |
+
self.model_type = 'convnext_base_w_320'
|
35 |
+
self.model_channel = [128, 256, 512, 1024]
|
36 |
+
elif 'xxlarge' in self.vision_tower_name:
|
37 |
+
self.model_type = 'convnext_xxlarge'
|
38 |
+
self.model_channel = [384, 768, 1536, 3072]
|
39 |
+
|
40 |
+
clip_model = CLIP(**get_model_config(self.model_type))
|
41 |
+
clip_model.visual.trunk.norm_pre = None
|
42 |
+
clip_model.visual.trunk.head = None
|
43 |
+
clip_model.visual.head = None
|
44 |
+
print(f'Loading pretrained weights ({self.model_type}).')
|
45 |
+
load_checkpoint(clip_model, ckpt_path, strict=False)
|
46 |
+
|
47 |
+
self.is_loaded = True
|
48 |
+
# decompose stem and stages blocks in vision tower
|
49 |
+
self.vision_stem = clip_model.visual.trunk.stem
|
50 |
+
self.vision_stages = clip_model.visual.trunk.stages
|
51 |
+
self.vision_stem.requires_grad_(False)
|
52 |
+
self.vision_stages.requires_grad_(False)
|
53 |
+
|
54 |
+
def forward(self, images):
|
55 |
+
if type(images) is list:
|
56 |
+
image_features = []
|
57 |
+
for image in images:
|
58 |
+
image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
|
59 |
+
image_features.append(image_feature)
|
60 |
+
else:
|
61 |
+
image_features = self.backbone(images.to(device=self.device, dtype=self.dtype))
|
62 |
+
|
63 |
+
return image_features
|
64 |
+
|
65 |
+
def backbone(self, images):
|
66 |
+
if not self.is_optimize:
|
67 |
+
with torch.no_grad():
|
68 |
+
results = self.basic_forward(images)
|
69 |
+
else:
|
70 |
+
results = self.basic_forward(images)
|
71 |
+
|
72 |
+
target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1])
|
73 |
+
result_cat = []
|
74 |
+
for _stage in results:
|
75 |
+
if _stage == 'stage_0':
|
76 |
+
result_cat.append(results[_stage].contiguous())
|
77 |
+
else:
|
78 |
+
result_cat.append(F.interpolate(results[_stage].float().contiguous() ,
|
79 |
+
size=target_size,
|
80 |
+
mode='bilinear',
|
81 |
+
align_corners=False).to(dtype=results[_stage].dtype))
|
82 |
+
result_cat = torch.cat(result_cat, dim=1)
|
83 |
+
|
84 |
+
return result_cat.contiguous()
|
85 |
+
|
86 |
+
def basic_forward(self, images):
|
87 |
+
results = {}
|
88 |
+
x = self.vision_stem(images)
|
89 |
+
for _idx in range(len(self.vision_stages)):
|
90 |
+
x = self.vision_stages[_idx](x)
|
91 |
+
results[f'stage_{_idx}'] = x
|
92 |
+
return results
|
93 |
+
|
94 |
+
@property
|
95 |
+
def dummy_feature(self):
|
96 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
97 |
+
|
98 |
+
@property
|
99 |
+
def dtype(self):
|
100 |
+
return self.vision_stem[0].weight.dtype
|
101 |
+
|
102 |
+
@property
|
103 |
+
def device(self):
|
104 |
+
return self.vision_stem[0].weight.device
|
105 |
+
|
106 |
+
@property
|
107 |
+
def config(self):
|
108 |
+
return self.vision_config
|
109 |
+
|
110 |
+
@property
|
111 |
+
def hidden_size(self):
|
112 |
+
return sum(self.model_channel)
|
113 |
+
|
114 |
+
# modified function from open_clip to support zero3 stage
|
115 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
116 |
+
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
|
117 |
+
from open_clip.big_vision import load_big_vision_weights
|
118 |
+
load_big_vision_weights(model, checkpoint_path)
|
119 |
+
return {}
|
120 |
+
|
121 |
+
state_dict = load_state_dict(checkpoint_path)
|
122 |
+
# detect old format and make compatible with new format
|
123 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
124 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
125 |
+
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
|
126 |
+
# if 'logit_bias' not in state_dict and model.logit_bias is not None:
|
127 |
+
# state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
|
128 |
+
# Certain text transformers no longer expect position_ids after transformers==4.31
|
129 |
+
position_id_key = 'text.transformer.embeddings.position_ids'
|
130 |
+
if position_id_key in state_dict and not hasattr(model, position_id_key):
|
131 |
+
del state_dict[position_id_key]
|
132 |
+
resize_pos_embed(state_dict, model)
|
133 |
+
# resize_text_pos_embed(state_dict, model)
|
134 |
+
#incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
135 |
+
if is_deepspeed_zero3_enabled():
|
136 |
+
|
137 |
+
error_msgs = []
|
138 |
+
|
139 |
+
def load(module: nn.Module, state_dict, prefix=""):
|
140 |
+
metadata = None
|
141 |
+
|
142 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
143 |
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
144 |
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
145 |
+
# state_dict
|
146 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
147 |
+
if is_deepspeed_zero3_enabled():
|
148 |
+
# In sharded models, each shard has only part of the full state_dict, so only gather
|
149 |
+
# parameters that are in the current state_dict.
|
150 |
+
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
151 |
+
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
152 |
+
if len(params_to_gather) > 0:
|
153 |
+
# because zero3 puts placeholders in model params, this context
|
154 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
155 |
+
# the state dict and then re-partitions them again
|
156 |
+
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
157 |
+
if torch.distributed.get_rank() == 0:
|
158 |
+
module._load_from_state_dict(*args)
|
159 |
+
else:
|
160 |
+
module._load_from_state_dict(*args)
|
161 |
+
|
162 |
+
for name, child in module._modules.items():
|
163 |
+
if child is not None:
|
164 |
+
load(child, state_dict, prefix + name + ".")
|
165 |
+
|
166 |
+
load(model, state_dict)
|
167 |
+
incompatible_keys = []
|
168 |
+
else:
|
169 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
170 |
+
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
|
171 |
+
return incompatible_keys
|
172 |
+
|
173 |
+
class CLIP(nn.Module):
|
174 |
+
output_dict: torch.jit.Final[bool]
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
embed_dim: int,
|
179 |
+
vision_cfg: CLIPVisionCfg,
|
180 |
+
text_cfg: CLIPTextCfg,
|
181 |
+
quick_gelu: bool = False,
|
182 |
+
cast_dtype: Optional[torch.dtype] = None,
|
183 |
+
output_dict: bool = False,
|
184 |
+
):
|
185 |
+
super().__init__()
|
186 |
+
self.output_dict = output_dict
|
187 |
+
|
188 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
model/multimodal_projector/__pycache__/builder.cpython-311.pyc
ADDED
Binary file (3.59 kB). View file
|
|
model/multimodal_projector/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (2.02 kB). View file
|
|
model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import re
|
4 |
+
|
5 |
+
class IdentityMap(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
def forward(self, x, *args, **kwargs):
|
10 |
+
return x
|
11 |
+
|
12 |
+
@property
|
13 |
+
def config(self):
|
14 |
+
return {"mm_projector_type": 'identity'}
|
15 |
+
|
16 |
+
|
17 |
+
class SimpleResBlock(nn.Module):
|
18 |
+
def __init__(self, channels):
|
19 |
+
super().__init__()
|
20 |
+
self.pre_norm = nn.LayerNorm(channels)
|
21 |
+
|
22 |
+
self.proj = nn.Sequential(
|
23 |
+
nn.Linear(channels, channels),
|
24 |
+
nn.GELU(),
|
25 |
+
nn.Linear(channels, channels)
|
26 |
+
)
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.pre_norm(x)
|
29 |
+
return x + self.proj(x)
|
30 |
+
|
31 |
+
|
32 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
33 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
34 |
+
|
35 |
+
if projector_type == 'linear':
|
36 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
37 |
+
|
38 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
39 |
+
if mlp_gelu_match:
|
40 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
41 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
42 |
+
for _ in range(1, mlp_depth):
|
43 |
+
modules.append(nn.GELU())
|
44 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
45 |
+
return nn.Sequential(*modules)
|
46 |
+
|
47 |
+
if projector_type == 'identity':
|
48 |
+
return IdentityMap()
|
49 |
+
|
50 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
model/processor/__pycache__/video_processor.cpython-311.pyc
ADDED
Binary file (4.86 kB). View file
|
|
model/processor/__pycache__/video_processor.cpython-39.pyc
ADDED
Binary file (2.84 kB). View file
|
|
model/processor/video_processor.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPImageProcessor
|
2 |
+
from transformers.image_processing_utils import BatchFeature, get_size_dict
|
3 |
+
from transformers.image_transforms import get_resize_output_image_size
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class VideoFramesProcessor(CLIPImageProcessor):
|
12 |
+
|
13 |
+
def __init__(self, **kwargs):
|
14 |
+
super().__init__(**kwargs)
|
15 |
+
|
16 |
+
def preprocess(self, images, **kwargs):
|
17 |
+
if not isinstance(images, np.ndarray):
|
18 |
+
return super().preprocess(images=images, **kwargs)
|
19 |
+
|
20 |
+
do_resize = kwargs.get('do_resize', self.do_resize)
|
21 |
+
size = kwargs.get('size', self.size)
|
22 |
+
size = get_size_dict(size, param_name="size", default_to_square=False)
|
23 |
+
do_center_crop = kwargs.get('do_center_crop', self.do_center_crop)
|
24 |
+
crop_size = kwargs.get('crop_size', self.crop_size)
|
25 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
|
26 |
+
do_rescale = kwargs.get('do_rescale', self.do_rescale)
|
27 |
+
rescale_factor = kwargs.get('rescale_factor', self.rescale_factor)
|
28 |
+
do_normalize = kwargs.get('do_normalize', self.do_normalize)
|
29 |
+
image_mean = kwargs.get('image_mean', self.image_mean)
|
30 |
+
image_std = kwargs.get('image_std', self.image_std)
|
31 |
+
return_tensors = kwargs.get('return_tensors', None)
|
32 |
+
|
33 |
+
def resize(images, output_size):
|
34 |
+
images = images.permute((0, 3, 1, 2))
|
35 |
+
images = F.interpolate(images, size=output_size, mode='bicubic')
|
36 |
+
images = images.permute((0, 2, 3, 1))
|
37 |
+
return images
|
38 |
+
|
39 |
+
def center_crop(images, crop_size):
|
40 |
+
crop_width, crop_height = crop_size["width"], crop_size["height"]
|
41 |
+
img_width, img_height = images.shape[1:3]
|
42 |
+
x = (img_width - crop_width) // 2
|
43 |
+
y = (img_height - crop_height) // 2
|
44 |
+
images = images[:, x:x+crop_width, y:y+crop_height]
|
45 |
+
return images
|
46 |
+
|
47 |
+
def rescale(images, rescale_factor):
|
48 |
+
images = images * rescale_factor
|
49 |
+
return images
|
50 |
+
|
51 |
+
def normalize(images, mean, std):
|
52 |
+
mean = torch.tensor(mean)
|
53 |
+
std = torch.tensor(std)
|
54 |
+
images = (images - mean) / std
|
55 |
+
return images
|
56 |
+
|
57 |
+
images = torch.from_numpy(images).float()
|
58 |
+
|
59 |
+
if do_resize:
|
60 |
+
output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False)
|
61 |
+
images = resize(images, output_size)
|
62 |
+
|
63 |
+
if do_center_crop:
|
64 |
+
images = center_crop(images, crop_size)
|
65 |
+
|
66 |
+
if do_rescale:
|
67 |
+
images = rescale(images, rescale_factor)
|
68 |
+
|
69 |
+
if do_normalize:
|
70 |
+
images = normalize(images, image_mean, image_std)
|
71 |
+
|
72 |
+
images = images.permute((0, 3, 1, 2))
|
73 |
+
data = {"pixel_values": images}
|
74 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
model/quant.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
import torch
|
3 |
+
from torch import distributed as tdist, nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn.functional import scaled_dot_product_attention
|
6 |
+
|
7 |
+
# from utils import dist
|
8 |
+
|
9 |
+
# this file only provides the VectorQuantizer2 used in VQVAE
|
10 |
+
__all__ = ['VectorQuantizer', ]
|
11 |
+
|
12 |
+
def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
|
13 |
+
E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
|
14 |
+
E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
|
15 |
+
logits = -E_dist.float().mul_(inv_entropy_tau)
|
16 |
+
# calc per_sample_entropy
|
17 |
+
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
|
18 |
+
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
|
19 |
+
# calc codebook_entropy
|
20 |
+
avg_prob = prob.mean(dim=0) # (vocab_size,)
|
21 |
+
log_avg_prob = torch.log(avg_prob + 1e-7)
|
22 |
+
codebook_entropy = (-avg_prob * log_avg_prob).sum()
|
23 |
+
# calc entropy_loss
|
24 |
+
entropy_loss = per_sample_entropy - codebook_entropy
|
25 |
+
return entropy_loss
|
26 |
+
|
27 |
+
|
28 |
+
class NormalizedEmbedding(nn.Embedding):
|
29 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
30 |
+
super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
31 |
+
# self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
|
32 |
+
|
33 |
+
def forward(self, idx):
|
34 |
+
return F.embedding(
|
35 |
+
idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
|
36 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse
|
37 |
+
)
|
38 |
+
|
39 |
+
def get_norm_weight(self):
|
40 |
+
return F.normalize(self.weight, dim=1)
|
41 |
+
|
42 |
+
|
43 |
+
class ResConv(nn.Conv2d):
|
44 |
+
def __init__(self, embed_dim, quant_resi):
|
45 |
+
ks = 3 if quant_resi < 0 else 1
|
46 |
+
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
|
47 |
+
self.resi_ratio = abs(quant_resi)
|
48 |
+
|
49 |
+
def forward(self, h_BChw):
|
50 |
+
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
|
51 |
+
|
52 |
+
|
53 |
+
class VectorQuantizer(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self, vocab_size: int, vocab_width: int, vocab_norm: bool, beta: float = 0.25, quant_resi=-0.5,
|
56 |
+
using_entropy_loss=False, entropy_temp=0.01,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.vocab_size: int = vocab_size
|
60 |
+
self.vocab_width: int = vocab_width
|
61 |
+
self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
|
62 |
+
self.vocab_usage_record_times: int = 0
|
63 |
+
|
64 |
+
self.vocab_norm: bool = vocab_norm
|
65 |
+
# self.quant_resi = ResConv(self.vocab_width, quant_resi=quant_resi)
|
66 |
+
self.quant_resi = nn.Identity()
|
67 |
+
self.embedding = nn.Embedding(self.vocab_size, self.vocab_width)
|
68 |
+
self.beta: float = beta
|
69 |
+
|
70 |
+
self.using_entropy_loss, self.inv_entropy_tau = using_entropy_loss, 1 / entropy_temp
|
71 |
+
if not self.vocab_norm:
|
72 |
+
assert not self.using_entropy_loss, 'entropy loss without vocab norm is not supported'
|
73 |
+
|
74 |
+
def init_vocab(self, eini: float):
|
75 |
+
if eini > 0:
|
76 |
+
nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
|
77 |
+
elif eini < 0:
|
78 |
+
base = self.vocab_width ** -0.5
|
79 |
+
base /= 36
|
80 |
+
self.embedding.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
|
81 |
+
|
82 |
+
def extra_repr(self) -> str:
|
83 |
+
return f'beta={self.beta:g}'
|
84 |
+
|
85 |
+
# ===================== `forward` is only used in VAE training =====================
|
86 |
+
def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[
|
87 |
+
torch.Tensor, torch.Tensor, torch.Tensor, List[float]]:
|
88 |
+
f_BChw = f_BChw.float()
|
89 |
+
B, C, h, w = f_BChw.shape
|
90 |
+
if self.vocab_norm:
|
91 |
+
if self.using_entropy_loss:
|
92 |
+
# find the nearest neighbor
|
93 |
+
NxC = f_BChw.permute(0, 2, 3, 1).reshape(-1, C)
|
94 |
+
NxC_no_grad = NxC.detach()
|
95 |
+
NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
|
96 |
+
idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
97 |
+
# get logits
|
98 |
+
E_dist = NxC.square().sum(dim=1, keepdim=True) + self.embedding.weight.square().sum(dim=1,
|
99 |
+
keepdim=False)
|
100 |
+
E_dist.addmm_(NxC, self.embedding.weight.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
|
101 |
+
logits = -E_dist.float().mul_(self.inv_entropy_tau)
|
102 |
+
# calc per_sample_entropy
|
103 |
+
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
|
104 |
+
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
|
105 |
+
# calc codebook_entropy
|
106 |
+
avg_prob = prob.mean(dim=0) # (vocab_size,)
|
107 |
+
log_avg_prob = torch.log(avg_prob + 1e-7)
|
108 |
+
codebook_entropy = (-avg_prob * log_avg_prob).sum()
|
109 |
+
# calc entropy_loss
|
110 |
+
entropy_loss = per_sample_entropy - codebook_entropy
|
111 |
+
else:
|
112 |
+
NxC_no_grad = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C)
|
113 |
+
NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
|
114 |
+
idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
115 |
+
entropy_loss = 0
|
116 |
+
else: # not self.vocab_norm
|
117 |
+
NxC_no_grad = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C)
|
118 |
+
E_dist = NxC_no_grad.square().sum(dim=1, keepdim=True) + self.embedding.weight.data.square().sum(dim=1,
|
119 |
+
keepdim=False)
|
120 |
+
E_dist.addmm_(NxC_no_grad, self.embedding.weight.data.T, alpha=-2, beta=1) # E_dist: N x vocab_size
|
121 |
+
idx_N = torch.argmin(E_dist, dim=1)
|
122 |
+
entropy_loss = 0
|
123 |
+
|
124 |
+
prob_per_class_is_chosen = idx_N.bincount(minlength=self.vocab_size).float()
|
125 |
+
handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
|
126 |
+
self.training and dist.initialized()) else None
|
127 |
+
|
128 |
+
# look up
|
129 |
+
idx_Bhw = idx_N.view(B, h, w)
|
130 |
+
fhat_BChw = self.quant_resi(self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous())
|
131 |
+
|
132 |
+
# calc loss
|
133 |
+
vq_loss = F.mse_loss(fhat_BChw.detach(), f_BChw).mul_(self.beta) + F.mse_loss(fhat_BChw, f_BChw.detach())
|
134 |
+
fhat_BChw = (fhat_BChw.detach() - f_BChw.detach()).add_(f_BChw)
|
135 |
+
|
136 |
+
# update vocab_usage
|
137 |
+
if handler is not None:
|
138 |
+
handler.wait()
|
139 |
+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
|
140 |
+
vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
|
141 |
+
|
142 |
+
if self.vocab_usage_record_times == 0:
|
143 |
+
self.vocab_usage.copy_(prob_per_class_is_chosen)
|
144 |
+
elif self.vocab_usage_record_times < 100:
|
145 |
+
self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
|
146 |
+
else:
|
147 |
+
self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
|
148 |
+
self.vocab_usage_record_times += 1
|
149 |
+
|
150 |
+
return fhat_BChw, vq_loss, entropy_loss, (vocab_usage if ret_usages else None)
|
151 |
+
|
152 |
+
def f_to_idx(self, f_BChw: torch.Tensor) -> torch.LongTensor:
|
153 |
+
f_BChw = f_BChw.float()
|
154 |
+
B, C, h, w = f_BChw.shape
|
155 |
+
with torch.cuda.amp.autocast(enabled=False):
|
156 |
+
# find the nearest embedding
|
157 |
+
query_NxC = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C)
|
158 |
+
if self.vocab_norm:
|
159 |
+
query_NxC = F.normalize(query_NxC, dim=-1)
|
160 |
+
idx_N = torch.argmax(query_NxC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
161 |
+
else:
|
162 |
+
E_dist = torch.sum(query_NxC.square(), dim=1, keepdim=True) + torch.sum(
|
163 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False)
|
164 |
+
E_dist.addmm_(query_NxC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
165 |
+
idx_N = torch.argmin(E_dist, dim=1)
|
166 |
+
return idx_N.view(B, h, w)
|
167 |
+
|
168 |
+
|
169 |
+
class VectorQuantizerHybrid(nn.Module):
|
170 |
+
def __init__(
|
171 |
+
self, vocab_size: int, vocab_width: int, vocab_norm: bool, beta: float = 0.25, quant_resi=-0.5,
|
172 |
+
using_entropy_loss=False, entropy_temp=0.01,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.vocab_size: int = vocab_size
|
176 |
+
self.vocab_width: int = vocab_width
|
177 |
+
self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
|
178 |
+
self.vocab_usage_record_times: int = 0
|
179 |
+
|
180 |
+
self.vocab_norm: bool = vocab_norm
|
181 |
+
# self.quant_resi = ResConv(self.vocab_width, quant_resi=quant_resi)
|
182 |
+
self.embedding = nn.Embedding(self.vocab_size, self.vocab_width)
|
183 |
+
self.beta: float = beta
|
184 |
+
|
185 |
+
self.using_entropy_loss, self.inv_entropy_tau = using_entropy_loss, 1 / entropy_temp
|
186 |
+
if not self.vocab_norm:
|
187 |
+
assert not self.using_entropy_loss, 'entropy loss without vocab norm is not supported'
|
188 |
+
|
189 |
+
def init_vocab(self, eini: float):
|
190 |
+
if eini > 0:
|
191 |
+
nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
|
192 |
+
elif eini < 0:
|
193 |
+
base = self.vocab_width ** -0.5
|
194 |
+
base /= 36
|
195 |
+
self.embedding.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
|
196 |
+
|
197 |
+
def extra_repr(self) -> str:
|
198 |
+
return f'beta={self.beta:g}'
|
199 |
+
|
200 |
+
def forward(self, class_tokens, patch_tokens, ret_usages=False):
|
201 |
+
class_tokens = class_tokens.float()
|
202 |
+
patch_tokens = patch_tokens.float()
|
203 |
+
|
204 |
+
B, L, C = class_tokens.shape
|
205 |
+
B, C, H, W = patch_tokens.shape
|
206 |
+
patch_tokens = patch_tokens.flatten(start_dim=2).permute(0, 2, 1)
|
207 |
+
NxC = torch.cat((class_tokens, patch_tokens), dim=1).reshape(-1, C)
|
208 |
+
if self.vocab_norm:
|
209 |
+
if self.using_entropy_loss:
|
210 |
+
# find the nearest neighbor
|
211 |
+
NxC_no_grad = NxC.detach()
|
212 |
+
NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
|
213 |
+
idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
214 |
+
# get logits
|
215 |
+
E_dist = NxC.square().sum(dim=1, keepdim=True) + self.embedding.weight.square().sum(dim=1,
|
216 |
+
keepdim=False)
|
217 |
+
E_dist.addmm_(NxC, self.embedding.weight.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
|
218 |
+
logits = -E_dist.float().mul_(self.inv_entropy_tau)
|
219 |
+
# calc per_sample_entropy
|
220 |
+
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
|
221 |
+
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
|
222 |
+
# calc codebook_entropy
|
223 |
+
avg_prob = prob.mean(dim=0) # (vocab_size,)
|
224 |
+
log_avg_prob = torch.log(avg_prob + 1e-7)
|
225 |
+
codebook_entropy = (-avg_prob * log_avg_prob).sum()
|
226 |
+
# calc entropy_loss
|
227 |
+
entropy_loss = per_sample_entropy - codebook_entropy
|
228 |
+
else:
|
229 |
+
NxC_no_grad = NxC.detach()
|
230 |
+
NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
|
231 |
+
idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
232 |
+
entropy_loss = 0
|
233 |
+
else: # not self.vocab_norm
|
234 |
+
NxC_no_grad = NxC.detach()
|
235 |
+
E_dist = NxC_no_grad.square().sum(dim=1, keepdim=True) + self.embedding.weight.data.square().sum(dim=1,
|
236 |
+
keepdim=False)
|
237 |
+
E_dist.addmm_(NxC_no_grad, self.embedding.weight.data.T, alpha=-2, beta=1) # E_dist: N x vocab_size
|
238 |
+
idx_N = torch.argmin(E_dist, dim=1)
|
239 |
+
entropy_loss = 0
|
240 |
+
|
241 |
+
prob_per_class_is_chosen = idx_N.bincount(minlength=self.vocab_size).float()
|
242 |
+
handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
|
243 |
+
self.training and dist.initialized()) else None
|
244 |
+
|
245 |
+
# look up
|
246 |
+
fhat = self.embedding(idx_N)
|
247 |
+
|
248 |
+
# calc loss
|
249 |
+
vq_loss = F.mse_loss(fhat.detach(), NxC).mul_(self.beta) + F.mse_loss(fhat, NxC.detach())
|
250 |
+
fhat = (fhat.detach() - NxC.detach()).add_(NxC)
|
251 |
+
|
252 |
+
# update vocab_usage
|
253 |
+
if handler is not None:
|
254 |
+
handler.wait()
|
255 |
+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
|
256 |
+
vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
|
257 |
+
|
258 |
+
if self.vocab_usage_record_times == 0:
|
259 |
+
self.vocab_usage.copy_(prob_per_class_is_chosen)
|
260 |
+
elif self.vocab_usage_record_times < 100:
|
261 |
+
self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
|
262 |
+
else:
|
263 |
+
self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
|
264 |
+
self.vocab_usage_record_times += 1
|
265 |
+
|
266 |
+
fhat = fhat.view(B, -1, C)
|
267 |
+
fhat_class = fhat[:, :L, :]
|
268 |
+
fhat_patch = fhat[:, L:, :].view(B, H, W, C).permute(0, 3, 1, 2)
|
269 |
+
|
270 |
+
return fhat_class, fhat_patch, vq_loss, entropy_loss, (vocab_usage if ret_usages else None)
|
271 |
+
|
272 |
+
def f_to_idx(self, class_tokens, patch_tokens) -> torch.LongTensor:
|
273 |
+
B, L, C = class_tokens.shape
|
274 |
+
B, C, H, W = patch_tokens.shape
|
275 |
+
class_tokens = class_tokens.float()
|
276 |
+
patch_tokens = patch_tokens.float()
|
277 |
+
patch_tokens = patch_tokens.flatten(start_dim=2).permute(0, 2, 1)
|
278 |
+
NxC = torch.cat((class_tokens, patch_tokens), dim=1).reshape(-1, C)
|
279 |
+
with torch.cuda.amp.autocast(enabled=False):
|
280 |
+
# find the nearest embedding
|
281 |
+
if self.vocab_norm:
|
282 |
+
NxC = F.normalize(NxC, dim=-1)
|
283 |
+
idx_N = torch.argmax(NxC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
284 |
+
else:
|
285 |
+
E_dist = torch.sum(NxC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(),
|
286 |
+
dim=1, keepdim=False)
|
287 |
+
E_dist.addmm_(NxC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
288 |
+
idx_N = torch.argmin(E_dist, dim=1)
|
289 |
+
return idx_N
|
290 |
+
|
291 |
+
|
292 |
+
class VectorQuantizerX(nn.Module):
|
293 |
+
def __init__(
|
294 |
+
self,
|
295 |
+
vocab_size: int,
|
296 |
+
vocab_width: int,
|
297 |
+
beta: float = 0.25,
|
298 |
+
use_entropy_loss=False,
|
299 |
+
entropy_temp=0.01,
|
300 |
+
):
|
301 |
+
super().__init__()
|
302 |
+
self.beta = beta
|
303 |
+
self.vocab_size = vocab_size
|
304 |
+
self.vocab_width = vocab_width
|
305 |
+
self.vocab_usage_record_times: int = 0
|
306 |
+
self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
|
307 |
+
|
308 |
+
self.codebook = NormalizedEmbedding(self.vocab_size, self.vocab_width)
|
309 |
+
|
310 |
+
self.use_entropy_loss = use_entropy_loss
|
311 |
+
self.inv_entropy_tau = 1 / entropy_temp
|
312 |
+
|
313 |
+
def init_vocab(self, eini: float):
|
314 |
+
if eini > 0:
|
315 |
+
nn.init.trunc_normal_(self.codebook.weight.data, std=eini)
|
316 |
+
elif eini < 0:
|
317 |
+
base = self.vocab_width ** -0.5
|
318 |
+
base /= 36
|
319 |
+
self.codebook.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
|
320 |
+
|
321 |
+
def extra_repr(self) -> str:
|
322 |
+
return f'beta={self.beta:g}'
|
323 |
+
|
324 |
+
def forward(self, features):
|
325 |
+
B, L, C = features.shape
|
326 |
+
features = features.reshape(-1, C)
|
327 |
+
features = F.normalize(features, dim=-1).float()
|
328 |
+
codebook_embed = self.codebook.get_norm_weight()
|
329 |
+
indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
|
330 |
+
entropy_loss = get_entropy_loss(features, codebook_embed, self.inv_entropy_tau) if self.use_entropy_loss else 0
|
331 |
+
features_hat = self.codebook(indices)
|
332 |
+
|
333 |
+
# calc loss
|
334 |
+
vq_loss = F.mse_loss(features_hat.detach(), features).mul_(self.beta) + F.mse_loss(features_hat,
|
335 |
+
features.detach())
|
336 |
+
features_hat = (features_hat.detach() - features.detach()).add_(features)
|
337 |
+
|
338 |
+
# update vocab_usage
|
339 |
+
prob_per_class_is_chosen = indices.bincount(minlength=self.vocab_size).float()
|
340 |
+
handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
|
341 |
+
self.training and dist.initialized()) else None
|
342 |
+
if handler is not None:
|
343 |
+
handler.wait()
|
344 |
+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
|
345 |
+
vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
|
346 |
+
if self.vocab_usage_record_times == 0:
|
347 |
+
self.vocab_usage.copy_(prob_per_class_is_chosen)
|
348 |
+
elif self.vocab_usage_record_times < 100:
|
349 |
+
self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
|
350 |
+
else:
|
351 |
+
self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
|
352 |
+
self.vocab_usage_record_times += 1
|
353 |
+
|
354 |
+
return features_hat.view(B, L, C), vq_loss, entropy_loss, vocab_usage
|
355 |
+
|
356 |
+
def f_to_idx(self, features):
|
357 |
+
B, L, C = features.shape
|
358 |
+
features = features.reshape(-1, C)
|
359 |
+
features = F.normalize(features, dim=-1).float()
|
360 |
+
codebook_embed = self.codebook.get_norm_weight().float()
|
361 |
+
indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
|
362 |
+
return indices.view(B, L)
|
363 |
+
|
364 |
+
|
365 |
+
class VectorQuantizerM(nn.Module):
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
vocab_size,
|
369 |
+
vocab_width,
|
370 |
+
beta=0.25,
|
371 |
+
use_entropy_loss=False,
|
372 |
+
entropy_temp=0.01,
|
373 |
+
num_codebooks=16
|
374 |
+
):
|
375 |
+
super().__init__()
|
376 |
+
self.num_codebooks = num_codebooks
|
377 |
+
self.codebooks = nn.ModuleList()
|
378 |
+
for _ in range(num_codebooks):
|
379 |
+
codebook = VectorQuantizerX(
|
380 |
+
vocab_size=vocab_size // num_codebooks,
|
381 |
+
vocab_width=vocab_width // num_codebooks,
|
382 |
+
beta=beta,
|
383 |
+
use_entropy_loss=use_entropy_loss,
|
384 |
+
entropy_temp=entropy_temp,
|
385 |
+
)
|
386 |
+
self.codebooks.append(codebook)
|
387 |
+
|
388 |
+
def init_vocab(self, eini: float):
|
389 |
+
for codebook in self.codebooks:
|
390 |
+
codebook.init_vocab(eini)
|
391 |
+
|
392 |
+
def f_to_idx(self, features):
|
393 |
+
indices = []
|
394 |
+
chunk_size = features.shape[-1] // self.num_codebooks
|
395 |
+
splited_features = features.split(chunk_size, dim=-1)
|
396 |
+
for i, codebook in enumerate(self.codebooks):
|
397 |
+
indices.append(codebook.f_to_idx(splited_features[i]))
|
398 |
+
indices = torch.stack(indices, dim=1)
|
399 |
+
return indices
|
400 |
+
|
401 |
+
def idx_to_f(self, indices):
|
402 |
+
assert indices.shape[1] == self.num_codebooks
|
403 |
+
latent_features = []
|
404 |
+
for i, codebook in enumerate(self.codebooks):
|
405 |
+
sub_indices = indices[:, i].flatten(start_dim=1)
|
406 |
+
latent_feature = codebook.codebook(sub_indices)
|
407 |
+
latent_features.append(latent_feature)
|
408 |
+
latent_features = torch.cat(latent_features, dim=-1)
|
409 |
+
return latent_features
|
410 |
+
|
411 |
+
def forward(self, features):
|
412 |
+
latent_features = []
|
413 |
+
global_vq_loss = 0.
|
414 |
+
global_entropy_loss = 0.
|
415 |
+
global_vocab_usage = 0.
|
416 |
+
chunk_size = features.shape[-1] // self.num_codebooks
|
417 |
+
splited_features = features.split(chunk_size, dim=-1)
|
418 |
+
for i, codebook in enumerate(self.codebooks):
|
419 |
+
latent_feature, vq_loss, entropy_loss, vocab_usage = codebook(splited_features[i])
|
420 |
+
latent_features.append(latent_feature)
|
421 |
+
global_vq_loss += vq_loss
|
422 |
+
global_entropy_loss += entropy_loss
|
423 |
+
global_vocab_usage += vocab_usage
|
424 |
+
latent_features = torch.cat(latent_features, dim=-1)
|
425 |
+
global_entropy_loss /= self.num_codebooks
|
426 |
+
global_vq_loss /= self.num_codebooks
|
427 |
+
global_vocab_usage /= self.num_codebooks
|
428 |
+
return latent_features, global_vq_loss, global_entropy_loss, global_vocab_usage
|
429 |
+
|
430 |
+
|
431 |
+
class CausalAttention(nn.Module):
|
432 |
+
def __init__(self, in_dim, out_dim, num_heads):
|
433 |
+
super().__init__()
|
434 |
+
if in_dim > out_dim:
|
435 |
+
# assert in_dim // num_heads == out_dim
|
436 |
+
self.head_dim = in_dim // num_heads
|
437 |
+
self.qkv = nn.Linear(in_dim, in_dim * 3, bias=False)
|
438 |
+
self.q_bias = nn.Parameter(torch.zeros(in_dim))
|
439 |
+
self.v_bias = nn.Parameter(torch.zeros(in_dim))
|
440 |
+
self.register_buffer('zero_k_bias', torch.zeros(in_dim))
|
441 |
+
else:
|
442 |
+
# assert out_dim // num_heads == in_dim
|
443 |
+
self.head_dim = out_dim // num_heads
|
444 |
+
self.qkv = nn.Linear(in_dim, out_dim * 3, bias=False)
|
445 |
+
self.q_bias = nn.Parameter(torch.zeros(out_dim))
|
446 |
+
self.v_bias = nn.Parameter(torch.zeros(out_dim))
|
447 |
+
self.register_buffer('zero_k_bias', torch.zeros(out_dim))
|
448 |
+
|
449 |
+
self.in_dim = in_dim
|
450 |
+
self.out_dim = out_dim
|
451 |
+
self.num_heads = num_heads
|
452 |
+
self.scale = self.head_dim ** -0.5
|
453 |
+
self.proj = nn.Linear(out_dim, out_dim)
|
454 |
+
|
455 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
456 |
+
B, N, C = x.shape
|
457 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias)))
|
458 |
+
q, k, v = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
|
459 |
+
|
460 |
+
x = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0., is_causal=True)
|
461 |
+
|
462 |
+
if self.in_dim > self.out_dim:
|
463 |
+
x = torch.mean(x, dim=1)
|
464 |
+
if self.in_dim // self.num_heads != self.out_dim:
|
465 |
+
x = nn.functional.adaptive_avg_pool1d(x, self.out_dim)
|
466 |
+
else:
|
467 |
+
x = x.transpose(1, 2).reshape(B, N, -1)
|
468 |
+
x = self.proj(x)
|
469 |
+
return x
|
470 |
+
|
471 |
+
|
472 |
+
class AttnProjection(nn.Module):
|
473 |
+
def __init__(self, in_dim, out_dim, num_heads, norm_layer=nn.LayerNorm, mlp_ratio=2):
|
474 |
+
super().__init__()
|
475 |
+
assert out_dim % in_dim == 0 or in_dim % out_dim == 0
|
476 |
+
self.in_dim = in_dim
|
477 |
+
self.out_dim = out_dim
|
478 |
+
self.norm1 = norm_layer(in_dim)
|
479 |
+
self.attn = CausalAttention(in_dim, out_dim, num_heads)
|
480 |
+
self.proj = nn.Linear(in_dim, out_dim)
|
481 |
+
self.norm3 = norm_layer(in_dim)
|
482 |
+
|
483 |
+
self.norm2 = norm_layer(out_dim)
|
484 |
+
hidden_dim = int(out_dim * mlp_ratio)
|
485 |
+
self.mlp = GeGluMlp(
|
486 |
+
in_features=out_dim,
|
487 |
+
hidden_features=hidden_dim
|
488 |
+
)
|
489 |
+
|
490 |
+
def forward(self, x):
|
491 |
+
x = self.proj(self.norm3(x)) + self.attn(self.norm1(x))
|
492 |
+
x = x + self.mlp(self.norm2(x))
|
493 |
+
return x
|
494 |
+
|
495 |
+
|
496 |
+
from functools import partial
|
497 |
+
from timm.models.layers import create_conv2d, get_norm_act_layer, get_norm_layer, make_divisible
|
498 |
+
|
499 |
+
class GeGluMlp(nn.Module):
|
500 |
+
def __init__(
|
501 |
+
self,
|
502 |
+
in_features,
|
503 |
+
hidden_features,
|
504 |
+
act_layer = None,
|
505 |
+
drop = 0.0,
|
506 |
+
):
|
507 |
+
super().__init__()
|
508 |
+
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
|
509 |
+
self.norm = norm_layer(in_features)
|
510 |
+
self.act = nn.GELU(approximate='tanh')
|
511 |
+
self.w0 = nn.Linear(in_features, hidden_features)
|
512 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
513 |
+
self.w2 = nn.Linear(hidden_features, in_features)
|
514 |
+
|
515 |
+
def forward(self, x):
|
516 |
+
x = self.norm(x)
|
517 |
+
x = self.act(self.w0(x)) * self.w1(x)
|
518 |
+
x = self.w2(x)
|
519 |
+
return x
|
t2i.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from torchvision import transforms
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
10 |
+
|
11 |
+
from model import *
|
12 |
+
from unitok.config import Args
|
13 |
+
from unitok.model import UniTok
|
14 |
+
|
15 |
+
|
16 |
+
PILtransform = transforms.ToPILImage()
|
17 |
+
|
18 |
+
|
19 |
+
def top_k_top_p_filtering(
|
20 |
+
logits,
|
21 |
+
top_k: int = 0,
|
22 |
+
top_p: float = 1.0,
|
23 |
+
filter_value: float = -float("Inf"),
|
24 |
+
min_tokens_to_keep: int = 1,
|
25 |
+
):
|
26 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
27 |
+
Args:
|
28 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
29 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
30 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
31 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
32 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
33 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
34 |
+
"""
|
35 |
+
|
36 |
+
if top_k > 0:
|
37 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
38 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
39 |
+
|
40 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
41 |
+
logits[indices_to_remove] = filter_value
|
42 |
+
|
43 |
+
if top_p < 1.0:
|
44 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
45 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
46 |
+
|
47 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
48 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
49 |
+
if min_tokens_to_keep > 1:
|
50 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
51 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
52 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
53 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
54 |
+
sorted_indices_to_remove[..., 0] = 0
|
55 |
+
|
56 |
+
# scatter sorted tensors to original indexing
|
57 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
58 |
+
logits[indices_to_remove] = filter_value
|
59 |
+
# import pdb;pdb.set_trace()
|
60 |
+
return logits
|
61 |
+
|
62 |
+
|
63 |
+
def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True):
|
64 |
+
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
65 |
+
if top_k > 0 or top_p < 1.0:
|
66 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
67 |
+
probs = F.softmax(logits, dim=-1)
|
68 |
+
if sample_logits:
|
69 |
+
idx = torch.multinomial(probs, num_samples=1)
|
70 |
+
else:
|
71 |
+
_, idx = torch.topk(probs, k=1, dim=-1)
|
72 |
+
return idx, probs
|
73 |
+
|
74 |
+
|
75 |
+
def split_list(input_list, chunk_size):
|
76 |
+
return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
|
77 |
+
|
78 |
+
|
79 |
+
def get_args_parser():
|
80 |
+
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
81 |
+
parser.add_argument('--unitok_path', type=str, required=True)
|
82 |
+
parser.add_argument('--mllm_path', type=str, required=True)
|
83 |
+
parser.add_argument('--prompt_file', type=str, required=True)
|
84 |
+
parser.add_argument('--result_dir', type=str, required=True)
|
85 |
+
parser.add_argument('--idx', type=int, default=0)
|
86 |
+
parser.add_argument('--tau', type=float, default=0.9)
|
87 |
+
parser.add_argument('--topk', type=int, default=2048)
|
88 |
+
parser.add_argument('--topp', type=float, default=0.96)
|
89 |
+
parser.add_argument('--cfg_scale', type=float, default=5.0)
|
90 |
+
return parser
|
91 |
+
|
92 |
+
|
93 |
+
def main(args):
|
94 |
+
text_set_id = args.idx
|
95 |
+
tau = args.tau
|
96 |
+
topk = args.topk
|
97 |
+
topp = args.topp
|
98 |
+
cfg_scale = args.cfg_scale
|
99 |
+
|
100 |
+
print('loading vq model ...')
|
101 |
+
ckpt = torch.load(args.unitok_path, map_location='cpu')
|
102 |
+
vae_cfg = Args()
|
103 |
+
vae_cfg.load_state_dict(ckpt['args'])
|
104 |
+
vq_model = UniTok(vae_cfg)
|
105 |
+
vq_model.load_state_dict(ckpt['trainer']['unitok'])
|
106 |
+
vq_model.to('cuda')
|
107 |
+
vq_model.eval()
|
108 |
+
|
109 |
+
image_save_pth = '{}/GenAI-cfg_{}-topk_{}-topp_{}-tau_{}'.format(args.result_dir, str(cfg_scale), str(topk), str(topp), str(tau))
|
110 |
+
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained(args.mllm_path, padding_side='left')
|
112 |
+
vqllm = AutoModelForCausalLM.from_pretrained(
|
113 |
+
args.mllm_path,
|
114 |
+
attn_implementation='flash_attention_2',
|
115 |
+
torch_dtype=torch.bfloat16
|
116 |
+
).to('cuda')
|
117 |
+
|
118 |
+
num_processes = 8
|
119 |
+
chunk_size = 8 # batchsize
|
120 |
+
num_codebooks = vae_cfg.num_codebooks
|
121 |
+
|
122 |
+
with open(args.prompt_file, 'r') as f:
|
123 |
+
lines = f.readlines()
|
124 |
+
all_prompts = []
|
125 |
+
for index, line in enumerate(lines):
|
126 |
+
all_prompts.append({'Index': str(index + 1).zfill(5), 'Prompt': line.strip()})
|
127 |
+
|
128 |
+
chunked_filenames = np.array_split(all_prompts, num_processes)
|
129 |
+
subset = chunked_filenames[text_set_id].tolist()
|
130 |
+
chunk_inputs = split_list(subset, chunk_size)
|
131 |
+
for chunk in tqdm(chunk_inputs):
|
132 |
+
text_inputs = [v['Prompt'] for v in chunk]
|
133 |
+
uncondition_text_inputs = ['<unconditional>'] * len(text_inputs)
|
134 |
+
for i in range(len(text_inputs)):
|
135 |
+
text_inputs[i] = text_inputs[i] + ' Generate an image based on this description.'
|
136 |
+
ori_batchsize = len(text_inputs)
|
137 |
+
|
138 |
+
save_list = []
|
139 |
+
if cfg_scale > 1:
|
140 |
+
model_inputs = tokenizer(text_inputs + uncondition_text_inputs, return_tensors="pt", padding=True).to('cuda')
|
141 |
+
total_batchsize = len(text_inputs + uncondition_text_inputs)
|
142 |
+
model_inputs['input_ids'] = torch.cat([
|
143 |
+
model_inputs['input_ids'],
|
144 |
+
torch.empty(total_batchsize, 1).fill_(3).to(model_inputs['input_ids'])
|
145 |
+
], dim=1)
|
146 |
+
model_inputs['attention_mask'] = torch.cat([
|
147 |
+
model_inputs['attention_mask'],
|
148 |
+
torch.empty(total_batchsize, 1).fill_(1).to(model_inputs['attention_mask'])
|
149 |
+
], dim=1)
|
150 |
+
else:
|
151 |
+
model_inputs = tokenizer(text_inputs, return_tensors="pt", padding=True).to('cuda')
|
152 |
+
total_batchsize = len(text_inputs)
|
153 |
+
model_inputs['input_ids'] = torch.cat([
|
154 |
+
model_inputs['input_ids'],
|
155 |
+
torch.empty(total_batchsize, 1).fill_(3).to(model_inputs['input_ids'])
|
156 |
+
], dim=1)
|
157 |
+
model_inputs['attention_mask'] = torch.cat([
|
158 |
+
model_inputs['attention_mask'],
|
159 |
+
torch.empty(total_batchsize, 1).fill_(1).to(model_inputs['attention_mask'])
|
160 |
+
], dim=1)
|
161 |
+
with torch.no_grad():
|
162 |
+
sampling_kwargs = {'temperature': tau, 'top_k': topk, 'top_p': topp, 'sample_logits': True}
|
163 |
+
pred_tokens = []
|
164 |
+
input_multi_ids = None
|
165 |
+
for _ in range(256):
|
166 |
+
outputs = vqllm.T2I_forward_nocache(
|
167 |
+
**model_inputs,
|
168 |
+
input_multi_ids=input_multi_ids,
|
169 |
+
use_cache=None,
|
170 |
+
return_dict=True,
|
171 |
+
output_attentions=False,
|
172 |
+
output_hidden_states=False,
|
173 |
+
)
|
174 |
+
next_embed = outputs['last_hidden_state'][:, -1:, :]
|
175 |
+
|
176 |
+
indices_arhead = []
|
177 |
+
for i_head in range(num_codebooks):
|
178 |
+
ar_next_embed = vqllm.ar_head(
|
179 |
+
inputs_embeds=next_embed,
|
180 |
+
use_cache=False,
|
181 |
+
output_attentions=False,
|
182 |
+
output_hidden_states=False,
|
183 |
+
return_dict=False,
|
184 |
+
)
|
185 |
+
next_token_logits = vqllm.ar_head.linear_head(ar_next_embed)
|
186 |
+
if cfg_scale > 1:
|
187 |
+
cond_logits, uncond_logits = torch.split(next_token_logits, len(next_token_logits) // 2, dim=0)
|
188 |
+
cfg_logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
189 |
+
half_next_token, _ = sample(cfg_logits, **sampling_kwargs)
|
190 |
+
next_token = torch.cat([half_next_token, half_next_token]) # [bz,1]
|
191 |
+
else:
|
192 |
+
next_token, next_prob = sample(next_token_logits, **sampling_kwargs)
|
193 |
+
|
194 |
+
indices_arhead.append(next_token)
|
195 |
+
if i_head < num_codebooks - 1:
|
196 |
+
predicted_embed = vqllm.ar_head.codebooks[i_head](next_token)
|
197 |
+
next_embed = torch.cat([next_embed, predicted_embed], dim=1)
|
198 |
+
|
199 |
+
# update generated ids, model inputs, and length for next step
|
200 |
+
pred_tokens.append(torch.cat(indices_arhead, dim=1)) # [numcodebook,bz*2]
|
201 |
+
input_multi_ids = torch.stack(pred_tokens, dim=-1)
|
202 |
+
|
203 |
+
del sampling_kwargs, model_inputs, outputs
|
204 |
+
|
205 |
+
image_vq_id = torch.stack(pred_tokens, dim=-1)[:ori_batchsize]
|
206 |
+
save_list.append(image_vq_id)
|
207 |
+
|
208 |
+
torch.cuda.empty_cache()
|
209 |
+
|
210 |
+
print('decoding images ...')
|
211 |
+
if not os.path.exists(image_save_pth):
|
212 |
+
os.makedirs(image_save_pth)
|
213 |
+
for datainfo, vq_code in zip(chunk, save_list[0]):
|
214 |
+
idx = datainfo['Index']
|
215 |
+
new_gen_ids = vq_code.unsqueeze(0).to('cuda')
|
216 |
+
rec_image = vq_model.idx_to_img(new_gen_ids)
|
217 |
+
rec_img = PILtransform(rec_image.squeeze(0).add(1).mul_(0.5).clamp_(0, 1))
|
218 |
+
rec_img.save('{}/{}.jpg'.format(image_save_pth, str(idx)))
|
219 |
+
|
220 |
+
|
221 |
+
if __name__ == '__main__':
|
222 |
+
parser = argparse.ArgumentParser('genai inference script', parents=[get_args_parser()])
|
223 |
+
args = parser.parse_args()
|
224 |
+
main(args)
|
tools.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from constants import LOGDIR
|
10 |
+
|
11 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
12 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
13 |
+
|
14 |
+
handler = None
|
15 |
+
|
16 |
+
|
17 |
+
def build_logger(logger_name, logger_filename):
|
18 |
+
global handler
|
19 |
+
|
20 |
+
formatter = logging.Formatter(
|
21 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
)
|
24 |
+
|
25 |
+
# Set the format of root handlers
|
26 |
+
if not logging.getLogger().handlers:
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
29 |
+
|
30 |
+
# Redirect stdout and stderr to loggers
|
31 |
+
stdout_logger = logging.getLogger("stdout")
|
32 |
+
stdout_logger.setLevel(logging.INFO)
|
33 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
34 |
+
sys.stdout = sl
|
35 |
+
|
36 |
+
stderr_logger = logging.getLogger("stderr")
|
37 |
+
stderr_logger.setLevel(logging.ERROR)
|
38 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
39 |
+
sys.stderr = sl
|
40 |
+
|
41 |
+
# Get logger
|
42 |
+
logger = logging.getLogger(logger_name)
|
43 |
+
logger.setLevel(logging.INFO)
|
44 |
+
|
45 |
+
# Add a file handler for all loggers
|
46 |
+
if handler is None:
|
47 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
48 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
49 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
50 |
+
filename, when='D', utc=True, encoding='UTF-8')
|
51 |
+
handler.setFormatter(formatter)
|
52 |
+
|
53 |
+
for name, item in logging.root.manager.loggerDict.items():
|
54 |
+
if isinstance(item, logging.Logger):
|
55 |
+
item.addHandler(handler)
|
56 |
+
|
57 |
+
return logger
|
58 |
+
|
59 |
+
|
60 |
+
class StreamToLogger(object):
|
61 |
+
"""
|
62 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
63 |
+
"""
|
64 |
+
def __init__(self, logger, log_level=logging.INFO):
|
65 |
+
self.terminal = sys.stdout
|
66 |
+
self.logger = logger
|
67 |
+
self.log_level = log_level
|
68 |
+
self.linebuf = ''
|
69 |
+
|
70 |
+
def __getattr__(self, attr):
|
71 |
+
return getattr(self.terminal, attr)
|
72 |
+
|
73 |
+
def write(self, buf):
|
74 |
+
temp_linebuf = self.linebuf + buf
|
75 |
+
self.linebuf = ''
|
76 |
+
for line in temp_linebuf.splitlines(True):
|
77 |
+
# From the io.TextIOWrapper docs:
|
78 |
+
# On output, if newline is None, any '\n' characters written
|
79 |
+
# are translated to the system default line separator.
|
80 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
81 |
+
# translates them so this is still cross platform.
|
82 |
+
if line[-1] == '\n':
|
83 |
+
self.logger.log(self.log_level, line.rstrip())
|
84 |
+
else:
|
85 |
+
self.linebuf += line
|
86 |
+
|
87 |
+
def flush(self):
|
88 |
+
if self.linebuf != '':
|
89 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
90 |
+
self.linebuf = ''
|
91 |
+
|
92 |
+
|
93 |
+
def disable_torch_init():
|
94 |
+
"""
|
95 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
96 |
+
"""
|
97 |
+
import torch
|
98 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
99 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
100 |
+
|
101 |
+
|
102 |
+
def violates_moderation(text):
|
103 |
+
"""
|
104 |
+
Check whether the text violates OpenAI moderation API.
|
105 |
+
"""
|
106 |
+
url = "https://api.openai.com/v1/moderations"
|
107 |
+
headers = {"Content-Type": "application/json",
|
108 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
109 |
+
text = text.replace("\n", "")
|
110 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
111 |
+
data = data.encode("utf-8")
|
112 |
+
try:
|
113 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
114 |
+
flagged = ret.json()["results"][0]["flagged"]
|
115 |
+
except requests.exceptions.RequestException as e:
|
116 |
+
flagged = False
|
117 |
+
except KeyError as e:
|
118 |
+
flagged = False
|
119 |
+
|
120 |
+
return flagged
|
121 |
+
|
122 |
+
|
123 |
+
def pretty_print_semaphore(semaphore):
|
124 |
+
if semaphore is None:
|
125 |
+
return "None"
|
126 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
unitok/config.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from tap import Tap
|
7 |
+
from typing import Optional, Union
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
from unitok import dist
|
11 |
+
|
12 |
+
|
13 |
+
class Args(Tap):
|
14 |
+
model: str = 'vitamin_large' # 'vitamin_base', 'vitamin_large', xxx
|
15 |
+
exp_name: str = 'unitok_large'
|
16 |
+
output_dir: str = 'local_output'
|
17 |
+
resume_from: str = '' # if specified, load this checkpoint; if not, load the latest checkpoint in output_dir (if exists)
|
18 |
+
lpips_path: str = 'external/lpips_with_vgg.pth'
|
19 |
+
dino_path: str = 'external/dinov2_vits14_pretrain.pth'
|
20 |
+
fid_eval_src: str = ''
|
21 |
+
fid_eval_dst: str = ''
|
22 |
+
vis_img_dir: str = 'asset/vis_imgs/'
|
23 |
+
fid_feature_extractor: str = 'external/weights-inception-2015-12-05-6726825d.pth'
|
24 |
+
clip_pretrain_path: str = ''
|
25 |
+
|
26 |
+
# speed-up
|
27 |
+
fp16: bool = False # whether to use FP16
|
28 |
+
bf16: bool = True # whether to use BF16
|
29 |
+
tf32: bool = True # whether to use TensorFloat32
|
30 |
+
compile_model: bool = False # whether to use torch.compile()
|
31 |
+
ddp_static: bool = False # whether to use static graph in DDP
|
32 |
+
grad_ckpt: bool = True # gradient checkpointing
|
33 |
+
grad_accu: int = 1 # gradient accumulation
|
34 |
+
device: str = 'cpu' # will be set automatically
|
35 |
+
dtype: torch.dtype = torch.float32 # will be set automatically
|
36 |
+
|
37 |
+
# data
|
38 |
+
train_data: str = None
|
39 |
+
val_data: str = None
|
40 |
+
dataset_type: str = 'webdataset'
|
41 |
+
imagenet_val: str = None
|
42 |
+
imagenet_v2: str = None
|
43 |
+
subset_ratio: float = 1.0
|
44 |
+
img_size: int = 256
|
45 |
+
resize_ratio: float = 1.125 # only applicable to 'img' dataset_type
|
46 |
+
hflip: bool = False
|
47 |
+
workers: int = 8 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
|
48 |
+
train_num_samples: int = 1280_000_000
|
49 |
+
train_data_upsampling_factors: str = None
|
50 |
+
dataset_resampled: bool = False
|
51 |
+
use_aug: bool = False
|
52 |
+
|
53 |
+
# quantizer
|
54 |
+
vocab_size: int = 32768
|
55 |
+
vocab_width: int = 64
|
56 |
+
vocab_norm: bool = True
|
57 |
+
vq_beta: float = 0.25 # commitment loss weight
|
58 |
+
num_codebooks: int = 8
|
59 |
+
quant_proj: str = 'attn'
|
60 |
+
|
61 |
+
# model
|
62 |
+
embed_dim: int = 768
|
63 |
+
num_query: int = 0
|
64 |
+
use_clip_pretrain: bool = False
|
65 |
+
patch_size: int = 16
|
66 |
+
drop_path: float = 0.1
|
67 |
+
text_width: int = 768
|
68 |
+
text_heads: int = 12
|
69 |
+
text_layers: int = 12
|
70 |
+
text_vocab_size: int = 49408
|
71 |
+
text_context_length: int = 77
|
72 |
+
|
73 |
+
# CLIP
|
74 |
+
local_loss: bool = True
|
75 |
+
gather_with_grad: bool = True
|
76 |
+
pretrained_clip: str = None
|
77 |
+
pretrained_clip_text: str = None
|
78 |
+
lock_text: bool = False
|
79 |
+
lock_text_unlocked_layers: int = 0
|
80 |
+
lock_text_freeze_layer_norm: bool = False
|
81 |
+
force_custom_text: bool = False
|
82 |
+
force_custom_vision: bool = False
|
83 |
+
zeroshot_eval_freq: int = 1
|
84 |
+
|
85 |
+
# discriminator
|
86 |
+
dino_depth: int = 12
|
87 |
+
dino_kernel_size: int = 9
|
88 |
+
disc_norm: str = 'gn' # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm
|
89 |
+
disc_aug_prob: float = 1.0
|
90 |
+
disc_specnorm: bool = False
|
91 |
+
step_disc_every: int = 1
|
92 |
+
|
93 |
+
# initialization
|
94 |
+
vae_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
|
95 |
+
vocab_init: float = -1 # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init)
|
96 |
+
disc_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
|
97 |
+
|
98 |
+
# optimization
|
99 |
+
epoch: int = 1 # number of epochs
|
100 |
+
local_bs: int = 64 # batch size per device; if this is specified, --global_bs will be ignored
|
101 |
+
vae_local_bs: int = 64 # sub-batch size for vae loss calculation
|
102 |
+
global_bs: int = 0 # global batch size (exclusive to --local_bs)
|
103 |
+
lr: float = 5e-4 # learning rate
|
104 |
+
wd: float = 0.02 # weight decay
|
105 |
+
disc_lr: float = 2e-5 # disc lr
|
106 |
+
disc_wd: float = 0.2
|
107 |
+
grad_clip: float = 10 # <=0 for not using grad clip
|
108 |
+
ema: float = 0.9999 # ema ratio
|
109 |
+
warmup_iter: int = None
|
110 |
+
warmup_ep: float = 0.01 # lr warmup: epochs
|
111 |
+
disc_start_ep: float = 0.375 # start using disc loss for VAE after xxx epochs;
|
112 |
+
disc_warmup_ep: float = 0.03 # disc loss warm up epochs;
|
113 |
+
schedule: str = 'cos' # lr schedule type
|
114 |
+
lr_start_ratio: float = 0. # lr warmup: initial lr ratio
|
115 |
+
lr_end_ratio: float = 0.1 # lr schedule: final lr ratio
|
116 |
+
disc_lr_end_ratio: float = 0.1
|
117 |
+
custom_lr_multiplier: float = None
|
118 |
+
optimizer: str = 'adamw'
|
119 |
+
optim_eps: float = 1e-6
|
120 |
+
fuse_opt: bool = False # whether to use fused optimizer
|
121 |
+
optim_beta: str = '0.9_0.95' # beta1, beta2 of optimizer
|
122 |
+
disc_optim_beta: str = '0.5_0.9' # beta1, beta2 of disc optimizer
|
123 |
+
|
124 |
+
# loss
|
125 |
+
l1: float = 0.2 # L1 rec loss weight
|
126 |
+
l2: float = 1.0 # L2 rec loss weight
|
127 |
+
lp: float = 1.0 # lpips loss weight
|
128 |
+
lpr: int = 48 # only calculate lpips >= this image resolution
|
129 |
+
ld: float = 0.4 # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT
|
130 |
+
le: float = 0.0 # VQ entropy loss weight
|
131 |
+
lq: float = 1.0
|
132 |
+
lc: float = 1.0 # CLIP loss weight
|
133 |
+
e_temp: float = 0.01
|
134 |
+
gada: int = 1
|
135 |
+
bcr: float = 4. # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0
|
136 |
+
bcr_cut: float = 0.2 # cutout ratio (0.5: 50% width)
|
137 |
+
dcrit: str = 'hg' # hg hinge, sp softplus, ln linear
|
138 |
+
|
139 |
+
# wandb log
|
140 |
+
report_wandb: bool = True
|
141 |
+
wandb_notes: str = None
|
142 |
+
run_id: str = None
|
143 |
+
|
144 |
+
# debug
|
145 |
+
eval_per_epoch: int = 8
|
146 |
+
dbg_unused_param: bool = False
|
147 |
+
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
|
148 |
+
seed: int = None
|
149 |
+
deterministic: bool = False
|
150 |
+
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
|
151 |
+
|
152 |
+
|
153 |
+
def seed_everything(self):
|
154 |
+
torch.backends.cudnn.enabled = True
|
155 |
+
torch.backends.cudnn.benchmark = True
|
156 |
+
torch.backends.cudnn.deterministic = False
|
157 |
+
if self.seed is not None:
|
158 |
+
if self.deterministic:
|
159 |
+
torch.backends.cudnn.benchmark = False
|
160 |
+
torch.backends.cudnn.deterministic = True
|
161 |
+
torch.use_deterministic_algorithms(True)
|
162 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
|
163 |
+
seed = self.seed + dist.get_rank() * 10000
|
164 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
165 |
+
random.seed(seed)
|
166 |
+
np.random.seed(seed)
|
167 |
+
torch.manual_seed(seed)
|
168 |
+
torch.cuda.manual_seed(seed)
|
169 |
+
torch.cuda.manual_seed_all(seed)
|
170 |
+
|
171 |
+
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
|
172 |
+
if self.seed is None:
|
173 |
+
return None
|
174 |
+
g = torch.Generator()
|
175 |
+
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
|
176 |
+
return g
|
177 |
+
|
178 |
+
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
|
179 |
+
d = (OrderedDict if key_ordered else dict)()
|
180 |
+
for k in self.class_variables.keys():
|
181 |
+
if k not in {'device'}: # these are not serializable
|
182 |
+
d[k] = getattr(self, k)
|
183 |
+
return d
|
184 |
+
|
185 |
+
def load_state_dict(self, state_dict):
|
186 |
+
for k, v in state_dict.items():
|
187 |
+
try:
|
188 |
+
setattr(self, k, v)
|
189 |
+
except Exception as e:
|
190 |
+
print(f'k={k}, v={v}')
|
191 |
+
raise e
|
192 |
+
|
193 |
+
@staticmethod
|
194 |
+
def set_tf32(tf32: bool):
|
195 |
+
if torch.cuda.is_available():
|
196 |
+
torch.backends.cudnn.allow_tf32 = bool(tf32)
|
197 |
+
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
|
198 |
+
if hasattr(torch, 'set_float32_matmul_precision'):
|
199 |
+
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
|
200 |
+
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
|
201 |
+
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
|
202 |
+
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
|
203 |
+
|
204 |
+
def __str__(self):
|
205 |
+
s = []
|
206 |
+
for k in self.class_variables.keys():
|
207 |
+
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
|
208 |
+
s.append(f' {k:20s}: {getattr(self, k)}')
|
209 |
+
s = '\n'.join(s)
|
210 |
+
return f'{{\n{s}\n}}\n'
|
211 |
+
|
212 |
+
|
213 |
+
def init_dist_and_get_args():
|
214 |
+
for i in range(len(sys.argv)):
|
215 |
+
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
|
216 |
+
del sys.argv[i]
|
217 |
+
break
|
218 |
+
|
219 |
+
args = Args(explicit_bool=True).parse_args(known_only=True)
|
220 |
+
# warn args.extra_args
|
221 |
+
if len(args.extra_args) > 0:
|
222 |
+
print(f'======================================================================================')
|
223 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
|
224 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
|
225 |
+
print(f'======================================================================================\n\n')
|
226 |
+
|
227 |
+
# init torch distributed
|
228 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
229 |
+
dist.init_distributed_mode(local_out_path=args.output_dir, timeout_minutes=30)
|
230 |
+
|
231 |
+
# set env
|
232 |
+
args.set_tf32(args.tf32)
|
233 |
+
args.seed_everything()
|
234 |
+
args.device = dist.get_device()
|
235 |
+
|
236 |
+
# update args
|
237 |
+
if args.local_bs == 0:
|
238 |
+
args.local_bs = max(1, round(args.global_bs / args.grad_accu / dist.get_world_size()))
|
239 |
+
args.global_bs = args.local_bs * dist.get_world_size()
|
240 |
+
if args.fp16 or args.bf16:
|
241 |
+
args.dtype = torch.float16 if args.fp16 else torch.bfloat16
|
242 |
+
|
243 |
+
return args
|
unitok/dist.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import pytz
|
9 |
+
import torch
|
10 |
+
import torch.distributed as tdist
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
|
13 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
+
__rank_str_zfill = '0'
|
15 |
+
__initialized = False
|
16 |
+
|
17 |
+
|
18 |
+
def initialized():
|
19 |
+
return __initialized
|
20 |
+
|
21 |
+
|
22 |
+
def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
|
23 |
+
global __device
|
24 |
+
if not torch.cuda.is_available():
|
25 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
26 |
+
return
|
27 |
+
elif 'RANK' not in os.environ:
|
28 |
+
torch.cuda.set_device(gpu_id_if_not_distibuted)
|
29 |
+
__device = torch.empty(1).cuda().device
|
30 |
+
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
|
31 |
+
return
|
32 |
+
# then 'RANK' must exist
|
33 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
34 |
+
local_rank = global_rank % num_gpus
|
35 |
+
torch.cuda.set_device(local_rank)
|
36 |
+
|
37 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
38 |
+
if mp.get_start_method(allow_none=True) is None:
|
39 |
+
method = 'fork' if fork else 'spawn'
|
40 |
+
print(f'[dist initialize] mp method={method}')
|
41 |
+
mp.set_start_method(method)
|
42 |
+
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
|
43 |
+
|
44 |
+
global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
|
45 |
+
__local_rank = local_rank
|
46 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
47 |
+
__rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
|
48 |
+
__device = torch.empty(1).cuda().device
|
49 |
+
__initialized = True
|
50 |
+
|
51 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
52 |
+
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
|
53 |
+
|
54 |
+
|
55 |
+
def get_rank():
|
56 |
+
return __rank
|
57 |
+
|
58 |
+
|
59 |
+
def get_rank_str_zfill():
|
60 |
+
return __rank_str_zfill
|
61 |
+
|
62 |
+
|
63 |
+
def get_local_rank():
|
64 |
+
return __local_rank
|
65 |
+
|
66 |
+
|
67 |
+
def get_world_size():
|
68 |
+
return __world_size
|
69 |
+
|
70 |
+
|
71 |
+
def get_device():
|
72 |
+
return __device
|
73 |
+
|
74 |
+
|
75 |
+
def set_gpu_id(gpu_id: int):
|
76 |
+
if gpu_id is None: return
|
77 |
+
global __device
|
78 |
+
if isinstance(gpu_id, (str, int)):
|
79 |
+
torch.cuda.set_device(int(gpu_id))
|
80 |
+
__device = torch.empty(1).cuda().device
|
81 |
+
else:
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
|
85 |
+
def is_master():
|
86 |
+
return __rank == 0
|
87 |
+
|
88 |
+
|
89 |
+
def is_local_master():
|
90 |
+
return __local_rank == 0
|
91 |
+
|
92 |
+
|
93 |
+
def new_group(ranks: List[int]):
|
94 |
+
if __initialized:
|
95 |
+
return tdist.new_group(ranks=ranks)
|
96 |
+
return None
|
97 |
+
|
98 |
+
|
99 |
+
def new_local_machine_group():
|
100 |
+
if __initialized:
|
101 |
+
cur_subgroup, subgroups = tdist.new_subgroups()
|
102 |
+
return cur_subgroup
|
103 |
+
return None
|
104 |
+
|
105 |
+
|
106 |
+
def barrier():
|
107 |
+
if __initialized:
|
108 |
+
tdist.barrier()
|
109 |
+
|
110 |
+
|
111 |
+
def allreduce(t: torch.Tensor, async_op=False):
|
112 |
+
if __initialized:
|
113 |
+
if not t.is_cuda:
|
114 |
+
cu = t.detach().cuda()
|
115 |
+
ret = tdist.all_reduce(cu, async_op=async_op)
|
116 |
+
t.copy_(cu.cpu())
|
117 |
+
else:
|
118 |
+
ret = tdist.all_reduce(t, async_op=async_op)
|
119 |
+
return ret
|
120 |
+
return None
|
121 |
+
|
122 |
+
|
123 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
124 |
+
if __initialized:
|
125 |
+
if not t.is_cuda:
|
126 |
+
t = t.cuda()
|
127 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
128 |
+
tdist.all_gather(ls, t)
|
129 |
+
else:
|
130 |
+
ls = [t]
|
131 |
+
if cat:
|
132 |
+
ls = torch.cat(ls, dim=0)
|
133 |
+
return ls
|
134 |
+
|
135 |
+
|
136 |
+
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
137 |
+
if __initialized:
|
138 |
+
if not t.is_cuda:
|
139 |
+
t = t.cuda()
|
140 |
+
|
141 |
+
t_size = torch.tensor(t.size(), device=t.device)
|
142 |
+
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
|
143 |
+
tdist.all_gather(ls_size, t_size)
|
144 |
+
|
145 |
+
max_B = max(size[0].item() for size in ls_size)
|
146 |
+
pad = max_B - t_size[0].item()
|
147 |
+
if pad:
|
148 |
+
pad_size = (pad, *t.size()[1:])
|
149 |
+
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
|
150 |
+
|
151 |
+
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
|
152 |
+
tdist.all_gather(ls_padded, t)
|
153 |
+
ls = []
|
154 |
+
for t, size in zip(ls_padded, ls_size):
|
155 |
+
ls.append(t[:size[0].item()])
|
156 |
+
else:
|
157 |
+
ls = [t]
|
158 |
+
if cat:
|
159 |
+
ls = torch.cat(ls, dim=0)
|
160 |
+
return ls
|
161 |
+
|
162 |
+
|
163 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
164 |
+
if __initialized:
|
165 |
+
if not t.is_cuda:
|
166 |
+
cu = t.detach().cuda()
|
167 |
+
tdist.broadcast(cu, src=src_rank)
|
168 |
+
t.copy_(cu.cpu())
|
169 |
+
else:
|
170 |
+
tdist.broadcast(t, src=src_rank)
|
171 |
+
|
172 |
+
|
173 |
+
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
174 |
+
if not initialized():
|
175 |
+
return torch.tensor([val]) if fmt is None else [fmt % val]
|
176 |
+
|
177 |
+
ts = torch.zeros(__world_size)
|
178 |
+
ts[__rank] = val
|
179 |
+
allreduce(ts)
|
180 |
+
if fmt is None:
|
181 |
+
return ts
|
182 |
+
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
183 |
+
|
184 |
+
|
185 |
+
def master_only(func):
|
186 |
+
@functools.wraps(func)
|
187 |
+
def wrapper(*args, **kwargs):
|
188 |
+
force = kwargs.pop('force', False)
|
189 |
+
if force or is_master():
|
190 |
+
ret = func(*args, **kwargs)
|
191 |
+
else:
|
192 |
+
ret = None
|
193 |
+
barrier()
|
194 |
+
return ret
|
195 |
+
return wrapper
|
196 |
+
|
197 |
+
|
198 |
+
def local_master_only(func):
|
199 |
+
@functools.wraps(func)
|
200 |
+
def wrapper(*args, **kwargs):
|
201 |
+
force = kwargs.pop('force', False)
|
202 |
+
if force or is_local_master():
|
203 |
+
ret = func(*args, **kwargs)
|
204 |
+
else:
|
205 |
+
ret = None
|
206 |
+
barrier()
|
207 |
+
return ret
|
208 |
+
return wrapper
|
209 |
+
|
210 |
+
|
211 |
+
def for_visualize(func):
|
212 |
+
@functools.wraps(func)
|
213 |
+
def wrapper(*args, **kwargs):
|
214 |
+
if is_master():
|
215 |
+
# with torch.no_grad():
|
216 |
+
ret = func(*args, **kwargs)
|
217 |
+
else:
|
218 |
+
ret = None
|
219 |
+
return ret
|
220 |
+
return wrapper
|
221 |
+
|
222 |
+
|
223 |
+
def finalize():
|
224 |
+
if __initialized:
|
225 |
+
tdist.destroy_process_group()
|
226 |
+
|
227 |
+
|
228 |
+
def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30):
|
229 |
+
try:
|
230 |
+
__initialize(fork=False, timeout_minutes=timeout_minutes)
|
231 |
+
barrier()
|
232 |
+
except RuntimeError as e:
|
233 |
+
print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
|
234 |
+
raise e
|
235 |
+
|
236 |
+
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
|
237 |
+
_change_builtin_print(is_local_master())
|
238 |
+
if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
|
239 |
+
sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
|
240 |
+
|
241 |
+
|
242 |
+
def _change_builtin_print(is_master):
|
243 |
+
import builtins as __builtin__
|
244 |
+
|
245 |
+
builtin_print = __builtin__.print
|
246 |
+
if type(builtin_print) != type(open):
|
247 |
+
return
|
248 |
+
|
249 |
+
def prt(*args, **kwargs):
|
250 |
+
force = kwargs.pop('force', False)
|
251 |
+
clean = kwargs.pop('clean', False)
|
252 |
+
deeper = kwargs.pop('deeper', False)
|
253 |
+
if is_master or force:
|
254 |
+
if not clean:
|
255 |
+
f_back = sys._getframe().f_back
|
256 |
+
if deeper and f_back.f_back is not None:
|
257 |
+
f_back = f_back.f_back
|
258 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
259 |
+
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
|
260 |
+
builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
|
261 |
+
else:
|
262 |
+
builtin_print(*args, **kwargs)
|
263 |
+
|
264 |
+
__builtin__.print = prt
|
265 |
+
|
266 |
+
|
267 |
+
class BackupStreamToFile(object):
|
268 |
+
def __init__(self, local_output_dir, for_stdout=True):
|
269 |
+
self.for_stdout = for_stdout
|
270 |
+
self.terminal_stream = sys.stdout if for_stdout else sys.stderr
|
271 |
+
fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt')
|
272 |
+
existing = os.path.exists(fname)
|
273 |
+
self.file_stream = open(fname, 'a')
|
274 |
+
if existing:
|
275 |
+
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
|
276 |
+
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
|
277 |
+
self.file_stream.flush()
|
278 |
+
self.enabled = True
|
279 |
+
|
280 |
+
def write(self, message):
|
281 |
+
self.terminal_stream.write(message)
|
282 |
+
self.file_stream.write(message)
|
283 |
+
|
284 |
+
def flush(self):
|
285 |
+
self.terminal_stream.flush()
|
286 |
+
self.file_stream.flush()
|
287 |
+
|
288 |
+
def close(self):
|
289 |
+
if not self.enabled:
|
290 |
+
return
|
291 |
+
self.enabled = False
|
292 |
+
self.file_stream.flush()
|
293 |
+
self.file_stream.close()
|
294 |
+
if self.for_stdout:
|
295 |
+
sys.stdout = self.terminal_stream
|
296 |
+
sys.stdout.flush()
|
297 |
+
else:
|
298 |
+
sys.stderr = self.terminal_stream
|
299 |
+
sys.stderr.flush()
|
300 |
+
|
301 |
+
def __del__(self):
|
302 |
+
self.close()
|
unitok/model.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from contextlib import nullcontext
|
7 |
+
|
8 |
+
from unitok.vitamin import GeGluMlp, ViTaminDecoder
|
9 |
+
from unitok.quant import VectorQuantizerM
|
10 |
+
from unitok.vqvae import AttnProjection
|
11 |
+
|
12 |
+
|
13 |
+
class UniTok(nn.Module):
|
14 |
+
def __init__(self, args):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.num_query = args.num_query
|
18 |
+
|
19 |
+
self.encoder = timm.create_model(
|
20 |
+
args.model,
|
21 |
+
patch_size=1,
|
22 |
+
fc_norm=False,
|
23 |
+
drop_rate=0.0,
|
24 |
+
num_classes=0,
|
25 |
+
global_pool='',
|
26 |
+
pos_embed='none',
|
27 |
+
class_token=False,
|
28 |
+
mlp_layer=GeGluMlp,
|
29 |
+
reg_tokens=args.num_query,
|
30 |
+
img_size=args.img_size,
|
31 |
+
drop_path_rate=args.drop_path,
|
32 |
+
)
|
33 |
+
self.encoder.pos_embed = nn.Parameter(torch.zeros(1, 1, self.encoder.embed_dim), requires_grad=False)
|
34 |
+
|
35 |
+
if args.quant_proj == 'linear':
|
36 |
+
self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width)
|
37 |
+
elif args.quant_proj == 'attn':
|
38 |
+
self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, self.encoder.embed_dim // args.vocab_width)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
self.quantizer = VectorQuantizerM(
|
43 |
+
vocab_size=args.vocab_size,
|
44 |
+
vocab_width=args.vocab_width,
|
45 |
+
beta=args.vq_beta,
|
46 |
+
use_entropy_loss=args.le > 0,
|
47 |
+
entropy_temp=args.e_temp,
|
48 |
+
num_codebooks=args.num_codebooks,
|
49 |
+
)
|
50 |
+
|
51 |
+
if args.quant_proj == 'linear':
|
52 |
+
self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim)
|
53 |
+
elif args.quant_proj == 'attn':
|
54 |
+
self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, self.encoder.embed_dim // args.vocab_width)
|
55 |
+
else:
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
self.decoder = ViTaminDecoder(
|
59 |
+
args.model,
|
60 |
+
num_query=args.num_query,
|
61 |
+
img_size=args.img_size,
|
62 |
+
drop_path=args.drop_path,
|
63 |
+
grad_ckpt=args.grad_ckpt,
|
64 |
+
)
|
65 |
+
|
66 |
+
text_cfg = {
|
67 |
+
"width": args.text_width,
|
68 |
+
"heads": args.text_heads,
|
69 |
+
"layers": args.text_layers,
|
70 |
+
"vocab_size": args.text_vocab_size,
|
71 |
+
"context_length": args.text_context_length,
|
72 |
+
}
|
73 |
+
from open_clip.model import _build_text_tower
|
74 |
+
self.text_encoder = _build_text_tower(args.embed_dim, text_cfg)
|
75 |
+
|
76 |
+
self.fc_norm = nn.LayerNorm(self.encoder.embed_dim, eps=1e-6)
|
77 |
+
self.projection = nn.Linear(self.encoder.embed_dim, args.embed_dim)
|
78 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
79 |
+
|
80 |
+
self.context_length = self.text_encoder.context_length
|
81 |
+
self.vocab_size = self.text_encoder.vocab_size
|
82 |
+
self.maybe_record_function = nullcontext
|
83 |
+
|
84 |
+
self.text_no_grad = False
|
85 |
+
self.encoder.set_grad_checkpointing(args.grad_ckpt)
|
86 |
+
self.text_encoder.set_grad_checkpointing(args.grad_ckpt)
|
87 |
+
|
88 |
+
def forward(self, img, vae_bs, text=None, ret_usages=False):
|
89 |
+
img_tokens = self.encoder(img).float()
|
90 |
+
with torch.cuda.amp.autocast(enabled=False):
|
91 |
+
img_tokens = torch.utils.checkpoint.checkpoint(self.quant_proj, img_tokens, use_reentrant=False)
|
92 |
+
img_tokens, vq_loss, entropy_loss, usages = self.quantizer(img_tokens)
|
93 |
+
img_tokens = torch.utils.checkpoint.checkpoint(self.post_quant_proj, img_tokens, use_reentrant=False)
|
94 |
+
img_rec = self.decoder(img_tokens[:vae_bs]).float()
|
95 |
+
|
96 |
+
clip_visual = img_tokens.mean(dim=1)
|
97 |
+
clip_visual = self.projection(self.fc_norm(clip_visual))
|
98 |
+
clip_visual = F.normalize(clip_visual, dim=-1)
|
99 |
+
if text is not None:
|
100 |
+
clip_text = self.text_encoder(text)
|
101 |
+
clip_text = F.normalize(clip_text, dim=-1)
|
102 |
+
else:
|
103 |
+
clip_text = None
|
104 |
+
|
105 |
+
output_dict = {
|
106 |
+
"img_rec": img_rec,
|
107 |
+
"vq_loss": vq_loss,
|
108 |
+
"entropy_loss": entropy_loss,
|
109 |
+
"codebook_usages": usages,
|
110 |
+
"clip_image_features": clip_visual,
|
111 |
+
"clip_text_features": clip_text,
|
112 |
+
"logit_scale": self.logit_scale.exp()
|
113 |
+
}
|
114 |
+
return output_dict
|
115 |
+
|
116 |
+
def encode_image(self, image, normalize: bool = False):
|
117 |
+
img_tokens = self.encoder(image)
|
118 |
+
img_tokens = self.quant_proj(img_tokens)
|
119 |
+
img_indices = self.quantizer.f_to_idx(img_tokens)
|
120 |
+
img_tokens = self.quantizer.idx_to_f(img_indices)
|
121 |
+
img_tokens = self.post_quant_proj(img_tokens)
|
122 |
+
features = img_tokens.mean(dim=1)
|
123 |
+
features = self.projection(self.fc_norm(features))
|
124 |
+
return F.normalize(features, dim=-1) if normalize else features
|
125 |
+
|
126 |
+
def encode_text(self, text, normalize: bool = False):
|
127 |
+
features = self.text_encoder(text)
|
128 |
+
return F.normalize(features, dim=-1) if normalize else features
|
129 |
+
|
130 |
+
def img_to_idx(self, img):
|
131 |
+
features = self.encoder(img).float()
|
132 |
+
features = self.quant_proj(features)
|
133 |
+
return self.quantizer.f_to_idx(features)
|
134 |
+
|
135 |
+
def idx_to_img(self, indices):
|
136 |
+
features = self.quantizer.idx_to_f(indices)
|
137 |
+
features = self.post_quant_proj(features)
|
138 |
+
img = self.decoder(features).clamp_(-1, 1)
|
139 |
+
return img
|
140 |
+
|
141 |
+
def img_to_reconstructed_img(self, image) -> torch.Tensor:
|
142 |
+
img_tokens = self.encoder(image)
|
143 |
+
img_tokens = self.quant_proj(img_tokens)
|
144 |
+
img_tokens, _, _, _ = self.quantizer(img_tokens)
|
145 |
+
img_tokens = self.post_quant_proj(img_tokens)
|
146 |
+
img_rec = self.decoder(img_tokens).clamp_(-1, 1)
|
147 |
+
return img_rec
|
148 |
+
|
149 |
+
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False):
|
150 |
+
self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj)
|
151 |
+
self.text_no_grad = True
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
model = timm.create_model(
|
156 |
+
'vitamin_base',
|
157 |
+
patch_size=1,
|
158 |
+
fc_norm=True,
|
159 |
+
drop_rate=0.0,
|
160 |
+
num_classes=0,
|
161 |
+
global_pool='',
|
162 |
+
pos_embed='none',
|
163 |
+
class_token=False,
|
164 |
+
mlp_layer=GeGluMlp,
|
165 |
+
reg_tokens=0,
|
166 |
+
img_size=256,
|
167 |
+
drop_path_rate=0.1,
|
168 |
+
)
|
169 |
+
model.pos_embed = nn.Parameter(torch.zeros(1, 1, model.embed_dim), requires_grad=False)
|
170 |
+
|
171 |
+
model_dict = model.state_dict()
|
172 |
+
ckpt_dict = torch.load('ViTamin-B/pytorch_model.bin')
|
173 |
+
visual_dict = dict()
|
174 |
+
for k, v in ckpt_dict.items():
|
175 |
+
if k.startswith('visual.'):
|
176 |
+
if 'head' in k or 'pos_embed' in k:
|
177 |
+
continue
|
178 |
+
new_k = k.replace('visual.trunk.', '')
|
179 |
+
visual_dict[new_k] = v
|
180 |
+
|
181 |
+
model.load_state_dict(visual_dict, strict=False)
|
182 |
+
print(set(model_dict.keys()) - set(visual_dict.keys()))
|
183 |
+
print(set(visual_dict.keys() - set(model_dict.keys())))
|
184 |
+
|
unitok/quant.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Tuple
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch import distributed as tdist, nn as nn
|
5 |
+
|
6 |
+
from unitok import dist
|
7 |
+
|
8 |
+
|
9 |
+
def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
|
10 |
+
E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
|
11 |
+
E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
|
12 |
+
logits = -E_dist.float().mul_(inv_entropy_tau)
|
13 |
+
# calc per_sample_entropy
|
14 |
+
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
|
15 |
+
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
|
16 |
+
# calc codebook_entropy
|
17 |
+
avg_prob = prob.mean(dim=0) # (vocab_size,)
|
18 |
+
log_avg_prob = torch.log(avg_prob + 1e-7)
|
19 |
+
codebook_entropy = (-avg_prob * log_avg_prob).sum()
|
20 |
+
# calc entropy_loss
|
21 |
+
entropy_loss = per_sample_entropy - codebook_entropy
|
22 |
+
return entropy_loss
|
23 |
+
|
24 |
+
|
25 |
+
class NormalizedEmbedding(nn.Embedding):
|
26 |
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
27 |
+
super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
28 |
+
# self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
|
29 |
+
|
30 |
+
def forward(self, idx):
|
31 |
+
return F.embedding(
|
32 |
+
idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
|
33 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse
|
34 |
+
)
|
35 |
+
|
36 |
+
def get_norm_weight(self):
|
37 |
+
return F.normalize(self.weight, dim=1)
|
38 |
+
|
39 |
+
|
40 |
+
class ResConv(nn.Conv2d):
|
41 |
+
def __init__(self, embed_dim, quant_resi):
|
42 |
+
ks = 3 if quant_resi < 0 else 1
|
43 |
+
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
|
44 |
+
self.resi_ratio = abs(quant_resi)
|
45 |
+
|
46 |
+
def forward(self, h_BChw):
|
47 |
+
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
|
48 |
+
|
49 |
+
|
50 |
+
class VectorQuantizer(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
vocab_size: int,
|
54 |
+
vocab_width: int,
|
55 |
+
beta: float = 0.25,
|
56 |
+
use_entropy_loss=False,
|
57 |
+
entropy_temp=0.01,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.beta = beta
|
61 |
+
self.vocab_size = vocab_size
|
62 |
+
self.vocab_width = vocab_width
|
63 |
+
self.vocab_usage_record_times: int = 0
|
64 |
+
self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
|
65 |
+
self.codebook = NormalizedEmbedding(self.vocab_size, self.vocab_width)
|
66 |
+
|
67 |
+
self.use_entropy_loss = use_entropy_loss
|
68 |
+
self.inv_entropy_tau = 1 / entropy_temp
|
69 |
+
|
70 |
+
def init_vocab(self, eini: float):
|
71 |
+
if eini > 0:
|
72 |
+
nn.init.trunc_normal_(self.codebook.weight.data, std=eini)
|
73 |
+
elif eini < 0:
|
74 |
+
base = self.vocab_width ** -0.5
|
75 |
+
base /= 36
|
76 |
+
self.codebook.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
|
77 |
+
|
78 |
+
def extra_repr(self) -> str:
|
79 |
+
return f'beta={self.beta:g}'
|
80 |
+
|
81 |
+
def forward(self, features):
|
82 |
+
B, L, C = features.shape
|
83 |
+
features = features.reshape(-1, C)
|
84 |
+
features = F.normalize(features, dim=-1).float()
|
85 |
+
codebook_embed = self.codebook.get_norm_weight()
|
86 |
+
indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
|
87 |
+
entropy_loss = get_entropy_loss(features, codebook_embed, self.inv_entropy_tau) if self.use_entropy_loss else 0
|
88 |
+
features_hat = self.codebook(indices)
|
89 |
+
|
90 |
+
# calc loss
|
91 |
+
vq_loss = F.mse_loss(features_hat.detach(), features).mul_(self.beta) + F.mse_loss(features_hat,
|
92 |
+
features.detach())
|
93 |
+
features_hat = (features_hat.detach() - features.detach()).add_(features)
|
94 |
+
|
95 |
+
# update vocab_usage
|
96 |
+
prob_per_class_is_chosen = indices.bincount(minlength=self.vocab_size).float()
|
97 |
+
handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
|
98 |
+
self.training and dist.initialized()) else None
|
99 |
+
if handler is not None:
|
100 |
+
handler.wait()
|
101 |
+
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
|
102 |
+
vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
|
103 |
+
if self.vocab_usage_record_times == 0:
|
104 |
+
self.vocab_usage.copy_(prob_per_class_is_chosen)
|
105 |
+
elif self.vocab_usage_record_times < 100:
|
106 |
+
self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
|
107 |
+
else:
|
108 |
+
self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
|
109 |
+
self.vocab_usage_record_times += 1
|
110 |
+
|
111 |
+
return features_hat.view(B, L, C), vq_loss, entropy_loss, vocab_usage
|
112 |
+
|
113 |
+
def f_to_idx(self, features):
|
114 |
+
B, L, C = features.shape
|
115 |
+
features = features.reshape(-1, C)
|
116 |
+
features = F.normalize(features, dim=-1).float()
|
117 |
+
codebook_embed = self.codebook.get_norm_weight().float()
|
118 |
+
indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
|
119 |
+
return indices.view(B, L)
|
120 |
+
|
121 |
+
|
122 |
+
class VectorQuantizerM(nn.Module):
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
vocab_size,
|
126 |
+
vocab_width,
|
127 |
+
beta=0.25,
|
128 |
+
use_entropy_loss=False,
|
129 |
+
entropy_temp=0.01,
|
130 |
+
num_codebooks=16
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
self.num_codebooks = num_codebooks
|
134 |
+
self.codebooks = nn.ModuleList()
|
135 |
+
for _ in range(num_codebooks):
|
136 |
+
codebook = VectorQuantizer(
|
137 |
+
vocab_size=vocab_size // num_codebooks,
|
138 |
+
vocab_width=vocab_width // num_codebooks,
|
139 |
+
beta=beta,
|
140 |
+
use_entropy_loss=use_entropy_loss,
|
141 |
+
entropy_temp=entropy_temp,
|
142 |
+
)
|
143 |
+
self.codebooks.append(codebook)
|
144 |
+
|
145 |
+
def init_vocab(self, eini: float):
|
146 |
+
for codebook in self.codebooks:
|
147 |
+
codebook.init_vocab(eini)
|
148 |
+
|
149 |
+
def f_to_idx(self, features):
|
150 |
+
indices = []
|
151 |
+
chunk_size = features.shape[-1] // self.num_codebooks
|
152 |
+
splited_features = features.split(chunk_size, dim=-1)
|
153 |
+
for i, codebook in enumerate(self.codebooks):
|
154 |
+
indices.append(codebook.f_to_idx(splited_features[i]))
|
155 |
+
indices = torch.stack(indices, dim=1)
|
156 |
+
return indices
|
157 |
+
|
158 |
+
def idx_to_f(self, indices):
|
159 |
+
assert indices.shape[1] == self.num_codebooks
|
160 |
+
latent_features = []
|
161 |
+
for i, codebook in enumerate(self.codebooks):
|
162 |
+
sub_indices = indices[:, i].flatten(start_dim=1)
|
163 |
+
latent_feature = codebook.codebook(sub_indices)
|
164 |
+
latent_features.append(latent_feature)
|
165 |
+
latent_features = torch.cat(latent_features, dim=-1)
|
166 |
+
return latent_features
|
167 |
+
|
168 |
+
def forward(self, features):
|
169 |
+
latent_features = []
|
170 |
+
global_vq_loss = 0.
|
171 |
+
global_entropy_loss = 0.
|
172 |
+
global_vocab_usage = 0.
|
173 |
+
chunk_size = features.shape[-1] // self.num_codebooks
|
174 |
+
splited_features = features.split(chunk_size, dim=-1)
|
175 |
+
for i, codebook in enumerate(self.codebooks):
|
176 |
+
latent_feature, vq_loss, entropy_loss, vocab_usage = codebook(splited_features[i])
|
177 |
+
latent_features.append(latent_feature)
|
178 |
+
global_vq_loss += vq_loss
|
179 |
+
global_entropy_loss += entropy_loss
|
180 |
+
global_vocab_usage += vocab_usage
|
181 |
+
latent_features = torch.cat(latent_features, dim=-1)
|
182 |
+
global_entropy_loss /= self.num_codebooks
|
183 |
+
global_vq_loss /= self.num_codebooks
|
184 |
+
global_vocab_usage /= self.num_codebooks
|
185 |
+
return latent_features, global_vq_loss, global_entropy_loss, global_vocab_usage
|
unitok/vitamin.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO: FIXME:
|
3 |
+
/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py:251: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
|
4 |
+
grad.sizes() = [256, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
|
5 |
+
bucket_view.sizes() = [256, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at ../torch/csrc/distributed/c10d/reducer.cpp:334.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
|
6 |
+
/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py:251: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
|
7 |
+
grad.sizes() = [256, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
|
8 |
+
|
9 |
+
"""
|
10 |
+
|
11 |
+
""" ViTamin
|
12 |
+
|
13 |
+
Paper: Designing Scalable Vison Models in the Vision-Language Era
|
14 |
+
|
15 |
+
@misc{chen2023designing,
|
16 |
+
title={Designing Scalable Vison Models in the Vision-Language Era},
|
17 |
+
author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
|
18 |
+
year={2023},
|
19 |
+
archivePrefix={arXiv},
|
20 |
+
primaryClass={cs.CV}
|
21 |
+
}
|
22 |
+
|
23 |
+
Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
|
24 |
+
|
25 |
+
Modifications and timm support by Jieneng Chen 2023
|
26 |
+
|
27 |
+
Reference:
|
28 |
+
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
29 |
+
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
30 |
+
"""
|
31 |
+
|
32 |
+
import math
|
33 |
+
from dataclasses import dataclass
|
34 |
+
from functools import partial
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from typing import Optional, Tuple, Union
|
37 |
+
|
38 |
+
import timm
|
39 |
+
import torch
|
40 |
+
import torch.nn as nn
|
41 |
+
from timm.layers import to_2tuple
|
42 |
+
from timm.layers.norm_act import _create_act
|
43 |
+
from timm.models._builder import build_model_with_cfg
|
44 |
+
from timm.models._manipulate import checkpoint_seq, named_apply
|
45 |
+
from timm.models._registry import register_model
|
46 |
+
from timm.models.layers import DropPath
|
47 |
+
from timm.models.layers import create_conv2d, get_norm_act_layer, get_norm_layer, make_divisible
|
48 |
+
from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
|
49 |
+
from timm.models.vision_transformer_hybrid import HybridEmbed
|
50 |
+
from torch.utils.checkpoint import checkpoint
|
51 |
+
|
52 |
+
DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)'
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class VitConvCfg:
|
57 |
+
expand_ratio: float = 4.0
|
58 |
+
expand_output: bool = True # calculate expansion channels from output (vs input chs)
|
59 |
+
kernel_size: int = 3
|
60 |
+
group_size: int = 1 # 1 == depthwise
|
61 |
+
pre_norm_act: bool = False # activation after pre-norm
|
62 |
+
stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
|
63 |
+
pool_type: str = 'avg2'
|
64 |
+
downsample_pool_type: str = 'avg2'
|
65 |
+
act_layer: str = 'gelu' # stem & stage 1234
|
66 |
+
act_layer1: str = 'gelu' # stage 1234
|
67 |
+
act_layer2: str = 'gelu' # stage 1234
|
68 |
+
norm_layer: str = ''
|
69 |
+
norm_layer_cl: str = ''
|
70 |
+
norm_eps: Optional[float] = None
|
71 |
+
down_shortcut: Optional[bool] = True
|
72 |
+
mlp: str = 'mlp'
|
73 |
+
|
74 |
+
def __post_init__(self):
|
75 |
+
# mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
|
76 |
+
use_mbconv = True
|
77 |
+
if not self.norm_layer:
|
78 |
+
self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
|
79 |
+
if not self.norm_layer_cl and not use_mbconv:
|
80 |
+
self.norm_layer_cl = 'layernorm'
|
81 |
+
if self.norm_eps is None:
|
82 |
+
self.norm_eps = 1e-5 if use_mbconv else 1e-6
|
83 |
+
self.downsample_pool_type = self.downsample_pool_type or self.pool_type
|
84 |
+
|
85 |
+
|
86 |
+
@dataclass
|
87 |
+
class VitCfg:
|
88 |
+
# embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
|
89 |
+
embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
|
90 |
+
depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
|
91 |
+
stem_width: int = 64
|
92 |
+
conv_cfg: VitConvCfg = None
|
93 |
+
weight_init: str = 'vit_eff'
|
94 |
+
head_type: str = ""
|
95 |
+
stem_type: str = "stem"
|
96 |
+
ln2d_permute: bool = True
|
97 |
+
# memory_format: str=""
|
98 |
+
|
99 |
+
|
100 |
+
def _init_conv(module, name, scheme=''):
|
101 |
+
if isinstance(module, nn.Conv2d):
|
102 |
+
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
103 |
+
fan_out //= module.groups
|
104 |
+
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
|
105 |
+
if module.bias is not None:
|
106 |
+
nn.init.zeros_(module.bias)
|
107 |
+
|
108 |
+
|
109 |
+
class Stem(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
in_chs: int,
|
113 |
+
out_chs: int,
|
114 |
+
act_layer: str = 'gelu',
|
115 |
+
norm_layer: str = 'layernorm2d',
|
116 |
+
norm_eps: float = 1e-6,
|
117 |
+
bias: bool = True,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.grad_checkpointing=False
|
121 |
+
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
122 |
+
self.out_chs = out_chs
|
123 |
+
self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
|
124 |
+
self.norm1 = norm_act_layer(out_chs)
|
125 |
+
self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
|
126 |
+
named_apply(_init_conv, self)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
if self.grad_checkpointing:
|
130 |
+
x = checkpoint(self.conv1, x)
|
131 |
+
x = self.norm1(x)
|
132 |
+
x = checkpoint(self.conv2, x)
|
133 |
+
else:
|
134 |
+
x = self.conv1(x)
|
135 |
+
x = self.norm1(x)
|
136 |
+
x = self.conv2(x)
|
137 |
+
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class Downsample2d(nn.Module):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
dim: int,
|
145 |
+
dim_out: int,
|
146 |
+
pool_type: str = 'avg2',
|
147 |
+
bias: bool = True,
|
148 |
+
):
|
149 |
+
super().__init__()
|
150 |
+
self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
|
151 |
+
|
152 |
+
if dim != dim_out:
|
153 |
+
self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv
|
154 |
+
else:
|
155 |
+
self.expand = nn.Identity()
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
x = self.pool(x) # spatial downsample
|
159 |
+
x = self.expand(x) # expand chs
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class StridedConv(nn.Module):
|
164 |
+
""" downsample 2d as well
|
165 |
+
"""
|
166 |
+
def __init__(
|
167 |
+
self,
|
168 |
+
kernel_size=3,
|
169 |
+
stride=2,
|
170 |
+
padding=1,
|
171 |
+
in_chans=3,
|
172 |
+
embed_dim=768,
|
173 |
+
ln2d_permute=True
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
177 |
+
self.permute = ln2d_permute # TODO: disable
|
178 |
+
norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
|
179 |
+
self.norm = norm_layer(in_chans) # affine over C
|
180 |
+
|
181 |
+
def forward(self, x):
|
182 |
+
x = self.norm(x)
|
183 |
+
x = self.proj(x)
|
184 |
+
return x
|
185 |
+
|
186 |
+
|
187 |
+
class MbConvLNBlock(nn.Module):
|
188 |
+
""" Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
|
189 |
+
"""
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
in_chs: int,
|
193 |
+
out_chs: int,
|
194 |
+
stride: int = 1,
|
195 |
+
drop_path: float = 0.,
|
196 |
+
kernel_size: int = 3,
|
197 |
+
norm_layer: str = 'layernorm2d',
|
198 |
+
norm_eps: float = 1e-6,
|
199 |
+
act_layer: str = 'gelu',
|
200 |
+
expand_ratio: float = 4.0,
|
201 |
+
):
|
202 |
+
super(MbConvLNBlock, self).__init__()
|
203 |
+
self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
|
204 |
+
mid_chs = make_divisible(out_chs * expand_ratio)
|
205 |
+
prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
206 |
+
|
207 |
+
if stride == 2:
|
208 |
+
self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True)
|
209 |
+
elif in_chs != out_chs:
|
210 |
+
self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
|
211 |
+
else:
|
212 |
+
self.shortcut = nn.Identity()
|
213 |
+
|
214 |
+
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
|
215 |
+
self.down = nn.Identity()
|
216 |
+
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
|
217 |
+
self.act1 = _create_act(act_layer, inplace=True)
|
218 |
+
self.act2 = _create_act(act_layer, inplace=True)
|
219 |
+
|
220 |
+
self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
|
221 |
+
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
|
222 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
223 |
+
|
224 |
+
def init_weights(self, scheme=''):
|
225 |
+
named_apply(partial(_init_conv, scheme=scheme), self)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
shortcut = self.shortcut(x)
|
229 |
+
|
230 |
+
x = self.pre_norm(x)
|
231 |
+
x = self.down(x) # nn.Identity()
|
232 |
+
|
233 |
+
# 1x1 expansion conv & act
|
234 |
+
x = self.conv1_1x1(x)
|
235 |
+
x = self.act1(x)
|
236 |
+
|
237 |
+
# (strided) depthwise 3x3 conv & act
|
238 |
+
x = self.conv2_kxk(x)
|
239 |
+
x = self.act2(x)
|
240 |
+
|
241 |
+
# 1x1 linear projection to output width
|
242 |
+
x = self.conv3_1x1(x)
|
243 |
+
x = self.drop_path(x) + shortcut
|
244 |
+
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
class MbConvStages(nn.Module):
|
249 |
+
""" MobileConv for stage 1 and stage 2 of ViTamin
|
250 |
+
"""
|
251 |
+
def __init__(
|
252 |
+
self,
|
253 |
+
cfg: VitCfg,
|
254 |
+
img_size: Union[int, Tuple[int, int]] = 224, # place holder
|
255 |
+
in_chans: int = 3,
|
256 |
+
):
|
257 |
+
super().__init__()
|
258 |
+
self.grad_checkpointing = False
|
259 |
+
self.stem = Stem(
|
260 |
+
in_chs=in_chans,
|
261 |
+
out_chs=cfg.stem_width,
|
262 |
+
)
|
263 |
+
stages = []
|
264 |
+
self.num_stages = len(cfg.embed_dim)
|
265 |
+
for s, dim in enumerate(cfg.embed_dim[:2]): # stage
|
266 |
+
blocks = []
|
267 |
+
stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
|
268 |
+
for d in range(cfg.depths[s]):
|
269 |
+
blocks += [MbConvLNBlock(
|
270 |
+
in_chs = stage_in_chs if d==0 else dim,
|
271 |
+
out_chs = dim,
|
272 |
+
stride = 2 if d == 0 else 1,
|
273 |
+
# cfg = cfg.conv_cfg,
|
274 |
+
)]
|
275 |
+
blocks = nn.Sequential(*blocks)
|
276 |
+
stages += [blocks]
|
277 |
+
|
278 |
+
self.stages = nn.ModuleList(stages)
|
279 |
+
self.pool = StridedConv(
|
280 |
+
stride=2,
|
281 |
+
in_chans=cfg.embed_dim[1],
|
282 |
+
embed_dim=cfg.embed_dim[2]
|
283 |
+
)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
x = self.stem(x)
|
287 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
288 |
+
for stage in self.stages:
|
289 |
+
x = checkpoint_seq(stage, x)
|
290 |
+
x = checkpoint(self.pool, x)
|
291 |
+
else:
|
292 |
+
for stage in self.stages:
|
293 |
+
x = stage(x)
|
294 |
+
x = self.pool(x)
|
295 |
+
|
296 |
+
return x
|
297 |
+
|
298 |
+
|
299 |
+
class GeGluMlp(nn.Module):
|
300 |
+
def __init__(
|
301 |
+
self,
|
302 |
+
in_features,
|
303 |
+
hidden_features,
|
304 |
+
act_layer = None,
|
305 |
+
drop = 0.0,
|
306 |
+
):
|
307 |
+
super().__init__()
|
308 |
+
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
|
309 |
+
self.norm = norm_layer(in_features)
|
310 |
+
self.act = nn.GELU(approximate='tanh')
|
311 |
+
self.w0 = nn.Linear(in_features, hidden_features)
|
312 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
313 |
+
self.w2 = nn.Linear(hidden_features, in_features)
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
x = self.norm(x)
|
317 |
+
x = self.act(self.w0(x)) * self.w1(x)
|
318 |
+
x = self.w2(x)
|
319 |
+
return x
|
320 |
+
|
321 |
+
|
322 |
+
class HybridEmbed(nn.Module):
|
323 |
+
""" CNN Feature Map Embedding
|
324 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
325 |
+
"""
|
326 |
+
def __init__(
|
327 |
+
self,
|
328 |
+
backbone,
|
329 |
+
img_size=256,
|
330 |
+
patch_size=1,
|
331 |
+
feature_size=None,
|
332 |
+
in_chans=3,
|
333 |
+
embed_dim=1024,
|
334 |
+
bias=True,
|
335 |
+
dynamic_img_pad=False,
|
336 |
+
):
|
337 |
+
super().__init__()
|
338 |
+
assert isinstance(backbone, nn.Module)
|
339 |
+
img_size = to_2tuple(img_size)
|
340 |
+
patch_size = to_2tuple(patch_size)
|
341 |
+
self.img_size = img_size
|
342 |
+
self.patch_size = patch_size
|
343 |
+
self.backbone = backbone
|
344 |
+
with torch.no_grad():
|
345 |
+
training = backbone.training
|
346 |
+
if training:
|
347 |
+
backbone.eval()
|
348 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
349 |
+
if isinstance(o, (list, tuple)):
|
350 |
+
o = o[-1] # last feature if backbone outputs list/tuple of features
|
351 |
+
feature_size = o.shape[-2:]
|
352 |
+
feature_dim = o.shape[1]
|
353 |
+
backbone.train(training)
|
354 |
+
|
355 |
+
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
356 |
+
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
357 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
358 |
+
self.proj = nn.Identity()
|
359 |
+
|
360 |
+
def forward(self, x):
|
361 |
+
x = self.backbone(x)
|
362 |
+
if isinstance(x, (list, tuple)):
|
363 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
364 |
+
x = self.proj(x)
|
365 |
+
x = x.flatten(2).transpose(1, 2)
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
class Upsample2d(nn.Module):
|
370 |
+
def __init__(self, dim, dim_out):
|
371 |
+
super().__init__()
|
372 |
+
self.conv = torch.nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
|
373 |
+
|
374 |
+
def forward(self, x):
|
375 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
376 |
+
x = self.conv(x)
|
377 |
+
return x
|
378 |
+
|
379 |
+
|
380 |
+
class InvMbConvLNBlock(nn.Module):
|
381 |
+
""" Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
|
382 |
+
"""
|
383 |
+
def __init__(
|
384 |
+
self,
|
385 |
+
in_chs: int,
|
386 |
+
out_chs: int,
|
387 |
+
stride: int = 1,
|
388 |
+
drop_path: float = 0.,
|
389 |
+
kernel_size: int = 3,
|
390 |
+
norm_layer: str = 'layernorm2d',
|
391 |
+
norm_eps: float = 1e-6,
|
392 |
+
act_layer: str = 'gelu',
|
393 |
+
expand_ratio: float = 4.0,
|
394 |
+
):
|
395 |
+
super().__init__()
|
396 |
+
self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
|
397 |
+
mid_chs = make_divisible(out_chs * expand_ratio)
|
398 |
+
prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
399 |
+
|
400 |
+
if stride == 2:
|
401 |
+
self.shortcut = Upsample2d(in_chs, out_chs)
|
402 |
+
elif in_chs != out_chs:
|
403 |
+
self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
|
404 |
+
else:
|
405 |
+
self.shortcut = nn.Identity()
|
406 |
+
|
407 |
+
self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
|
408 |
+
|
409 |
+
self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
|
410 |
+
self.act1 = _create_act(act_layer, inplace=True)
|
411 |
+
self.act2 = _create_act(act_layer, inplace=True)
|
412 |
+
|
413 |
+
self.up = Upsample2d(mid_chs, mid_chs) if stride == 2 else nn.Identity()
|
414 |
+
self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=1, dilation=1, groups=mid_chs, bias=True)
|
415 |
+
self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
|
416 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
417 |
+
|
418 |
+
def init_weights(self, scheme=''):
|
419 |
+
named_apply(partial(_init_conv, scheme=scheme), self)
|
420 |
+
|
421 |
+
def forward(self, x):
|
422 |
+
shortcut = self.shortcut(x)
|
423 |
+
x = self.pre_norm(x)
|
424 |
+
|
425 |
+
# 1x1 expansion conv & act
|
426 |
+
x = self.conv1_1x1(x)
|
427 |
+
x = self.act1(x)
|
428 |
+
x = self.up(x)
|
429 |
+
|
430 |
+
# (strided) depthwise 3x3 conv & act
|
431 |
+
x = self.conv2_kxk(x)
|
432 |
+
x = self.act2(x)
|
433 |
+
|
434 |
+
# 1x1 linear projection to output width
|
435 |
+
x = self.conv3_1x1(x)
|
436 |
+
x = self.drop_path(x) + shortcut
|
437 |
+
|
438 |
+
return x
|
439 |
+
|
440 |
+
|
441 |
+
class InvStem(nn.Module):
|
442 |
+
def __init__(
|
443 |
+
self,
|
444 |
+
in_chs: int,
|
445 |
+
out_chs: int,
|
446 |
+
act_layer: str = 'gelu',
|
447 |
+
norm_layer: str = 'layernorm2d',
|
448 |
+
norm_eps: float = 1e-6,
|
449 |
+
bias: bool = True,
|
450 |
+
):
|
451 |
+
super().__init__()
|
452 |
+
self.grad_checkpointing=False
|
453 |
+
norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
|
454 |
+
self.out_chs = out_chs
|
455 |
+
self.conv1 = Upsample2d(in_chs, in_chs)
|
456 |
+
self.norm1 = norm_act_layer(in_chs)
|
457 |
+
self.conv2 = create_conv2d(in_chs, out_chs, 3, stride=1, bias=bias)
|
458 |
+
named_apply(_init_conv, self)
|
459 |
+
|
460 |
+
def forward(self, x):
|
461 |
+
if self.grad_checkpointing:
|
462 |
+
x = checkpoint(self.conv1, x)
|
463 |
+
x = self.norm1(x)
|
464 |
+
x = checkpoint(self.conv2, x)
|
465 |
+
else:
|
466 |
+
x = self.conv1(x)
|
467 |
+
x = self.norm1(x)
|
468 |
+
x = self.conv2(x)
|
469 |
+
|
470 |
+
return x
|
471 |
+
|
472 |
+
|
473 |
+
class ViTaminDecoder(nn.Module):
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
model,
|
477 |
+
num_query=0,
|
478 |
+
img_size=256,
|
479 |
+
drop_path=0.,
|
480 |
+
depths=(4, 2),
|
481 |
+
grad_ckpt=False,
|
482 |
+
):
|
483 |
+
super().__init__()
|
484 |
+
|
485 |
+
self.num_query = num_query
|
486 |
+
vit = timm.create_model(
|
487 |
+
model,
|
488 |
+
fc_norm=False,
|
489 |
+
patch_size=1,
|
490 |
+
drop_rate=0.0,
|
491 |
+
num_classes=0,
|
492 |
+
global_pool='',
|
493 |
+
pos_embed='none',
|
494 |
+
mlp_layer=GeGluMlp,
|
495 |
+
class_token=False,
|
496 |
+
reg_tokens=num_query,
|
497 |
+
img_size=img_size,
|
498 |
+
drop_path_rate=drop_path,
|
499 |
+
)
|
500 |
+
self.blocks = vit.blocks
|
501 |
+
self.norm_pre = vit.norm_pre
|
502 |
+
self.norm = vit.norm
|
503 |
+
|
504 |
+
embed_dims = {
|
505 |
+
'vitamin_base': (768, 256, 128),
|
506 |
+
'vitamin_large': (1024, 320, 160)
|
507 |
+
}[model]
|
508 |
+
self.up_conv1 = Upsample2d(embed_dims[0], embed_dims[1])
|
509 |
+
self.up_conv2 = nn.Sequential(*[
|
510 |
+
InvMbConvLNBlock(
|
511 |
+
in_chs=embed_dims[1],
|
512 |
+
out_chs=embed_dims[1],
|
513 |
+
stride=2 if d == 0 else 1)
|
514 |
+
for d in range(depths[0])]
|
515 |
+
)
|
516 |
+
self.up_conv3 = nn.Sequential(*[
|
517 |
+
InvMbConvLNBlock(
|
518 |
+
in_chs=embed_dims[1] if d == 0 else embed_dims[2],
|
519 |
+
out_chs=embed_dims[2],
|
520 |
+
stride=2 if d == 0 else 1)
|
521 |
+
for d in range(depths[1])]
|
522 |
+
)
|
523 |
+
self.up_conv4 = InvStem(in_chs=embed_dims[2], out_chs=3)
|
524 |
+
|
525 |
+
self.grad_ckpt = grad_ckpt
|
526 |
+
|
527 |
+
def get_last_param(self):
|
528 |
+
return self.up_conv4.conv2.weight
|
529 |
+
|
530 |
+
def forward(self, x):
|
531 |
+
B, L, C = x.shape
|
532 |
+
H = W = int((L-self.num_query) ** 0.5)
|
533 |
+
x = self.norm_pre(x)
|
534 |
+
if self.grad_ckpt:
|
535 |
+
x = checkpoint_seq(self.blocks, x)
|
536 |
+
x = x[:, self.num_query:, :]
|
537 |
+
x = self.norm(x)
|
538 |
+
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
|
539 |
+
x = checkpoint(self.up_conv1, x)
|
540 |
+
x = checkpoint_seq(self.up_conv2, x)
|
541 |
+
x = checkpoint_seq(self.up_conv3, x)
|
542 |
+
else:
|
543 |
+
x = self.blocks(x)
|
544 |
+
x = x[:, self.num_query:, :]
|
545 |
+
x = self.norm(x)
|
546 |
+
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
|
547 |
+
x = self.up_conv1(x)
|
548 |
+
x = self.up_conv2(x)
|
549 |
+
x = self.up_conv3(x)
|
550 |
+
x = self.up_conv4(x)
|
551 |
+
return x
|
552 |
+
|
553 |
+
|
554 |
+
def _create_vision_transformer(variant, pretrained=False, grad_ckpt=False, **kwargs) -> VisionTransformer:
|
555 |
+
if kwargs.get('features_only', None):
|
556 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
557 |
+
|
558 |
+
if 'flexi' in variant:
|
559 |
+
# FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
|
560 |
+
# interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
|
561 |
+
_filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
|
562 |
+
else:
|
563 |
+
_filter_fn = checkpoint_filter_fn
|
564 |
+
|
565 |
+
return build_model_with_cfg(
|
566 |
+
VisionTransformer,
|
567 |
+
variant,
|
568 |
+
pretrained,
|
569 |
+
pretrained_filter_fn=_filter_fn,
|
570 |
+
**kwargs,
|
571 |
+
)
|
572 |
+
|
573 |
+
|
574 |
+
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
575 |
+
embed_layer = partial(HybridEmbed, backbone=backbone)
|
576 |
+
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
577 |
+
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
|
578 |
+
|
579 |
+
|
580 |
+
@register_model
|
581 |
+
def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
|
582 |
+
stage_1_2 = MbConvStages(cfg=VitCfg(
|
583 |
+
embed_dim=(64, 128, 384),
|
584 |
+
depths=(2, 4, 1),
|
585 |
+
stem_width=64,
|
586 |
+
conv_cfg = VitConvCfg(
|
587 |
+
norm_layer='layernorm2d',
|
588 |
+
norm_eps=1e-6,
|
589 |
+
),
|
590 |
+
head_type='1d',
|
591 |
+
),
|
592 |
+
)
|
593 |
+
stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
594 |
+
model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
595 |
+
return model
|
596 |
+
|
597 |
+
|
598 |
+
@register_model
|
599 |
+
def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
|
600 |
+
stage_1_2 = MbConvStages(cfg=VitCfg(
|
601 |
+
embed_dim=(128, 256, 768),
|
602 |
+
depths=(2, 4, 1),
|
603 |
+
stem_width=128,
|
604 |
+
conv_cfg = VitConvCfg(
|
605 |
+
norm_layer='layernorm2d',
|
606 |
+
norm_eps=1e-6,
|
607 |
+
),
|
608 |
+
head_type='1d',
|
609 |
+
),
|
610 |
+
)
|
611 |
+
stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
612 |
+
model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
613 |
+
return model
|
614 |
+
|
615 |
+
|
616 |
+
@register_model
|
617 |
+
def vitamin_base_256(pretrained=False, **kwargs) -> VisionTransformer:
|
618 |
+
stage_1_2 = MbConvStages(cfg=VitCfg(
|
619 |
+
embed_dim=(128, 256, 768),
|
620 |
+
depths=(2, 4, 1),
|
621 |
+
stem_width=128,
|
622 |
+
conv_cfg = VitConvCfg(
|
623 |
+
norm_layer='layernorm2d',
|
624 |
+
norm_eps=1e-6,
|
625 |
+
),
|
626 |
+
head_type='1d',
|
627 |
+
),
|
628 |
+
)
|
629 |
+
stage3_args = dict(img_size=256, embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
630 |
+
model = _create_vision_transformer_hybrid('vitamin_base_256', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
631 |
+
return model
|
632 |
+
|
633 |
+
|
634 |
+
@register_model
|
635 |
+
def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
|
636 |
+
stage_1_2 = MbConvStages(cfg=VitCfg(
|
637 |
+
embed_dim=(160, 320, 1024),
|
638 |
+
depths=(2, 4, 1),
|
639 |
+
stem_width=160,
|
640 |
+
conv_cfg = VitConvCfg(
|
641 |
+
norm_layer='layernorm2d',
|
642 |
+
norm_eps=1e-6,
|
643 |
+
),
|
644 |
+
head_type='1d',
|
645 |
+
),
|
646 |
+
)
|
647 |
+
stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
648 |
+
model = _create_vision_transformer_hybrid(
|
649 |
+
'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
|
650 |
+
return model
|
651 |
+
|
652 |
+
# @register_model
|
653 |
+
# def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
|
654 |
+
# backbone = MbConvStages(cfg=VitCfg(
|
655 |
+
# embed_dim=(160, 320, 1024),
|
656 |
+
# depths=(2, 4, 1),
|
657 |
+
# stem_width=160,
|
658 |
+
# conv_cfg = VitConvCfg(
|
659 |
+
# norm_layer='layernorm2d',
|
660 |
+
# norm_eps=1e-6,
|
661 |
+
# ),
|
662 |
+
# head_type='1d',
|
663 |
+
# ),
|
664 |
+
# )
|
665 |
+
# model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
666 |
+
# model = _create_vision_transformer_hybrid(
|
667 |
+
# 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
668 |
+
# return model
|
669 |
+
|
670 |
+
# @register_model
|
671 |
+
# def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
|
672 |
+
# backbone = MbConvStages(cfg=VitCfg(
|
673 |
+
# embed_dim=(160, 320, 1024),
|
674 |
+
# depths=(2, 4, 1),
|
675 |
+
# stem_width=160,
|
676 |
+
# conv_cfg = VitConvCfg(
|
677 |
+
# norm_layer='layernorm2d',
|
678 |
+
# norm_eps=1e-6,
|
679 |
+
# ),
|
680 |
+
# head_type='1d',
|
681 |
+
# ),
|
682 |
+
# )
|
683 |
+
# model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
684 |
+
# model = _create_vision_transformer_hybrid(
|
685 |
+
# 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
686 |
+
# return model
|
687 |
+
|
688 |
+
# @register_model
|
689 |
+
# def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
|
690 |
+
# backbone = MbConvStages(cfg=VitCfg(
|
691 |
+
# embed_dim=(160, 320, 1024),
|
692 |
+
# depths=(2, 4, 1),
|
693 |
+
# stem_width=160,
|
694 |
+
# conv_cfg = VitConvCfg(
|
695 |
+
# norm_layer='layernorm2d',
|
696 |
+
# norm_eps=1e-6,
|
697 |
+
# ),
|
698 |
+
# head_type='1d',
|
699 |
+
# ),
|
700 |
+
# )
|
701 |
+
# model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
702 |
+
# model = _create_vision_transformer_hybrid(
|
703 |
+
# 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
704 |
+
# return model
|
705 |
+
|
706 |
+
# @register_model
|
707 |
+
# def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
|
708 |
+
# backbone = MbConvStages(cfg=VitCfg(
|
709 |
+
# embed_dim=(192, 384, 1152),
|
710 |
+
# depths=(2, 4, 1),
|
711 |
+
# stem_width=192,
|
712 |
+
# conv_cfg = VitConvCfg(
|
713 |
+
# norm_layer='layernorm2d',
|
714 |
+
# norm_eps=1e-6,
|
715 |
+
# ),
|
716 |
+
# head_type='1d',
|
717 |
+
# ),
|
718 |
+
# )
|
719 |
+
# model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
720 |
+
# model = _create_vision_transformer_hybrid(
|
721 |
+
# 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
722 |
+
# return model
|
723 |
+
|
724 |
+
# @register_model
|
725 |
+
# def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
|
726 |
+
# backbone = MbConvStages(cfg=VitCfg(
|
727 |
+
# embed_dim=(192, 384, 1152),
|
728 |
+
# depths=(2, 4, 1),
|
729 |
+
# stem_width=192,
|
730 |
+
# conv_cfg = VitConvCfg(
|
731 |
+
# norm_layer='layernorm2d',
|
732 |
+
# norm_eps=1e-6,
|
733 |
+
# ),
|
734 |
+
# head_type='1d',
|
735 |
+
# ),
|
736 |
+
# )
|
737 |
+
# model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
738 |
+
# model = _create_vision_transformer_hybrid(
|
739 |
+
# 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
740 |
+
# return model
|
741 |
+
|
742 |
+
# @register_model
|
743 |
+
# def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
|
744 |
+
# backbone = MbConvStages(cfg=VitCfg(
|
745 |
+
# embed_dim=(192, 384, 1152),
|
746 |
+
# depths=(2, 4, 1),
|
747 |
+
# stem_width=192,
|
748 |
+
# conv_cfg = VitConvCfg(
|
749 |
+
# norm_layer='layernorm2d',
|
750 |
+
# norm_eps=1e-6,
|
751 |
+
# ),
|
752 |
+
# head_type='1d',
|
753 |
+
# ),
|
754 |
+
# )
|
755 |
+
# model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
|
756 |
+
# model = _create_vision_transformer_hybrid(
|
757 |
+
# 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
|
758 |
+
# return model
|
759 |
+
|
760 |
+
|
761 |
+
def count_params(model: nn.Module):
|
762 |
+
return sum([m.numel() for m in model.parameters()])
|
763 |
+
|
764 |
+
|
765 |
+
def count_stage_params(model: nn.Module, prefix='none'):
|
766 |
+
collections = []
|
767 |
+
for name, m in model.named_parameters():
|
768 |
+
print(name)
|
769 |
+
if name.startswith(prefix):
|
770 |
+
collections.append(m.numel())
|
771 |
+
return sum(collections)
|
772 |
+
|
773 |
+
|
774 |
+
if __name__ == "__main__":
|
775 |
+
# ViTaminDecoder('vitamin_base', img_size=256, patch_size=16)
|
776 |
+
# model = timm.create_model(
|
777 |
+
# 'vitamin_base',
|
778 |
+
# fc_norm=True,
|
779 |
+
# drop_rate=0.0,
|
780 |
+
# num_classes=0,
|
781 |
+
# global_pool='',
|
782 |
+
# mlp_layer=GeGluMlp,
|
783 |
+
# class_token=False,
|
784 |
+
# reg_tokens=32,
|
785 |
+
# img_size=256,
|
786 |
+
# patch_size=1,
|
787 |
+
# drop_path_rate=0.1,
|
788 |
+
# )
|
789 |
+
# print(model.has_class_token)
|
790 |
+
# print(model.num_prefix_tokens)
|
791 |
+
# print(model.pos_embed.shape)
|
792 |
+
Stem(64, 64)
|
unitok/vqvae.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from torch.nn.functional import scaled_dot_product_attention
|
7 |
+
|
8 |
+
from unitok.quant import VectorQuantizerM
|
9 |
+
from unitok.vitamin import ViTaminDecoder, GeGluMlp
|
10 |
+
|
11 |
+
|
12 |
+
class PlainAttention(nn.Module):
|
13 |
+
def __init__(self, in_dim, out_dim, num_heads):
|
14 |
+
super().__init__()
|
15 |
+
if in_dim > out_dim:
|
16 |
+
# assert in_dim // num_heads == out_dim
|
17 |
+
self.head_dim = in_dim // num_heads
|
18 |
+
self.qkv = nn.Linear(in_dim, in_dim * 3, bias=False)
|
19 |
+
self.q_bias = nn.Parameter(torch.zeros(in_dim))
|
20 |
+
self.v_bias = nn.Parameter(torch.zeros(in_dim))
|
21 |
+
self.register_buffer('zero_k_bias', torch.zeros(in_dim))
|
22 |
+
else:
|
23 |
+
# assert out_dim // num_heads == in_dim
|
24 |
+
self.head_dim = out_dim // num_heads
|
25 |
+
self.qkv = nn.Linear(in_dim, out_dim * 3, bias=False)
|
26 |
+
self.q_bias = nn.Parameter(torch.zeros(out_dim))
|
27 |
+
self.v_bias = nn.Parameter(torch.zeros(out_dim))
|
28 |
+
self.register_buffer('zero_k_bias', torch.zeros(out_dim))
|
29 |
+
|
30 |
+
self.in_dim = in_dim
|
31 |
+
self.out_dim = out_dim
|
32 |
+
self.num_heads = num_heads
|
33 |
+
self.scale = self.head_dim ** -0.5
|
34 |
+
self.proj = nn.Linear(out_dim, out_dim)
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
37 |
+
B, N, C = x.shape
|
38 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias)))
|
39 |
+
q, k, v = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
|
40 |
+
|
41 |
+
x = scaled_dot_product_attention(q, k, v)
|
42 |
+
|
43 |
+
if self.in_dim > self.out_dim:
|
44 |
+
x = torch.mean(x, dim=1)
|
45 |
+
if self.in_dim // self.num_heads != self.out_dim:
|
46 |
+
x = nn.functional.adaptive_avg_pool1d(x, self.out_dim)
|
47 |
+
else:
|
48 |
+
x = x.transpose(1, 2).reshape(B, N, -1)
|
49 |
+
x = self.proj(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class AttnProjection(nn.Module):
|
54 |
+
def __init__(self, in_dim, out_dim, num_heads, norm_layer=nn.LayerNorm, mlp_ratio=2):
|
55 |
+
super().__init__()
|
56 |
+
assert out_dim % in_dim == 0 or in_dim % out_dim == 0
|
57 |
+
self.in_dim = in_dim
|
58 |
+
self.out_dim = out_dim
|
59 |
+
self.norm1 = norm_layer(in_dim)
|
60 |
+
self.attn = PlainAttention(in_dim, out_dim, num_heads)
|
61 |
+
self.proj = nn.Linear(in_dim, out_dim)
|
62 |
+
self.norm3 = norm_layer(in_dim)
|
63 |
+
|
64 |
+
self.norm2 = norm_layer(out_dim)
|
65 |
+
hidden_dim = int(out_dim * mlp_ratio)
|
66 |
+
self.mlp = GeGluMlp(
|
67 |
+
in_features=out_dim,
|
68 |
+
hidden_features=hidden_dim
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.proj(self.norm3(x)) + self.attn(self.norm1(x))
|
73 |
+
x = x + self.mlp(self.norm2(x))
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class VQVAE(nn.Module):
|
78 |
+
def __init__(self, args):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
# 1. build encoder
|
82 |
+
self.encoder = timm.create_model(
|
83 |
+
args.model,
|
84 |
+
patch_size=1,
|
85 |
+
fc_norm=True,
|
86 |
+
drop_rate=0.0,
|
87 |
+
num_classes=0,
|
88 |
+
global_pool='',
|
89 |
+
pos_embed='none',
|
90 |
+
class_token=False,
|
91 |
+
mlp_layer=GeGluMlp,
|
92 |
+
img_size=args.img_size,
|
93 |
+
drop_path_rate=args.drop_path,
|
94 |
+
)
|
95 |
+
self.encoder.set_grad_checkpointing(args.grad_ckpt)
|
96 |
+
|
97 |
+
# 2. build conv before quant
|
98 |
+
if args.quant_proj == 'linear':
|
99 |
+
self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width)
|
100 |
+
elif args.quant_proj == 'attn':
|
101 |
+
self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, args.num_codebooks)
|
102 |
+
else:
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
# 3. build quant
|
106 |
+
self.quantize = VectorQuantizerM(
|
107 |
+
vocab_size=args.vocab_size,
|
108 |
+
vocab_width=args.vocab_width,
|
109 |
+
beta=args.vq_beta,
|
110 |
+
use_entropy_loss=args.le > 0,
|
111 |
+
entropy_temp=args.e_temp,
|
112 |
+
num_codebooks=args.num_codebooks,
|
113 |
+
)
|
114 |
+
|
115 |
+
# 4. build conv after quant
|
116 |
+
if args.quant_proj == 'linear':
|
117 |
+
self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim)
|
118 |
+
elif args.quant_proj == 'attn':
|
119 |
+
self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, args.num_codebooks)
|
120 |
+
else:
|
121 |
+
raise NotImplementedError
|
122 |
+
|
123 |
+
# 5. build decoder
|
124 |
+
self.decoder = ViTaminDecoder(
|
125 |
+
args.model,
|
126 |
+
depths=(4, 2),
|
127 |
+
img_size=args.img_size,
|
128 |
+
drop_path=args.drop_path,
|
129 |
+
grad_ckpt=args.grad_ckpt
|
130 |
+
)
|
131 |
+
|
132 |
+
self.maybe_record_function = nullcontext
|
133 |
+
|
134 |
+
def forward(self, img):
|
135 |
+
features = self.encoder(img).float()
|
136 |
+
with torch.cuda.amp.autocast(enabled=False):
|
137 |
+
features = self.quant_proj(features)
|
138 |
+
quant_out = self.quantize(features)
|
139 |
+
features, vq_loss, entropy_loss, usages = quant_out
|
140 |
+
features = self.post_quant_proj(features)
|
141 |
+
rec_img = self.decoder(features).float()
|
142 |
+
return rec_img, vq_loss, entropy_loss, usages
|
143 |
+
|
144 |
+
def img_to_idx(self, img):
|
145 |
+
features = self.encoder(img).float()
|
146 |
+
features = self.quant_proj(features)
|
147 |
+
return self.quantize.f_to_idx(features)
|
148 |
+
|
149 |
+
def idx_to_img(self, indices):
|
150 |
+
features = self.quantize.idx_to_f(indices)
|
151 |
+
features = self.post_quant_proj(features)
|
152 |
+
img = self.decoder(features).clamp_(-1, 1)
|
153 |
+
return img
|
154 |
+
|
155 |
+
def img_to_reconstructed_img(self, img) -> torch.Tensor:
|
156 |
+
features = self.encoder(img).float()
|
157 |
+
with torch.cuda.amp.autocast(enabled=False):
|
158 |
+
features = self.quant_proj(features)
|
159 |
+
quant_out = self.quantize(features)
|
160 |
+
features, _, _, _ = quant_out
|
161 |
+
features = self.post_quant_proj(features)
|
162 |
+
rec_img = self.decoder(features).float().clamp_(-1, 1)
|
163 |
+
return rec_img
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == '__main__':
|
167 |
+
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d,
|
168 |
+
nn.ConvTranspose2d):
|
169 |
+
setattr(clz, 'reset_parameters', lambda self: None)
|
170 |
+
|
171 |
+
cnn = VQVAE(channel_num=64, vocab_norm=False)
|
172 |
+
from models import init_weights
|
173 |
+
|
174 |
+
init_weights(cnn, -0.5)
|
175 |
+
torch.save(cnn.state_dict(), r'C:\Users\16333\Desktop\PyCharm\vlip\local_output\cnn.pth')
|