Upload HfMoondream
Browse files- config.json +1 -1
- config.py +3 -0
- generation_config.json +1 -1
- hf_moondream.py +26 -7
- layers.py +2 -2
- model.safetensors +1 -1
- moondream.py +204 -136
- region.py +19 -12
- text.py +114 -76
- vision.py +21 -7
config.json
CHANGED
|
@@ -9,5 +9,5 @@
|
|
| 9 |
"config": {},
|
| 10 |
"model_type": "moondream1",
|
| 11 |
"torch_dtype": "float16",
|
| 12 |
-
"transformers_version": "4.
|
| 13 |
}
|
|
|
|
| 9 |
"config": {},
|
| 10 |
"model_type": "moondream1",
|
| 11 |
"torch_dtype": "float16",
|
| 12 |
+
"transformers_version": "4.48.0"
|
| 13 |
}
|
config.py
CHANGED
|
@@ -5,10 +5,12 @@ from typing import Dict, List, Optional
|
|
| 5 |
@dataclass(frozen=True)
|
| 6 |
class TextConfig:
|
| 7 |
dim: int = 2048
|
|
|
|
| 8 |
n_layers: int = 24
|
| 9 |
vocab_size: int = 51200
|
| 10 |
max_context: int = 2048
|
| 11 |
n_heads: int = 32
|
|
|
|
| 12 |
prefix_attn: int = 730
|
| 13 |
|
| 14 |
|
|
@@ -46,6 +48,7 @@ class TokenizerConfig:
|
|
| 46 |
"caption": {
|
| 47 |
"short": [198, 198, 16438, 8305, 25],
|
| 48 |
"normal": [198, 198, 24334, 1159, 25],
|
|
|
|
| 49 |
},
|
| 50 |
"query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
|
| 51 |
"detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
|
|
|
|
| 5 |
@dataclass(frozen=True)
|
| 6 |
class TextConfig:
|
| 7 |
dim: int = 2048
|
| 8 |
+
ff_dim: int = 8192
|
| 9 |
n_layers: int = 24
|
| 10 |
vocab_size: int = 51200
|
| 11 |
max_context: int = 2048
|
| 12 |
n_heads: int = 32
|
| 13 |
+
n_kv_heads: int = 32
|
| 14 |
prefix_attn: int = 730
|
| 15 |
|
| 16 |
|
|
|
|
| 48 |
"caption": {
|
| 49 |
"short": [198, 198, 16438, 8305, 25],
|
| 50 |
"normal": [198, 198, 24334, 1159, 25],
|
| 51 |
+
"long": [198, 198, 14617, 8305, 25],
|
| 52 |
},
|
| 53 |
"query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
|
| 54 |
"detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
|
generation_config.json
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
{
|
| 2 |
"_from_model_config": true,
|
| 3 |
-
"transformers_version": "4.
|
| 4 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"_from_model_config": true,
|
| 3 |
+
"transformers_version": "4.48.0"
|
| 4 |
}
|
hf_moondream.py
CHANGED
|
@@ -14,7 +14,7 @@ from .utils import *
|
|
| 14 |
def extract_question(text):
|
| 15 |
prefix = "<image>\n\nQuestion: "
|
| 16 |
suffix = "\n\nAnswer:"
|
| 17 |
-
|
| 18 |
if text.startswith(prefix) and text.endswith(suffix):
|
| 19 |
return text[len(prefix) : -len(suffix)]
|
| 20 |
else:
|
|
@@ -36,30 +36,44 @@ class HfMoondream(PreTrainedModel):
|
|
| 36 |
|
| 37 |
def __init__(self, config):
|
| 38 |
super().__init__(config)
|
| 39 |
-
self.model = MoondreamModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
@property
|
| 42 |
def encode_image(self):
|
|
|
|
| 43 |
return self.model.encode_image
|
| 44 |
|
| 45 |
@property
|
| 46 |
def query(self):
|
|
|
|
| 47 |
return self.model.query
|
| 48 |
|
| 49 |
@property
|
| 50 |
def caption(self):
|
|
|
|
| 51 |
return self.model.caption
|
| 52 |
|
| 53 |
@property
|
| 54 |
def detect(self):
|
|
|
|
| 55 |
return self.model.detect
|
| 56 |
|
| 57 |
@property
|
| 58 |
def point(self):
|
|
|
|
| 59 |
return self.model.point
|
| 60 |
|
| 61 |
@property
|
| 62 |
def detect_gaze(self):
|
|
|
|
| 63 |
return self.model.detect_gaze
|
| 64 |
|
| 65 |
def answer_question(
|
|
@@ -98,22 +112,27 @@ class HfMoondream(PreTrainedModel):
|
|
| 98 |
"""
|
| 99 |
prompt_extracted = extract_question(prompt)
|
| 100 |
if prompt_extracted is not None:
|
| 101 |
-
answer = self.model.query(
|
| 102 |
-
|
| 103 |
-
]
|
| 104 |
else:
|
| 105 |
image_embeds = self.encode_image(image_embeds)
|
| 106 |
prompt_tokens = torch.tensor(
|
| 107 |
[self.model.tokenizer.encode(prompt).ids],
|
| 108 |
device=self.device,
|
| 109 |
)
|
|
|
|
| 110 |
def generator():
|
| 111 |
for token in self.model._generate_text(
|
| 112 |
-
prompt_tokens,
|
|
|
|
|
|
|
|
|
|
| 113 |
):
|
| 114 |
yield token
|
|
|
|
| 115 |
answer = "".join(list(generator()))
|
| 116 |
-
|
| 117 |
return [answer]
|
| 118 |
|
| 119 |
def get_input_embeddings(self):
|
|
|
|
| 14 |
def extract_question(text):
|
| 15 |
prefix = "<image>\n\nQuestion: "
|
| 16 |
suffix = "\n\nAnswer:"
|
| 17 |
+
|
| 18 |
if text.startswith(prefix) and text.endswith(suffix):
|
| 19 |
return text[len(prefix) : -len(suffix)]
|
| 20 |
else:
|
|
|
|
| 36 |
|
| 37 |
def __init__(self, config):
|
| 38 |
super().__init__(config)
|
| 39 |
+
self.model = MoondreamModel(
|
| 40 |
+
MoondreamConfig.from_dict(config.config), setup_caches=False
|
| 41 |
+
)
|
| 42 |
+
self._is_kv_cache_setup = False
|
| 43 |
+
|
| 44 |
+
def _setup_caches(self):
|
| 45 |
+
if not self._is_kv_cache_setup:
|
| 46 |
+
self.model._setup_caches()
|
| 47 |
+
self._is_kv_cache_setup = True
|
| 48 |
|
| 49 |
@property
|
| 50 |
def encode_image(self):
|
| 51 |
+
self._setup_caches()
|
| 52 |
return self.model.encode_image
|
| 53 |
|
| 54 |
@property
|
| 55 |
def query(self):
|
| 56 |
+
self._setup_caches()
|
| 57 |
return self.model.query
|
| 58 |
|
| 59 |
@property
|
| 60 |
def caption(self):
|
| 61 |
+
self._setup_caches()
|
| 62 |
return self.model.caption
|
| 63 |
|
| 64 |
@property
|
| 65 |
def detect(self):
|
| 66 |
+
self._setup_caches()
|
| 67 |
return self.model.detect
|
| 68 |
|
| 69 |
@property
|
| 70 |
def point(self):
|
| 71 |
+
self._setup_caches()
|
| 72 |
return self.model.point
|
| 73 |
|
| 74 |
@property
|
| 75 |
def detect_gaze(self):
|
| 76 |
+
self._setup_caches()
|
| 77 |
return self.model.detect_gaze
|
| 78 |
|
| 79 |
def answer_question(
|
|
|
|
| 112 |
"""
|
| 113 |
prompt_extracted = extract_question(prompt)
|
| 114 |
if prompt_extracted is not None:
|
| 115 |
+
answer = self.model.query(
|
| 116 |
+
image=image_embeds, question=prompt_extracted, stream=False
|
| 117 |
+
)["answer"]
|
| 118 |
else:
|
| 119 |
image_embeds = self.encode_image(image_embeds)
|
| 120 |
prompt_tokens = torch.tensor(
|
| 121 |
[self.model.tokenizer.encode(prompt).ids],
|
| 122 |
device=self.device,
|
| 123 |
)
|
| 124 |
+
|
| 125 |
def generator():
|
| 126 |
for token in self.model._generate_text(
|
| 127 |
+
prompt_tokens,
|
| 128 |
+
image_embeds.kv_cache,
|
| 129 |
+
image_embeds.pos,
|
| 130 |
+
max_new_tokens,
|
| 131 |
):
|
| 132 |
yield token
|
| 133 |
+
|
| 134 |
answer = "".join(list(generator()))
|
| 135 |
+
|
| 136 |
return [answer]
|
| 137 |
|
| 138 |
def get_input_embeddings(self):
|
layers.py
CHANGED
|
@@ -37,9 +37,9 @@ class MLPWeights:
|
|
| 37 |
|
| 38 |
|
| 39 |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
|
| 40 |
-
x =
|
| 41 |
x = gelu_approx(x)
|
| 42 |
-
x =
|
| 43 |
return x
|
| 44 |
|
| 45 |
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
|
| 40 |
+
x = w.fc1(x)
|
| 41 |
x = gelu_approx(x)
|
| 42 |
+
x = w.fc2(x)
|
| 43 |
return x
|
| 44 |
|
| 45 |
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 3854538376
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fadcffea8c17fe8a20ea68af3a013cf3184a63787ee4453cc9eb75206c7c1f9b
|
| 3 |
size 3854538376
|
moondream.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import random
|
| 4 |
|
| 5 |
-
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional
|
| 6 |
from PIL import Image
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from tokenizers import Tokenizer
|
|
@@ -10,7 +10,7 @@ from tokenizers import Tokenizer
|
|
| 10 |
from .config import MoondreamConfig
|
| 11 |
from .image_crops import reconstruct_from_crops
|
| 12 |
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
| 13 |
-
from .text import build_text_model,
|
| 14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
| 15 |
from .utils import remove_outlier_points
|
| 16 |
|
|
@@ -21,53 +21,41 @@ SamplingSettings = TypedDict(
|
|
| 21 |
total=False,
|
| 22 |
)
|
| 23 |
|
| 24 |
-
DEFAULT_MAX_TOKENS =
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass(frozen=True)
|
| 28 |
class EncodedImage:
|
| 29 |
pos: int
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
sorted_indices_to_remove = torch.gather(
|
| 51 |
-
tokens_to_remove, dim=-1, index=sorted_indices
|
| 52 |
-
)
|
| 53 |
-
if min_tokens_to_keep > 1:
|
| 54 |
-
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
| 55 |
-
|
| 56 |
-
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 57 |
-
1, sorted_indices, sorted_indices_to_remove
|
| 58 |
-
)
|
| 59 |
-
logits = logits.masked_fill(indices_to_remove, filter_value)
|
| 60 |
-
token = torch.multinomial(logits, num_samples=1)
|
| 61 |
-
return token.squeeze(0)
|
| 62 |
|
| 63 |
|
| 64 |
class MoondreamModel(nn.Module):
|
| 65 |
-
def __init__(self, config: MoondreamConfig, dtype=torch.float16):
|
| 66 |
super().__init__()
|
| 67 |
self.config = config
|
| 68 |
|
| 69 |
self.tokenizer = Tokenizer.from_pretrained(
|
| 70 |
-
"vikhyatk/moondream2", revision="
|
| 71 |
)
|
| 72 |
self.vision = build_vision_model(config.vision, dtype)
|
| 73 |
self.text = build_text_model(config.text, dtype)
|
|
@@ -114,35 +102,65 @@ class MoondreamModel(nn.Module):
|
|
| 114 |
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
|
| 115 |
)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
@property
|
| 125 |
def device(self):
|
| 126 |
return self.vision.pos_emb.device
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
def compile(self):
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# self.ops["vision_projection"], fullgraph=True
|
| 135 |
-
# )
|
| 136 |
-
self.ops["prefill"] = torch.compile(self.ops["prefill"], fullgraph=True)
|
| 137 |
-
self.ops["decode_one_token"] = torch.compile(
|
| 138 |
-
self.ops["decode_one_token"], fullgraph=True
|
| 139 |
)
|
| 140 |
|
| 141 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
| 142 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
| 143 |
torch._dynamo.mark_dynamic(all_crops, 0)
|
| 144 |
|
| 145 |
-
outputs = self.
|
| 146 |
|
| 147 |
global_features = outputs[0]
|
| 148 |
local_features = outputs[1:].view(
|
|
@@ -159,9 +177,7 @@ class MoondreamModel(nn.Module):
|
|
| 159 |
overlap_margin=self.config.vision.overlap_margin,
|
| 160 |
)
|
| 161 |
|
| 162 |
-
return self.
|
| 163 |
-
global_features, reconstructed, self.vision, self.config.vision
|
| 164 |
-
)
|
| 165 |
|
| 166 |
def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
|
| 167 |
if isinstance(image, EncodedImage):
|
|
@@ -171,34 +187,35 @@ class MoondreamModel(nn.Module):
|
|
| 171 |
|
| 172 |
# Run through text model in addition to the vision encoder, to minimize
|
| 173 |
# re-computation if multiple queries are performed on this image.
|
| 174 |
-
|
| 175 |
-
self.config.text.n_layers,
|
| 176 |
-
2, # k, v
|
| 177 |
-
1, # batch size
|
| 178 |
-
self.config.text.n_heads,
|
| 179 |
-
self.config.text.max_context, # static cache
|
| 180 |
-
self.config.text.dim // self.config.text.n_heads, # head dim
|
| 181 |
-
device=self.device,
|
| 182 |
-
dtype=torch.float16,
|
| 183 |
-
)
|
| 184 |
-
with torch.no_grad():
|
| 185 |
img_emb = self._run_vision_encoder(image)
|
| 186 |
bos_emb = text_encoder(
|
| 187 |
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
|
| 188 |
self.text,
|
| 189 |
)
|
| 190 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 191 |
-
self.
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
def _prefill_prompt(
|
| 195 |
-
|
| 196 |
-
):
|
| 197 |
-
with torch.no_grad():
|
| 198 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
)
|
|
|
|
| 202 |
logits = lm_head(hidden, self.text)
|
| 203 |
next_token = torch.argmax(logits, dim=-1)
|
| 204 |
pos = pos + prompt_emb.size(1)
|
|
@@ -207,33 +224,67 @@ class MoondreamModel(nn.Module):
|
|
| 207 |
def _generate_text(
|
| 208 |
self,
|
| 209 |
prompt_tokens: torch.Tensor,
|
| 210 |
-
kv_cache: torch.Tensor,
|
| 211 |
pos: int,
|
| 212 |
max_tokens: int,
|
| 213 |
):
|
| 214 |
-
|
| 215 |
-
_, _, next_token, pos = self._prefill_prompt(kv_cache, prompt_tokens, pos)
|
| 216 |
|
| 217 |
def generator(next_token, pos):
|
|
|
|
|
|
|
|
|
|
| 218 |
generated_tokens = 0
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
while (
|
| 221 |
next_token_id := next_token.item()
|
| 222 |
) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
next_emb = text_encoder(next_token, self.text)
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
)
|
| 230 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
| 231 |
-
kv_cache_update
|
| 232 |
-
)
|
| 233 |
pos += 1
|
| 234 |
next_token = torch.argmax(logits, dim=-1)
|
| 235 |
generated_tokens += 1
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
return generator(next_token, pos)
|
| 238 |
|
| 239 |
def query(
|
|
@@ -247,10 +298,12 @@ class MoondreamModel(nn.Module):
|
|
| 247 |
raise NotImplementedError("Model does not support querying.")
|
| 248 |
|
| 249 |
image = self.encode_image(image)
|
|
|
|
|
|
|
| 250 |
prompt_tokens = torch.tensor(
|
| 251 |
[
|
| 252 |
self.config.tokenizer.templates["query"]["prefix"]
|
| 253 |
-
+ self.tokenizer.encode(question).ids
|
| 254 |
+ self.config.tokenizer.templates["query"]["suffix"]
|
| 255 |
],
|
| 256 |
device=self.device,
|
|
@@ -261,9 +314,7 @@ class MoondreamModel(nn.Module):
|
|
| 261 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
| 262 |
|
| 263 |
def generator():
|
| 264 |
-
for token in self._generate_text(
|
| 265 |
-
prompt_tokens, image.kv_cache, image.pos, max_tokens
|
| 266 |
-
):
|
| 267 |
yield token
|
| 268 |
|
| 269 |
if stream:
|
|
@@ -271,10 +322,15 @@ class MoondreamModel(nn.Module):
|
|
| 271 |
else:
|
| 272 |
return {"answer": "".join(list(generator()))}
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
def caption(
|
| 275 |
self,
|
| 276 |
image: Union[Image.Image, EncodedImage],
|
| 277 |
-
length: Literal["normal", "short"] = "normal",
|
| 278 |
stream: bool = False,
|
| 279 |
settings: Optional[SamplingSettings] = None,
|
| 280 |
):
|
|
@@ -284,6 +340,8 @@ class MoondreamModel(nn.Module):
|
|
| 284 |
raise ValueError(f"Model does not support caption length '{length}'.")
|
| 285 |
|
| 286 |
image = self.encode_image(image)
|
|
|
|
|
|
|
| 287 |
prompt_tokens = torch.tensor(
|
| 288 |
[self.config.tokenizer.templates["caption"][length]], device=self.device
|
| 289 |
)
|
|
@@ -293,9 +351,7 @@ class MoondreamModel(nn.Module):
|
|
| 293 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
| 294 |
|
| 295 |
def generator():
|
| 296 |
-
for token in self._generate_text(
|
| 297 |
-
prompt_tokens, image.kv_cache, image.pos, max_tokens
|
| 298 |
-
):
|
| 299 |
yield token
|
| 300 |
|
| 301 |
if stream:
|
|
@@ -306,15 +362,17 @@ class MoondreamModel(nn.Module):
|
|
| 306 |
def _generate_points(
|
| 307 |
self,
|
| 308 |
hidden: torch.Tensor,
|
| 309 |
-
kv_cache: torch.Tensor,
|
| 310 |
next_token: torch.Tensor,
|
| 311 |
pos: int,
|
| 312 |
include_size: bool = True,
|
| 313 |
max_points: int = 50,
|
| 314 |
):
|
| 315 |
out = []
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
-
with torch.
|
| 318 |
while (
|
| 319 |
next_token.item() != self.config.tokenizer.eos_id
|
| 320 |
and len(out) < max_points
|
|
@@ -326,12 +384,8 @@ class MoondreamModel(nn.Module):
|
|
| 326 |
)
|
| 327 |
|
| 328 |
# Decode y-coordinate
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
)
|
| 332 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
| 333 |
-
kv_cache_update
|
| 334 |
-
)
|
| 335 |
pos += 1
|
| 336 |
y_logits = decode_coordinate(hidden, self.region)
|
| 337 |
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
|
@@ -341,16 +395,20 @@ class MoondreamModel(nn.Module):
|
|
| 341 |
|
| 342 |
# Decode size
|
| 343 |
if include_size:
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
)
|
| 347 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
| 348 |
-
kv_cache_update
|
| 349 |
-
)
|
| 350 |
pos += 1
|
| 351 |
size_logits = decode_size(hidden, self.region)
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
next_emb = encode_size(
|
| 355 |
torch.tensor(
|
| 356 |
[w, h], device=self.device, dtype=size_logits.dtype
|
|
@@ -371,12 +429,8 @@ class MoondreamModel(nn.Module):
|
|
| 371 |
out.append({"x": x_center.item(), "y": y_center.item()})
|
| 372 |
|
| 373 |
# Decode next token (x-coordinate, or eos)
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
)
|
| 377 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
| 378 |
-
kv_cache_update
|
| 379 |
-
)
|
| 380 |
pos += 1
|
| 381 |
next_token = torch.argmax(logits, dim=-1)
|
| 382 |
|
|
@@ -392,23 +446,22 @@ class MoondreamModel(nn.Module):
|
|
| 392 |
raise NotImplementedError("Model does not support object detection.")
|
| 393 |
|
| 394 |
image = self.encode_image(image)
|
|
|
|
|
|
|
| 395 |
prompt_tokens = torch.tensor(
|
| 396 |
[
|
| 397 |
self.config.tokenizer.templates["detect"]["prefix"]
|
| 398 |
-
+ self.tokenizer.encode(object).ids
|
| 399 |
+ self.config.tokenizer.templates["detect"]["suffix"]
|
| 400 |
],
|
| 401 |
device=self.device,
|
| 402 |
)
|
| 403 |
|
| 404 |
-
|
| 405 |
-
_, hidden, next_token, pos = self._prefill_prompt(
|
| 406 |
-
kv_cache, prompt_tokens, image.pos
|
| 407 |
-
)
|
| 408 |
hidden = hidden[:, -1:, :]
|
| 409 |
|
| 410 |
objects = self._generate_points(
|
| 411 |
-
hidden,
|
| 412 |
)
|
| 413 |
|
| 414 |
return {"objects": objects}
|
|
@@ -423,23 +476,22 @@ class MoondreamModel(nn.Module):
|
|
| 423 |
raise NotImplementedError("Model does not support pointing.")
|
| 424 |
|
| 425 |
image = self.encode_image(image)
|
|
|
|
|
|
|
| 426 |
prompt_tokens = torch.tensor(
|
| 427 |
[
|
| 428 |
self.config.tokenizer.templates["point"]["prefix"]
|
| 429 |
-
+ self.tokenizer.encode(object).ids
|
| 430 |
+ self.config.tokenizer.templates["point"]["suffix"]
|
| 431 |
],
|
| 432 |
device=self.device,
|
| 433 |
)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
_, hidden, next_token, pos = self._prefill_prompt(
|
| 437 |
-
kv_cache, prompt_tokens, image.pos
|
| 438 |
-
)
|
| 439 |
hidden = hidden[:, -1:, :]
|
| 440 |
|
| 441 |
objects = self._generate_points(
|
| 442 |
-
hidden,
|
| 443 |
)
|
| 444 |
|
| 445 |
return {"points": objects}
|
|
@@ -450,7 +502,7 @@ class MoondreamModel(nn.Module):
|
|
| 450 |
source: Tuple[float, float],
|
| 451 |
force_detect: bool = False,
|
| 452 |
):
|
| 453 |
-
with torch.
|
| 454 |
before_emb = text_encoder(
|
| 455 |
torch.tensor(
|
| 456 |
[self.tokenizer.encode("\n\nPoint:").ids], device=self.device
|
|
@@ -474,10 +526,13 @@ class MoondreamModel(nn.Module):
|
|
| 474 |
|
| 475 |
prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
| 480 |
)
|
|
|
|
| 481 |
logits = lm_head(hidden, self.text)
|
| 482 |
next_token = torch.argmax(logits, dim=-1)
|
| 483 |
pos = image.pos + prompt_emb.size(1)
|
|
@@ -490,7 +545,7 @@ class MoondreamModel(nn.Module):
|
|
| 490 |
return None
|
| 491 |
|
| 492 |
gaze = self._generate_points(
|
| 493 |
-
hidden,
|
| 494 |
)
|
| 495 |
return gaze[0]
|
| 496 |
|
|
@@ -584,3 +639,16 @@ class MoondreamModel(nn.Module):
|
|
| 584 |
)
|
| 585 |
|
| 586 |
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import random
|
| 4 |
|
| 5 |
+
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
|
| 6 |
from PIL import Image
|
| 7 |
from dataclasses import dataclass
|
| 8 |
from tokenizers import Tokenizer
|
|
|
|
| 10 |
from .config import MoondreamConfig
|
| 11 |
from .image_crops import reconstruct_from_crops
|
| 12 |
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
| 13 |
+
from .text import build_text_model, text_encoder, lm_head, text_decoder
|
| 14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
| 15 |
from .utils import remove_outlier_points
|
| 16 |
|
|
|
|
| 21 |
total=False,
|
| 22 |
)
|
| 23 |
|
| 24 |
+
DEFAULT_MAX_TOKENS = 768
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass(frozen=True)
|
| 28 |
class EncodedImage:
|
| 29 |
pos: int
|
| 30 |
+
caches: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class KVCache(nn.Module):
|
| 34 |
+
|
| 35 |
+
def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
|
| 36 |
+
super().__init__()
|
| 37 |
+
cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
|
| 38 |
+
self.register_buffer(
|
| 39 |
+
"k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 40 |
+
)
|
| 41 |
+
self.register_buffer(
|
| 42 |
+
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def update(self, pos_ids, k, v):
|
| 46 |
+
kout, vout = self.k_cache, self.v_cache
|
| 47 |
+
kout[:, :, pos_ids, :] = k
|
| 48 |
+
vout[:, :, pos_ids, :] = v
|
| 49 |
+
return kout, vout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class MoondreamModel(nn.Module):
|
| 53 |
+
def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
|
| 54 |
super().__init__()
|
| 55 |
self.config = config
|
| 56 |
|
| 57 |
self.tokenizer = Tokenizer.from_pretrained(
|
| 58 |
+
"vikhyatk/moondream2", revision="2025-01-09"
|
| 59 |
)
|
| 60 |
self.vision = build_vision_model(config.vision, dtype)
|
| 61 |
self.text = build_text_model(config.text, dtype)
|
|
|
|
| 102 |
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
|
| 103 |
)
|
| 104 |
|
| 105 |
+
attn_mask = torch.tril(
|
| 106 |
+
torch.ones(
|
| 107 |
+
1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
patch_w = config.vision.crop_size // config.vision.enc_patch_size
|
| 111 |
+
prefix_attn_len = 1 + patch_w**2
|
| 112 |
+
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
|
| 113 |
+
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
| 114 |
+
|
| 115 |
+
# Initialize KV caches.
|
| 116 |
+
if setup_caches:
|
| 117 |
+
self._setup_caches()
|
| 118 |
+
|
| 119 |
+
def _setup_caches(self):
|
| 120 |
+
c = self.config.text
|
| 121 |
+
for b in self.text.blocks:
|
| 122 |
+
b.kv_cache = KVCache(
|
| 123 |
+
c.n_heads,
|
| 124 |
+
c.n_kv_heads,
|
| 125 |
+
c.max_context,
|
| 126 |
+
c.dim,
|
| 127 |
+
device=self.device,
|
| 128 |
+
dtype=self.vision.pos_emb.dtype,
|
| 129 |
+
)
|
| 130 |
|
| 131 |
@property
|
| 132 |
def device(self):
|
| 133 |
return self.vision.pos_emb.device
|
| 134 |
|
| 135 |
+
def _vis_enc(self, x: torch.Tensor):
|
| 136 |
+
return vision_encoder(x, self.vision, self.config.vision)
|
| 137 |
+
|
| 138 |
+
def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
|
| 139 |
+
return vision_projection(g, r, self.vision, self.config.vision)
|
| 140 |
+
|
| 141 |
+
def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
|
| 142 |
+
return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
|
| 143 |
+
|
| 144 |
+
def _decode_one_tok(
|
| 145 |
+
self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
|
| 146 |
+
):
|
| 147 |
+
hidden = text_decoder(x[None], self.text, attn_mask, pos_ids, self.config.text)
|
| 148 |
+
logits = lm_head(hidden, self.text)
|
| 149 |
+
return logits, hidden
|
| 150 |
+
|
| 151 |
def compile(self):
|
| 152 |
+
# TODO: vision_projection is not being compiled
|
| 153 |
+
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
|
| 154 |
+
self._prefill = torch.compile(self._prefill, fullgraph=True)
|
| 155 |
+
self._decode_one_tok = torch.compile(
|
| 156 |
+
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
)
|
| 158 |
|
| 159 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
| 160 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
| 161 |
torch._dynamo.mark_dynamic(all_crops, 0)
|
| 162 |
|
| 163 |
+
outputs = self._vis_enc(all_crops)
|
| 164 |
|
| 165 |
global_features = outputs[0]
|
| 166 |
local_features = outputs[1:].view(
|
|
|
|
| 177 |
overlap_margin=self.config.vision.overlap_margin,
|
| 178 |
)
|
| 179 |
|
| 180 |
+
return self._vis_proj(global_features, reconstructed)
|
|
|
|
|
|
|
| 181 |
|
| 182 |
def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
|
| 183 |
if isinstance(image, EncodedImage):
|
|
|
|
| 187 |
|
| 188 |
# Run through text model in addition to the vision encoder, to minimize
|
| 189 |
# re-computation if multiple queries are performed on this image.
|
| 190 |
+
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
img_emb = self._run_vision_encoder(image)
|
| 192 |
bos_emb = text_encoder(
|
| 193 |
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
|
| 194 |
self.text,
|
| 195 |
)
|
| 196 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 197 |
+
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 198 |
+
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 199 |
+
self._prefill(inputs_embeds, mask, pos_ids)
|
| 200 |
+
|
| 201 |
+
return EncodedImage(
|
| 202 |
+
pos=inputs_embeds.size(1),
|
| 203 |
+
caches=[
|
| 204 |
+
(
|
| 205 |
+
b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 206 |
+
b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
| 207 |
+
)
|
| 208 |
+
for b in self.text.blocks
|
| 209 |
+
],
|
| 210 |
+
)
|
| 211 |
|
| 212 |
+
def _prefill_prompt(self, prompt_tokens: torch.Tensor, pos: int):
|
| 213 |
+
with torch.inference_mode():
|
|
|
|
|
|
|
| 214 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
| 215 |
+
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 216 |
+
mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
|
| 217 |
+
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
|
| 218 |
+
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
| 219 |
logits = lm_head(hidden, self.text)
|
| 220 |
next_token = torch.argmax(logits, dim=-1)
|
| 221 |
pos = pos + prompt_emb.size(1)
|
|
|
|
| 224 |
def _generate_text(
|
| 225 |
self,
|
| 226 |
prompt_tokens: torch.Tensor,
|
|
|
|
| 227 |
pos: int,
|
| 228 |
max_tokens: int,
|
| 229 |
):
|
| 230 |
+
_, _, next_token, pos = self._prefill_prompt(prompt_tokens, pos)
|
|
|
|
| 231 |
|
| 232 |
def generator(next_token, pos):
|
| 233 |
+
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
|
| 234 |
+
mask[:, :, :pos] = 1
|
| 235 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
| 236 |
generated_tokens = 0
|
| 237 |
|
| 238 |
+
# For properly handling token streaming with Unicode
|
| 239 |
+
token_cache = []
|
| 240 |
+
print_len = 0
|
| 241 |
+
|
| 242 |
while (
|
| 243 |
next_token_id := next_token.item()
|
| 244 |
) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
|
| 245 |
+
# Add token to our cache
|
| 246 |
+
token_cache.append(next_token_id)
|
| 247 |
+
|
| 248 |
+
# Decode all tokens collected so far
|
| 249 |
+
text = self.tokenizer.decode(token_cache)
|
| 250 |
+
|
| 251 |
+
# After a newline, we flush the cache completely
|
| 252 |
+
if text.endswith("\n"):
|
| 253 |
+
printable_text = text[print_len:]
|
| 254 |
+
token_cache = []
|
| 255 |
+
print_len = 0
|
| 256 |
+
if printable_text:
|
| 257 |
+
yield printable_text
|
| 258 |
+
# If the last token is a CJK character, we can safely print it
|
| 259 |
+
elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
|
| 260 |
+
printable_text = text[print_len:]
|
| 261 |
+
print_len += len(printable_text)
|
| 262 |
+
if printable_text:
|
| 263 |
+
yield printable_text
|
| 264 |
+
# Otherwise, only print up to the last space to avoid cutting words
|
| 265 |
+
else:
|
| 266 |
+
last_space_idx = text.rfind(" ", print_len)
|
| 267 |
+
if last_space_idx >= print_len:
|
| 268 |
+
printable_text = text[print_len : last_space_idx + 1]
|
| 269 |
+
print_len += len(printable_text)
|
| 270 |
+
if printable_text:
|
| 271 |
+
yield printable_text
|
| 272 |
+
|
| 273 |
+
with torch.inference_mode():
|
| 274 |
next_emb = text_encoder(next_token, self.text)
|
| 275 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 276 |
+
logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
pos += 1
|
| 278 |
next_token = torch.argmax(logits, dim=-1)
|
| 279 |
generated_tokens += 1
|
| 280 |
|
| 281 |
+
# Flush any remaining text in the cache
|
| 282 |
+
if token_cache:
|
| 283 |
+
text = self.tokenizer.decode(token_cache)
|
| 284 |
+
printable_text = text[print_len:]
|
| 285 |
+
if printable_text:
|
| 286 |
+
yield printable_text
|
| 287 |
+
|
| 288 |
return generator(next_token, pos)
|
| 289 |
|
| 290 |
def query(
|
|
|
|
| 298 |
raise NotImplementedError("Model does not support querying.")
|
| 299 |
|
| 300 |
image = self.encode_image(image)
|
| 301 |
+
self.load_encoded_image(image)
|
| 302 |
+
|
| 303 |
prompt_tokens = torch.tensor(
|
| 304 |
[
|
| 305 |
self.config.tokenizer.templates["query"]["prefix"]
|
| 306 |
+
+ self.tokenizer.encode(" " + question).ids
|
| 307 |
+ self.config.tokenizer.templates["query"]["suffix"]
|
| 308 |
],
|
| 309 |
device=self.device,
|
|
|
|
| 314 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
| 315 |
|
| 316 |
def generator():
|
| 317 |
+
for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
|
|
|
|
|
|
|
| 318 |
yield token
|
| 319 |
|
| 320 |
if stream:
|
|
|
|
| 322 |
else:
|
| 323 |
return {"answer": "".join(list(generator()))}
|
| 324 |
|
| 325 |
+
def load_encoded_image(self, encoded_image: EncodedImage):
|
| 326 |
+
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
| 327 |
+
b.kv_cache.k_cache[:, :, : k.size(2), :] = k
|
| 328 |
+
b.kv_cache.v_cache[:, :, : v.size(2), :] = v
|
| 329 |
+
|
| 330 |
def caption(
|
| 331 |
self,
|
| 332 |
image: Union[Image.Image, EncodedImage],
|
| 333 |
+
length: Literal["normal", "short", "long"] = "normal",
|
| 334 |
stream: bool = False,
|
| 335 |
settings: Optional[SamplingSettings] = None,
|
| 336 |
):
|
|
|
|
| 340 |
raise ValueError(f"Model does not support caption length '{length}'.")
|
| 341 |
|
| 342 |
image = self.encode_image(image)
|
| 343 |
+
self.load_encoded_image(image)
|
| 344 |
+
|
| 345 |
prompt_tokens = torch.tensor(
|
| 346 |
[self.config.tokenizer.templates["caption"][length]], device=self.device
|
| 347 |
)
|
|
|
|
| 351 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
| 352 |
|
| 353 |
def generator():
|
| 354 |
+
for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
|
|
|
|
|
|
|
| 355 |
yield token
|
| 356 |
|
| 357 |
if stream:
|
|
|
|
| 362 |
def _generate_points(
|
| 363 |
self,
|
| 364 |
hidden: torch.Tensor,
|
|
|
|
| 365 |
next_token: torch.Tensor,
|
| 366 |
pos: int,
|
| 367 |
include_size: bool = True,
|
| 368 |
max_points: int = 50,
|
| 369 |
):
|
| 370 |
out = []
|
| 371 |
+
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
|
| 372 |
+
mask[:, :, :pos] = 1
|
| 373 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
| 374 |
|
| 375 |
+
with torch.inference_mode():
|
| 376 |
while (
|
| 377 |
next_token.item() != self.config.tokenizer.eos_id
|
| 378 |
and len(out) < max_points
|
|
|
|
| 384 |
)
|
| 385 |
|
| 386 |
# Decode y-coordinate
|
| 387 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 388 |
+
_, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
pos += 1
|
| 390 |
y_logits = decode_coordinate(hidden, self.region)
|
| 391 |
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
|
|
|
| 395 |
|
| 396 |
# Decode size
|
| 397 |
if include_size:
|
| 398 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 399 |
+
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
pos += 1
|
| 401 |
size_logits = decode_size(hidden, self.region)
|
| 402 |
+
|
| 403 |
+
# Get bin indices from the logits
|
| 404 |
+
w_bin = torch.argmax(size_logits[0], dim=-1)
|
| 405 |
+
h_bin = torch.argmax(size_logits[1], dim=-1)
|
| 406 |
+
|
| 407 |
+
# Convert from bin indices to actual size values using the inverse of the log-scale mapping
|
| 408 |
+
# Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
| 409 |
+
w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
|
| 410 |
+
h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
|
| 411 |
+
|
| 412 |
next_emb = encode_size(
|
| 413 |
torch.tensor(
|
| 414 |
[w, h], device=self.device, dtype=size_logits.dtype
|
|
|
|
| 429 |
out.append({"x": x_center.item(), "y": y_center.item()})
|
| 430 |
|
| 431 |
# Decode next token (x-coordinate, or eos)
|
| 432 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
| 433 |
+
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
pos += 1
|
| 435 |
next_token = torch.argmax(logits, dim=-1)
|
| 436 |
|
|
|
|
| 446 |
raise NotImplementedError("Model does not support object detection.")
|
| 447 |
|
| 448 |
image = self.encode_image(image)
|
| 449 |
+
self.load_encoded_image(image)
|
| 450 |
+
|
| 451 |
prompt_tokens = torch.tensor(
|
| 452 |
[
|
| 453 |
self.config.tokenizer.templates["detect"]["prefix"]
|
| 454 |
+
+ self.tokenizer.encode(" " + object).ids
|
| 455 |
+ self.config.tokenizer.templates["detect"]["suffix"]
|
| 456 |
],
|
| 457 |
device=self.device,
|
| 458 |
)
|
| 459 |
|
| 460 |
+
_, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
|
|
|
|
|
|
|
|
|
|
| 461 |
hidden = hidden[:, -1:, :]
|
| 462 |
|
| 463 |
objects = self._generate_points(
|
| 464 |
+
hidden, next_token, pos, include_size=True, max_points=50
|
| 465 |
)
|
| 466 |
|
| 467 |
return {"objects": objects}
|
|
|
|
| 476 |
raise NotImplementedError("Model does not support pointing.")
|
| 477 |
|
| 478 |
image = self.encode_image(image)
|
| 479 |
+
self.load_encoded_image(image)
|
| 480 |
+
|
| 481 |
prompt_tokens = torch.tensor(
|
| 482 |
[
|
| 483 |
self.config.tokenizer.templates["point"]["prefix"]
|
| 484 |
+
+ self.tokenizer.encode(" " + object).ids
|
| 485 |
+ self.config.tokenizer.templates["point"]["suffix"]
|
| 486 |
],
|
| 487 |
device=self.device,
|
| 488 |
)
|
| 489 |
|
| 490 |
+
_, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
|
|
|
|
|
|
|
|
|
|
| 491 |
hidden = hidden[:, -1:, :]
|
| 492 |
|
| 493 |
objects = self._generate_points(
|
| 494 |
+
hidden, next_token, pos, include_size=False, max_points=50
|
| 495 |
)
|
| 496 |
|
| 497 |
return {"points": objects}
|
|
|
|
| 502 |
source: Tuple[float, float],
|
| 503 |
force_detect: bool = False,
|
| 504 |
):
|
| 505 |
+
with torch.inference_mode():
|
| 506 |
before_emb = text_encoder(
|
| 507 |
torch.tensor(
|
| 508 |
[self.tokenizer.encode("\n\nPoint:").ids], device=self.device
|
|
|
|
| 526 |
|
| 527 |
prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
|
| 528 |
|
| 529 |
+
self.load_encoded_image(image)
|
| 530 |
+
|
| 531 |
+
mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
|
| 532 |
+
pos_ids = torch.arange(
|
| 533 |
+
image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
|
| 534 |
)
|
| 535 |
+
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
| 536 |
logits = lm_head(hidden, self.text)
|
| 537 |
next_token = torch.argmax(logits, dim=-1)
|
| 538 |
pos = image.pos + prompt_emb.size(1)
|
|
|
|
| 545 |
return None
|
| 546 |
|
| 547 |
gaze = self._generate_points(
|
| 548 |
+
hidden, next_token, pos, include_size=False, max_points=1
|
| 549 |
)
|
| 550 |
return gaze[0]
|
| 551 |
|
|
|
|
| 639 |
)
|
| 640 |
|
| 641 |
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def _is_cjk_char(cp):
|
| 645 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 646 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 647 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 648 |
+
if (
|
| 649 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 650 |
+
or (cp >= 0x3400 and cp <= 0x4DBF)
|
| 651 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
| 652 |
+
):
|
| 653 |
+
return True
|
| 654 |
+
return False
|
region.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import math
|
| 3 |
|
| 4 |
-
from .weights import RegionModel
|
| 5 |
from .layers import linear, mlp
|
| 6 |
|
| 7 |
|
|
@@ -25,7 +25,7 @@ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
| 25 |
return torch.cat([f.cos(), f.sin()], dim=-1)
|
| 26 |
|
| 27 |
|
| 28 |
-
def encode_coordinate(coord: torch.Tensor, w:
|
| 29 |
"""
|
| 30 |
Takes as input a tensor containing a single float coordinate value (x or y)
|
| 31 |
and encodes it into hidden states for input to the text model.
|
|
@@ -39,7 +39,7 @@ def encode_coordinate(coord: torch.Tensor, w: RegionModel) -> torch.Tensor:
|
|
| 39 |
return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
|
| 40 |
|
| 41 |
|
| 42 |
-
def decode_coordinate(hidden_state: torch.Tensor, w:
|
| 43 |
"""
|
| 44 |
Takes as input the last hidden state from the text model and outputs a single logit
|
| 45 |
representing either an x or y coordinate prediction.
|
|
@@ -53,13 +53,13 @@ def decode_coordinate(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tenso
|
|
| 53 |
return mlp(hidden_state, w.coord_decoder)
|
| 54 |
|
| 55 |
|
| 56 |
-
def encode_size(size: torch.Tensor, w:
|
| 57 |
"""
|
| 58 |
-
Takes a tensor containing
|
| 59 |
-
|
| 60 |
|
| 61 |
Args:
|
| 62 |
-
size: Tensor with two floats for width and height
|
| 63 |
|
| 64 |
Returns:
|
| 65 |
Encoded hidden states tensor for input to text model
|
|
@@ -67,16 +67,23 @@ def encode_size(size: torch.Tensor, w: RegionModel) -> torch.Tensor:
|
|
| 67 |
return linear(fourier_features(size, w.size_features), w.size_encoder)
|
| 68 |
|
| 69 |
|
| 70 |
-
def decode_size(hidden_state: torch.Tensor, w:
|
| 71 |
"""
|
| 72 |
-
Takes as input the last hidden state from the text model and outputs
|
| 73 |
-
for width and height
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
Args:
|
| 76 |
hidden_state: The final hidden state tensor from the text model.
|
| 77 |
|
| 78 |
Returns:
|
| 79 |
-
A tensor containing
|
| 80 |
-
|
| 81 |
"""
|
| 82 |
return mlp(hidden_state, w.size_decoder).view(2, -1)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
import math
|
| 4 |
|
|
|
|
| 5 |
from .layers import linear, mlp
|
| 6 |
|
| 7 |
|
|
|
|
| 25 |
return torch.cat([f.cos(), f.sin()], dim=-1)
|
| 26 |
|
| 27 |
|
| 28 |
+
def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 29 |
"""
|
| 30 |
Takes as input a tensor containing a single float coordinate value (x or y)
|
| 31 |
and encodes it into hidden states for input to the text model.
|
|
|
|
| 39 |
return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
|
| 40 |
|
| 41 |
|
| 42 |
+
def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 43 |
"""
|
| 44 |
Takes as input the last hidden state from the text model and outputs a single logit
|
| 45 |
representing either an x or y coordinate prediction.
|
|
|
|
| 53 |
return mlp(hidden_state, w.coord_decoder)
|
| 54 |
|
| 55 |
|
| 56 |
+
def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 57 |
"""
|
| 58 |
+
Takes a tensor containing width and height values and encodes them into
|
| 59 |
+
hidden states for input to the text model.
|
| 60 |
|
| 61 |
Args:
|
| 62 |
+
size: Tensor with two floats for width and height
|
| 63 |
|
| 64 |
Returns:
|
| 65 |
Encoded hidden states tensor for input to text model
|
|
|
|
| 67 |
return linear(fourier_features(size, w.size_features), w.size_encoder)
|
| 68 |
|
| 69 |
|
| 70 |
+
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 71 |
"""
|
| 72 |
+
Takes as input the last hidden state from the text model and outputs logits
|
| 73 |
+
for 1024 bins representing width and height in log-scale.
|
| 74 |
+
|
| 75 |
+
The bins are distributed according to the formula:
|
| 76 |
+
bin = (log2(size) + 10.0) / 10.0 * 1023.0
|
| 77 |
+
where size values are clamped to be at least 1/1024.
|
| 78 |
+
|
| 79 |
+
To convert from bin back to size:
|
| 80 |
+
size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
| 81 |
|
| 82 |
Args:
|
| 83 |
hidden_state: The final hidden state tensor from the text model.
|
| 84 |
|
| 85 |
Returns:
|
| 86 |
+
A tensor containing logits for 1024 bins for width and height.
|
| 87 |
+
Shape is (2, 1024) where the first dimension corresponds to width and height.
|
| 88 |
"""
|
| 89 |
return mlp(hidden_state, w.size_decoder).view(2, -1)
|
text.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
|
| 5 |
-
from .layers import layer_norm,
|
| 6 |
from .rope import apply_rotary_emb, precompute_freqs_cis
|
| 7 |
-
from .weights import AttentionWeights
|
| 8 |
from .config import TextConfig
|
| 9 |
|
| 10 |
|
|
@@ -14,106 +14,153 @@ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
|
|
| 14 |
|
| 15 |
def attn(
|
| 16 |
x: torch.Tensor,
|
| 17 |
-
w:
|
| 18 |
freqs_cis: torch.Tensor,
|
| 19 |
-
|
| 20 |
attn_mask: torch.Tensor,
|
| 21 |
n_heads: int,
|
| 22 |
-
|
|
|
|
| 23 |
):
|
| 24 |
bsz, q_len, d_model = x.shape
|
| 25 |
head_dim = d_model // n_heads
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
|
| 33 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
| 34 |
-
k = apply_rotary_emb(k, freqs_cis, position_ids,
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).to(
|
| 42 |
-
# This type conversion isn't needed when running in PyTorch directly, but the
|
| 43 |
-
# ONNX export runs attention in float32 because the attention mask is cast to
|
| 44 |
-
# float32.
|
| 45 |
-
x.dtype
|
| 46 |
)
|
| 47 |
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
| 48 |
-
out =
|
| 49 |
-
return out
|
| 50 |
|
| 51 |
|
| 52 |
-
def
|
| 53 |
-
|
| 54 |
-
w:
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
| 58 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
hidden_BTC = inputs_embeds
|
| 60 |
-
new_kv_cache = [torch.empty(0)] * len(w.blocks)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
]
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
for i, block in enumerate(w.blocks):
|
| 67 |
l_in = layer_norm(hidden_BTC, block.ln)
|
| 68 |
-
l_attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
l_in,
|
| 70 |
block.attn,
|
| 71 |
freqs_cis=w.freqs_cis,
|
| 72 |
-
|
| 73 |
attn_mask=attn_mask,
|
| 74 |
n_heads=config.n_heads,
|
| 75 |
-
|
|
|
|
| 76 |
)
|
| 77 |
l_mlp = mlp(l_in, block.mlp)
|
| 78 |
-
|
| 79 |
|
| 80 |
-
return
|
| 81 |
|
| 82 |
|
| 83 |
def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
| 84 |
hidden_BC = hidden_BTC[:, -1, :]
|
| 85 |
hidden_BC = layer_norm(hidden_BC, w.post_ln)
|
| 86 |
-
logits =
|
| 87 |
return logits
|
| 88 |
|
| 89 |
|
| 90 |
-
def
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
w: nn.Module,
|
| 95 |
-
config: TextConfig,
|
| 96 |
-
):
|
| 97 |
-
# Updates kv_cache in-place
|
| 98 |
-
hidden, kv_cache[:, :, :, :, pos : pos + inputs_embeds.size(1), :] = text_decoder(
|
| 99 |
-
inputs_embeds, w, kv_cache, pos, config
|
| 100 |
-
)
|
| 101 |
-
return hidden
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def decode_one_token(
|
| 105 |
-
token_emb: torch.Tensor,
|
| 106 |
-
kv_cache: torch.Tensor,
|
| 107 |
-
pos: int,
|
| 108 |
-
w: nn.Module,
|
| 109 |
-
config: TextConfig,
|
| 110 |
-
):
|
| 111 |
-
hidden, kv_cache_update = text_decoder(token_emb[None], w, kv_cache, pos, config)
|
| 112 |
-
logits = lm_head(hidden, w)
|
| 113 |
-
return logits, hidden, kv_cache_update
|
| 114 |
|
| 115 |
|
| 116 |
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
|
|
|
|
|
| 117 |
text = nn.ModuleDict(
|
| 118 |
{
|
| 119 |
"blocks": nn.ModuleList(
|
|
@@ -123,9 +170,7 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
| 123 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
| 124 |
"attn": nn.ModuleDict(
|
| 125 |
{
|
| 126 |
-
"qkv": nn.Linear(
|
| 127 |
-
config.dim, 3 * config.dim, dtype=dtype
|
| 128 |
-
),
|
| 129 |
"proj": nn.Linear(
|
| 130 |
config.dim, config.dim, dtype=dtype
|
| 131 |
),
|
|
@@ -134,10 +179,10 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
| 134 |
"mlp": nn.ModuleDict(
|
| 135 |
{
|
| 136 |
"fc1": nn.Linear(
|
| 137 |
-
config.dim,
|
| 138 |
),
|
| 139 |
"fc2": nn.Linear(
|
| 140 |
-
|
| 141 |
),
|
| 142 |
}
|
| 143 |
),
|
|
@@ -157,11 +202,4 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
| 157 |
persistent=False,
|
| 158 |
)
|
| 159 |
|
| 160 |
-
attn_mask = torch.tril(
|
| 161 |
-
torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool)
|
| 162 |
-
)
|
| 163 |
-
if config.prefix_attn != 0:
|
| 164 |
-
attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1
|
| 165 |
-
text.register_buffer("attn_mask", attn_mask, persistent=False)
|
| 166 |
-
|
| 167 |
return text
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
+
|
| 4 |
from torch.nn import functional as F
|
| 5 |
|
| 6 |
+
from .layers import layer_norm, mlp
|
| 7 |
from .rope import apply_rotary_emb, precompute_freqs_cis
|
|
|
|
| 8 |
from .config import TextConfig
|
| 9 |
|
| 10 |
|
|
|
|
| 14 |
|
| 15 |
def attn(
|
| 16 |
x: torch.Tensor,
|
| 17 |
+
w: nn.Module,
|
| 18 |
freqs_cis: torch.Tensor,
|
| 19 |
+
kv_cache: nn.Module,
|
| 20 |
attn_mask: torch.Tensor,
|
| 21 |
n_heads: int,
|
| 22 |
+
n_kv_heads: int,
|
| 23 |
+
position_ids: torch.Tensor,
|
| 24 |
):
|
| 25 |
bsz, q_len, d_model = x.shape
|
| 26 |
head_dim = d_model // n_heads
|
| 27 |
|
| 28 |
+
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
| 29 |
+
q_dim = n_heads * head_dim
|
| 30 |
+
kv_dim = n_kv_heads * head_dim
|
| 31 |
+
|
| 32 |
+
q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
| 33 |
+
k = (
|
| 34 |
+
qkv_out[..., q_dim : q_dim + kv_dim]
|
| 35 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
| 36 |
+
.transpose(1, 2)
|
| 37 |
+
)
|
| 38 |
+
v = (
|
| 39 |
+
qkv_out[..., q_dim + kv_dim :]
|
| 40 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
| 41 |
+
.transpose(1, 2)
|
| 42 |
+
)
|
| 43 |
|
|
|
|
| 44 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
| 45 |
+
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
| 46 |
+
|
| 47 |
+
if kv_cache is not None:
|
| 48 |
+
k, v = kv_cache.update(position_ids, k, v)
|
| 49 |
+
|
| 50 |
+
out = F.scaled_dot_product_attention(
|
| 51 |
+
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
)
|
| 53 |
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
| 54 |
+
out = w.proj(out)
|
| 55 |
+
return out
|
| 56 |
|
| 57 |
|
| 58 |
+
def _attn(
|
| 59 |
+
x: torch.Tensor,
|
| 60 |
+
w: torch.Tensor,
|
| 61 |
+
freqs_cis: torch.Tensor,
|
| 62 |
+
attn_mask: torch.Tensor,
|
| 63 |
+
n_heads: int,
|
| 64 |
+
n_kv_heads: int,
|
| 65 |
):
|
| 66 |
+
bsz, q_len, d_model = x.shape
|
| 67 |
+
head_dim = d_model // n_heads
|
| 68 |
+
pos = 0
|
| 69 |
+
|
| 70 |
+
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
| 71 |
+
q_dim = n_heads * head_dim
|
| 72 |
+
kv_dim = n_kv_heads * head_dim
|
| 73 |
+
|
| 74 |
+
q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
| 75 |
+
k = (
|
| 76 |
+
qkv_out[..., q_dim : q_dim + kv_dim]
|
| 77 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
| 78 |
+
.transpose(1, 2)
|
| 79 |
+
)
|
| 80 |
+
v = (
|
| 81 |
+
qkv_out[..., q_dim + kv_dim :]
|
| 82 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
| 83 |
+
.transpose(1, 2)
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
|
| 87 |
+
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
| 88 |
+
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
| 89 |
+
out = F.scaled_dot_product_attention(
|
| 90 |
+
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
|
| 91 |
+
)
|
| 92 |
+
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
| 93 |
+
out = w.proj(out)
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
|
| 98 |
hidden_BTC = inputs_embeds
|
|
|
|
| 99 |
|
| 100 |
+
bsz, q_len, d_model = inputs_embeds.shape
|
| 101 |
+
attn_mask = torch.zeros(q_len, q_len)
|
| 102 |
+
attn_mask[:730, :730] = 1
|
| 103 |
+
for i in range(730, q_len):
|
| 104 |
+
attn_mask[i, : i + 1] = 1
|
| 105 |
+
attn_mask = attn_mask.to(dtype=torch.bool)
|
| 106 |
|
| 107 |
for i, block in enumerate(w.blocks):
|
| 108 |
l_in = layer_norm(hidden_BTC, block.ln)
|
| 109 |
+
l_attn = _attn(
|
| 110 |
+
x=l_in,
|
| 111 |
+
w=block.attn,
|
| 112 |
+
freqs_cis=w.freqs_cis,
|
| 113 |
+
attn_mask=attn_mask,
|
| 114 |
+
n_heads=config.n_heads,
|
| 115 |
+
n_kv_heads=config.n_kv_heads,
|
| 116 |
+
)
|
| 117 |
+
l_mlp = mlp(l_in, block.mlp)
|
| 118 |
+
hidden_BTC = hidden_BTC + l_attn + l_mlp
|
| 119 |
+
|
| 120 |
+
return hidden_BTC
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def text_decoder(
|
| 124 |
+
x: torch.Tensor,
|
| 125 |
+
w: nn.Module,
|
| 126 |
+
attn_mask: torch.Tensor,
|
| 127 |
+
position_ids: torch.Tensor,
|
| 128 |
+
config: TextConfig,
|
| 129 |
+
):
|
| 130 |
+
for i, block in enumerate(w.blocks):
|
| 131 |
+
l_in = layer_norm(x, block.ln)
|
| 132 |
+
l_attn = attn(
|
| 133 |
l_in,
|
| 134 |
block.attn,
|
| 135 |
freqs_cis=w.freqs_cis,
|
| 136 |
+
kv_cache=block.kv_cache,
|
| 137 |
attn_mask=attn_mask,
|
| 138 |
n_heads=config.n_heads,
|
| 139 |
+
n_kv_heads=config.n_kv_heads,
|
| 140 |
+
position_ids=position_ids,
|
| 141 |
)
|
| 142 |
l_mlp = mlp(l_in, block.mlp)
|
| 143 |
+
x = x + l_attn + l_mlp
|
| 144 |
|
| 145 |
+
return x
|
| 146 |
|
| 147 |
|
| 148 |
def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
| 149 |
hidden_BC = hidden_BTC[:, -1, :]
|
| 150 |
hidden_BC = layer_norm(hidden_BC, w.post_ln)
|
| 151 |
+
logits = w.lm_head(hidden_BC)
|
| 152 |
return logits
|
| 153 |
|
| 154 |
|
| 155 |
+
def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
| 156 |
+
hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
|
| 157 |
+
logits = w.lm_head(hidden_BTC)
|
| 158 |
+
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
| 162 |
+
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
| 163 |
+
|
| 164 |
text = nn.ModuleDict(
|
| 165 |
{
|
| 166 |
"blocks": nn.ModuleList(
|
|
|
|
| 170 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
| 171 |
"attn": nn.ModuleDict(
|
| 172 |
{
|
| 173 |
+
"qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
|
|
|
|
|
|
|
| 174 |
"proj": nn.Linear(
|
| 175 |
config.dim, config.dim, dtype=dtype
|
| 176 |
),
|
|
|
|
| 179 |
"mlp": nn.ModuleDict(
|
| 180 |
{
|
| 181 |
"fc1": nn.Linear(
|
| 182 |
+
config.dim, config.ff_dim, dtype=dtype
|
| 183 |
),
|
| 184 |
"fc2": nn.Linear(
|
| 185 |
+
config.ff_dim, config.dim, dtype=dtype
|
| 186 |
),
|
| 187 |
}
|
| 188 |
),
|
|
|
|
| 202 |
persistent=False,
|
| 203 |
)
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
return text
|
vision.py
CHANGED
|
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
|
| 4 |
import numpy as np
|
| 5 |
|
| 6 |
from typing import Union, Tuple
|
| 7 |
-
from einops import rearrange
|
| 8 |
from PIL import Image
|
| 9 |
|
| 10 |
from .layers import attn, layer_norm, linear, mlp
|
|
@@ -42,13 +41,28 @@ def prepare_crops(
|
|
| 42 |
return all_crops, overlap_crops["tiling"]
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
| 46 |
-
x =
|
| 47 |
-
input_BCHW,
|
| 48 |
-
"b c (h p1) (w p2) -> b (h w) (c p1 p2)",
|
| 49 |
-
p1=config.enc_patch_size,
|
| 50 |
-
p2=config.enc_patch_size,
|
| 51 |
-
) # B3HW -> B(HxW)(3xP1xP2), aka BTC
|
| 52 |
|
| 53 |
x = linear(x, w.patch_emb)
|
| 54 |
x = x + w.pos_emb
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
|
| 6 |
from typing import Union, Tuple
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
from .layers import attn, layer_norm, linear, mlp
|
|
|
|
| 41 |
return all_crops, overlap_crops["tiling"]
|
| 42 |
|
| 43 |
|
| 44 |
+
def create_patches(x, patch_size):
|
| 45 |
+
# Original shape: [B, C, H, W]
|
| 46 |
+
B, C, H, W = x.shape
|
| 47 |
+
P1 = P2 = patch_size
|
| 48 |
+
|
| 49 |
+
# Step 1: Split H and W dimensions into patches
|
| 50 |
+
# [B, C, H/P1, P1, W/P2, P2]
|
| 51 |
+
x = x.reshape(B, C, H // P1, P1, W // P2, P2)
|
| 52 |
+
|
| 53 |
+
# Step 2: Rearrange dimensions to match target shape
|
| 54 |
+
# [B, H/P1, W/P2, C, P1, P2]
|
| 55 |
+
x = x.permute(0, 2, 4, 1, 3, 5)
|
| 56 |
+
|
| 57 |
+
# Step 3: Combine dimensions to get final shape
|
| 58 |
+
# [B, (H/P1)*(W/P2), C*P1*P2]
|
| 59 |
+
x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
|
| 60 |
+
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
| 65 |
+
x = create_patches(input_BCHW, config.enc_patch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
x = linear(x, w.patch_emb)
|
| 68 |
x = x + w.pos_emb
|