vikhyatk commited on
Commit
235555c
·
verified ·
1 Parent(s): 4f59175

Upload HfMoondream

Browse files
Files changed (11) hide show
  1. config.json +2 -2
  2. config.py +16 -8
  3. hf_moondream.py +46 -5
  4. image_crops.py +36 -13
  5. layers.py +109 -6
  6. lora.py +82 -0
  7. model.safetensors +2 -2
  8. moondream.py +330 -61
  9. region.py +50 -3
  10. text.py +34 -18
  11. vision.py +3 -3
config.json CHANGED
@@ -8,6 +8,6 @@
8
  },
9
  "config": {},
10
  "model_type": "moondream1",
11
- "torch_dtype": "float16",
12
- "transformers_version": "4.44.0"
13
  }
 
8
  },
9
  "config": {},
10
  "model_type": "moondream1",
11
+ "torch_dtype": "bfloat16",
12
+ "transformers_version": "4.52.4"
13
  }
config.py CHANGED
@@ -12,6 +12,7 @@ class TextConfig:
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
 
15
 
16
 
17
  @dataclass(frozen=True)
@@ -37,22 +38,29 @@ class RegionConfig:
37
  size_feat_dim: int = 512
38
  size_out_dim: int = 2048
39
  inner_dim: int = 8192
 
40
 
41
 
42
  @dataclass(frozen=True)
43
  class TokenizerConfig:
44
- bos_id: int = 50256
45
- eos_id: int = 50256
 
 
 
 
 
 
46
  templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
47
  default_factory=lambda: {
48
  "caption": {
49
- "short": [198, 198, 16438, 8305, 25],
50
- "normal": [198, 198, 24334, 1159, 25],
51
- "long": [198, 198, 14617, 8305, 25],
52
  },
53
- "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
54
- "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
55
- "point": {"prefix": [198, 198, 12727, 25], "suffix": [628]},
56
  }
57
  )
58
 
 
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
15
+ group_size: Optional[int] = None
16
 
17
 
18
  @dataclass(frozen=True)
 
38
  size_feat_dim: int = 512
39
  size_out_dim: int = 2048
40
  inner_dim: int = 8192
41
+ group_size: Optional[int] = None
42
 
43
 
44
  @dataclass(frozen=True)
45
  class TokenizerConfig:
46
+ bos_id: int = 0
47
+ eos_id: int = 0
48
+ answer_id: int = 3
49
+ thinking_id: int = 4
50
+ coord_id: int = 5
51
+ size_id: int = 6
52
+ start_ground_points_id: int = 7
53
+ end_ground_id: int = 9
54
  templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
55
  default_factory=lambda: {
56
  "caption": {
57
+ "short": [1, 32708, 2, 12492, 3],
58
+ "normal": [1, 32708, 2, 6382, 3],
59
+ "long": [1, 32708, 2, 4059, 3],
60
  },
61
+ "query": {"prefix": [1, 15381, 2], "suffix": [3]},
62
+ "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
63
+ "point": {"prefix": [1, 2581, 2], "suffix": [3]},
64
  }
65
  )
66
 
hf_moondream.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  from transformers import PreTrainedModel, PretrainedConfig
 
2
 
3
  from .config import MoondreamConfig
4
  from .moondream import MoondreamModel
@@ -123,7 +127,7 @@ class HfMoondream(PreTrainedModel):
123
  )
124
 
125
  def generator():
126
- for token in self.model._generate_text(
127
  prompt_tokens,
128
  image_embeds.kv_cache,
129
  image_embeds.pos,
@@ -135,8 +139,45 @@ class HfMoondream(PreTrainedModel):
135
 
136
  return [answer]
137
 
138
- def get_input_embeddings(self):
139
- return super().get_input_embeddings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- def input_embeds(self, *args, **kwargs):
142
- self._unsupported_exception()
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
  from transformers import PreTrainedModel, PretrainedConfig
5
+ from typing import Union
6
 
7
  from .config import MoondreamConfig
8
  from .moondream import MoondreamModel
 
127
  )
128
 
129
  def generator():
130
+ for token in self.model._generate_answer(
131
  prompt_tokens,
132
  image_embeds.kv_cache,
133
  image_embeds.pos,
 
139
 
140
  return [answer]
141
 
142
+ def get_input_embeddings(self) -> nn.Embedding:
143
+ """
144
+ Lazily wrap the raw parameter `self.model.text.wte` in a real
145
+ `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper
146
+ **shares** the weight tensor—no copy is made.
147
+ """
148
+ if not hasattr(self, "_input_embeddings"):
149
+ self._input_embeddings = nn.Embedding.from_pretrained(
150
+ self.model.text.wte, # tensor created in text.py
151
+ freeze=True, # set to False if you need it trainable
152
+ )
153
+ return self._input_embeddings
154
+
155
+ def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
156
+ """
157
+ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
158
+ embeddings and keeps everything tied to `self.model.text.wte`.
159
+ """
160
+ # 1. point the low-level parameter to the new weight matrix
161
+ self.model.text.wte = value.weight
162
+ # 2. keep a reference for get_input_embeddings()
163
+ self._input_embeddings = value
164
+
165
+ def input_embeds(
166
+ self,
167
+ input_ids: Union[torch.LongTensor, list, tuple],
168
+ *,
169
+ device: torch.device | None = None
170
+ ) -> torch.FloatTensor:
171
+ """
172
+ Back-compat wrapper that turns token IDs into embeddings.
173
+
174
+ Example:
175
+ ids = torch.tensor([[1, 2, 3]])
176
+ embeds = model.input_embeds(ids) # (1, 3, hidden_dim)
177
+ """
178
+ if not torch.is_tensor(input_ids):
179
+ input_ids = torch.as_tensor(input_ids)
180
+ if device is not None:
181
+ input_ids = input_ids.to(device)
182
 
183
+ return self.get_input_embeddings()(input_ids)
 
image_crops.py CHANGED
@@ -1,10 +1,18 @@
1
  import math
2
  import numpy as np
3
  import torch
4
- import pyvips
5
 
6
  from typing import TypedDict
7
 
 
 
 
 
 
 
 
 
 
8
 
9
  def select_tiling(
10
  height: int, width: int, crop_size: int, max_crops: int
@@ -113,18 +121,33 @@ def overlap_crop_image(
113
  tiling[1] * crop_window_size + total_margin_pixels,
114
  )
115
 
116
- # Convert to vips for resizing
117
- vips_image = pyvips.Image.new_from_array(image)
118
- scale_x = target_size[1] / image.shape[1]
119
- scale_y = target_size[0] / image.shape[0]
120
- resized = vips_image.resize(scale_x, vscale=scale_y)
121
- image = resized.numpy()
122
-
123
- # Create global crop
124
- scale_x = base_size[1] / vips_image.width
125
- scale_y = base_size[0] / vips_image.height
126
- global_vips = vips_image.resize(scale_x, vscale=scale_y)
127
- crops[0] = global_vips.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  for i in range(tiling[0]):
130
  for j in range(tiling[1]):
 
1
  import math
2
  import numpy as np
3
  import torch
 
4
 
5
  from typing import TypedDict
6
 
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
 
17
  def select_tiling(
18
  height: int, width: int, crop_size: int, max_crops: int
 
121
  tiling[1] * crop_window_size + total_margin_pixels,
122
  )
123
 
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
 
152
  for i in range(tiling[0]):
153
  for j in range(tiling[1]):
layers.py CHANGED
@@ -1,8 +1,24 @@
 
 
 
 
1
  from dataclasses import dataclass
2
- from typing import Literal
3
 
4
- import torch
5
- from torch.nn import functional as F
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def gelu_approx(x):
@@ -19,6 +35,80 @@ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
19
  return F.linear(x, w.weight, w.bias)
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @dataclass
23
  class LayerNormWeights:
24
  weight: torch.Tensor
@@ -36,10 +126,23 @@ class MLPWeights:
36
  act: Literal["gelu_approx"] = "gelu_approx"
37
 
38
 
39
- def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
40
- x = w.fc1(x)
 
 
 
 
 
 
41
  x = gelu_approx(x)
42
- x = w.fc2(x)
 
 
 
 
 
 
 
43
  return x
44
 
45
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
  from dataclasses import dataclass
6
+ from typing import Literal, Optional
7
 
8
+ try:
9
+ from torchao import quantize_
10
+ from torchao.quantization import int4_weight_only
11
+ except ImportError:
12
+
13
+ def quantize_(model, quant_mode):
14
+ raise ImportError(
15
+ "torchao is not installed. Please install it with `pip install torchao`."
16
+ )
17
+
18
+ def int4_weight_only(group_size):
19
+ raise ImportError(
20
+ "torchao is not installed. Please install it with `pip install torchao`."
21
+ )
22
 
23
 
24
  def gelu_approx(x):
 
35
  return F.linear(x, w.weight, w.bias)
36
 
37
 
38
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39
+ _step = W_q.shape[0]
40
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
41
+ W_r[:_step] = (W_q & 0b11110000) >> 4
42
+ W_r[_step:] = W_q & 0b00001111
43
+ W_r.sub_(zero).mul_(scale)
44
+ return W_r.reshape(orig_shape)
45
+
46
+
47
+ class QuantizedLinear(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_features: int,
51
+ out_features: int,
52
+ dtype: torch.dtype,
53
+ ):
54
+ # TODO: Take group_size as an input instead of hardcoding it here.
55
+ super().__init__()
56
+ self.in_features = in_features
57
+ self.out_features = out_features
58
+ self.weight = nn.ParameterDict(
59
+ {
60
+ "packed": nn.Parameter(
61
+ torch.empty(
62
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
63
+ ),
64
+ requires_grad=False,
65
+ ),
66
+ "scale": nn.Parameter(
67
+ torch.empty(out_features * in_features // 128, 1),
68
+ requires_grad=False,
69
+ ),
70
+ "zero_point": nn.Parameter(
71
+ torch.empty(out_features * in_features // 128, 1),
72
+ requires_grad=False,
73
+ ),
74
+ }
75
+ )
76
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
77
+ self.unpacked = False
78
+
79
+ def unpack(self):
80
+ if self.unpacked:
81
+ return
82
+
83
+ self.weight = nn.Parameter(
84
+ dequantize_tensor(
85
+ self.weight["packed"],
86
+ self.weight["scale"],
87
+ self.weight["zero_point"],
88
+ (self.out_features, self.in_features),
89
+ torch.bfloat16,
90
+ )
91
+ )
92
+ with torch.device("meta"):
93
+ self.linear = nn.Linear(
94
+ self.in_features, self.out_features, dtype=torch.bfloat16
95
+ )
96
+ self.linear.weight = self.weight
97
+ self.linear.bias = nn.Parameter(
98
+ self.bias.to(torch.bfloat16), requires_grad=False
99
+ )
100
+
101
+ del self.weight, self.bias
102
+ quantize_(self, int4_weight_only(group_size=128))
103
+ self.unpacked = True
104
+ torch.cuda.empty_cache()
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ if not self.unpacked:
108
+ self.unpack()
109
+ return self.linear(x)
110
+
111
+
112
  @dataclass
113
  class LayerNormWeights:
114
  weight: torch.Tensor
 
126
  act: Literal["gelu_approx"] = "gelu_approx"
127
 
128
 
129
+ def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
130
+ x0 = w.fc1(x)
131
+ if lora is not None:
132
+ x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
133
+ x = x0 + x1
134
+ else:
135
+ x = x0
136
+
137
  x = gelu_approx(x)
138
+
139
+ x0 = w.fc2(x)
140
+ if lora is not None:
141
+ x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
142
+ x = x0 + x1
143
+ else:
144
+ x = x0
145
+
146
  return x
147
 
148
 
lora.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import shutil
4
+ import torch
5
+
6
+ from pathlib import Path
7
+ from urllib.request import Request, urlopen
8
+ from typing import Optional
9
+
10
+
11
+ def variant_cache_dir():
12
+ hf_hub_cache = os.environ.get("HF_HUB_CACHE")
13
+ if hf_hub_cache is not None:
14
+ return Path(hf_hub_cache) / "md_variants"
15
+
16
+ hf_home = os.environ.get("HF_HOME")
17
+ if hf_home is not None:
18
+ return Path(hf_home) / "hub" / "md_variants"
19
+
20
+ return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"
21
+
22
+
23
+ def cached_variant_path(variant_id: str):
24
+ variant, *rest = variant_id.split("/", 1)
25
+ step = rest[0] if rest else "final"
26
+
27
+ cache_dir = variant_cache_dir() / variant
28
+ os.makedirs(cache_dir, exist_ok=True)
29
+ dest = cache_dir / f"{step}.pt"
30
+ if dest.exists():
31
+ return dest
32
+
33
+ md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")
34
+
35
+ headers = {"User-Agent": "moondream-torch"}
36
+ api_key = os.getenv("MOONDREAM_API_KEY")
37
+ if api_key is not None:
38
+ headers["X-Moondream-Auth"] = api_key
39
+
40
+ req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
41
+ with urlopen(req) as r, open(dest, "wb") as f:
42
+ shutil.copyfileobj(r, f)
43
+ return dest
44
+
45
+
46
+ def nest(flat):
47
+ tree = {}
48
+ for k, v in flat.items():
49
+ parts = k.split(".")
50
+ d = tree
51
+ for p in parts[:-1]:
52
+ d = d.setdefault(p, {})
53
+ d[parts[-1]] = v
54
+ return tree
55
+
56
+
57
+ @functools.lru_cache(maxsize=5)
58
+ def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
59
+ if variant_id is None:
60
+ return None
61
+
62
+ state_dict = torch.load(
63
+ cached_variant_path(variant_id), map_location=device, weights_only=True
64
+ )
65
+
66
+ # TODO: Move these into the training code that saves checkpoints...
67
+ rename_rules = [
68
+ ("text_model.transformer.h", "text.blocks"),
69
+ (".mixer", ".attn"),
70
+ (".out_proj", ".proj"),
71
+ (".Wqkv", ".qkv"),
72
+ (".parametrizations.weight.0", ""),
73
+ ]
74
+ new_state_dict = {}
75
+ for key, tensor in state_dict.items():
76
+ new_key = key
77
+ for old, new in rename_rules:
78
+ if old in new_key:
79
+ new_key = new_key.replace(old, new)
80
+ new_state_dict[new_key] = tensor
81
+
82
+ return nest(new_state_dict)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96dce588e4a319fde7af3c70fbf27e726f4850e22522d0fdc4b165d5e6003ad5
3
- size 3854538376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70a7d94c0c8349eb58ed2d9e636ef2d0916960f321ecabeac6354b8ba3d7403f
3
+ size 3854538968
moondream.py CHANGED
@@ -11,9 +11,23 @@ from .config import MoondreamConfig
11
  from .image_crops import reconstruct_from_crops
12
  from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
  from .text import build_text_model, text_encoder, lm_head, text_decoder
14
- from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
 
 
 
 
 
 
 
 
 
15
  from .utils import remove_outlier_points
16
 
 
 
 
 
 
17
 
18
  TextSamplingSettings = TypedDict(
19
  "TextSamplingSettings",
@@ -21,16 +35,18 @@ TextSamplingSettings = TypedDict(
21
  "max_tokens": int,
22
  "temperature": float,
23
  "top_p": float,
 
24
  },
25
  total=False,
26
  )
27
 
28
  ObjectSamplingSettings = TypedDict(
29
  "ObjectSamplingSettings",
30
- {"max_objects": int},
31
  total=False,
32
  )
33
 
 
34
  DEFAULT_MAX_TOKENS = 768
35
  DEFAULT_TEMPERATURE = 0.5
36
  DEFAULT_TOP_P = 0.3
@@ -63,43 +79,47 @@ class KVCache(nn.Module):
63
 
64
 
65
  class MoondreamModel(nn.Module):
66
- def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
 
 
 
67
  super().__init__()
68
  self.config = config
69
 
70
- self.tokenizer = Tokenizer.from_pretrained(
71
- "vikhyatk/moondream2", revision="2025-01-09"
72
- )
73
  self.vision = build_vision_model(config.vision, dtype)
74
  self.text = build_text_model(config.text, dtype)
75
 
76
  # Region Model
 
 
 
77
  self.region = nn.ModuleDict(
78
  {
79
- "coord_encoder": nn.Linear(
80
  config.region.coord_feat_dim, config.region.dim, dtype=dtype
81
  ),
82
  "coord_decoder": nn.ModuleDict(
83
  {
84
- "fc1": nn.Linear(
85
  config.region.dim, config.region.inner_dim, dtype=dtype
86
  ),
87
- "fc2": nn.Linear(
88
  config.region.inner_dim,
89
  config.region.coord_out_dim,
90
  dtype=dtype,
91
  ),
92
  }
93
  ),
94
- "size_encoder": nn.Linear(
95
  config.region.size_feat_dim, config.region.dim, dtype=dtype
96
  ),
97
  "size_decoder": nn.ModuleDict(
98
  {
99
- "fc1": nn.Linear(
100
  config.region.dim, config.region.inner_dim, dtype=dtype
101
  ),
102
- "fc2": nn.Linear(
103
  config.region.inner_dim,
104
  config.region.size_out_dim,
105
  dtype=dtype,
@@ -151,17 +171,31 @@ class MoondreamModel(nn.Module):
151
  def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
152
  return vision_projection(g, r, self.vision, self.config.vision)
153
 
154
- def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
155
- return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
 
 
 
 
 
 
156
 
157
  def _decode_one_tok(
158
- self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
 
 
 
 
159
  ):
160
- hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
161
  logits = lm_head(hidden, self.text)
162
  return logits, hidden
163
 
164
  def compile(self):
 
 
 
 
165
  # TODO: vision_projection is not being compiled
166
  self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
167
  self._prefill = torch.compile(self._prefill, fullgraph=True)
@@ -171,6 +205,7 @@ class MoondreamModel(nn.Module):
171
 
172
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
173
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
 
174
  torch._dynamo.mark_dynamic(all_crops, 0)
175
 
176
  outputs = self._vis_enc(all_crops)
@@ -192,12 +227,22 @@ class MoondreamModel(nn.Module):
192
 
193
  return self._vis_proj(global_features, reconstructed)
194
 
195
- def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
 
 
 
 
196
  if isinstance(image, EncodedImage):
197
  return image
198
  elif not isinstance(image, Image.Image):
199
  raise ValueError("image must be a PIL Image or EncodedImage")
200
 
 
 
 
 
 
 
201
  # Run through text model in addition to the vision encoder, to minimize
202
  # re-computation if multiple queries are performed on this image.
203
  with torch.inference_mode():
@@ -209,7 +254,7 @@ class MoondreamModel(nn.Module):
209
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
210
  mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
211
  pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
212
- self._prefill(inputs_embeds, mask, pos_ids)
213
 
214
  return EncodedImage(
215
  pos=inputs_embeds.size(1),
@@ -233,31 +278,167 @@ class MoondreamModel(nn.Module):
233
  return next_probs
234
 
235
  def _prefill_prompt(
236
- self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
 
 
 
 
 
 
 
237
  ):
238
  with torch.inference_mode():
239
  prompt_emb = text_encoder(prompt_tokens, self.text)
 
 
 
 
 
 
 
 
 
 
 
240
  torch._dynamo.mark_dynamic(prompt_emb, 1)
241
- mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
 
 
 
 
242
  pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
243
- hidden = self._prefill(prompt_emb, mask, pos_ids)
244
- logits = lm_head(hidden, self.text)
245
 
246
  if temperature == 0:
247
- next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
248
  else:
249
- probs = torch.softmax(logits / temperature, dim=-1)
250
  probs = self._apply_top_p(probs, top_p)
251
  next_token = torch.multinomial(probs, num_samples=1)
252
 
253
  pos = pos + prompt_emb.size(1)
254
- return logits, hidden, next_token, pos
255
 
256
- def _generate_text(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  self,
258
  prompt_tokens: torch.Tensor,
259
  pos: int,
260
  settings: Optional[TextSamplingSettings] = None,
 
 
 
261
  ):
262
  max_tokens = (
263
  settings.get("max_tokens", DEFAULT_MAX_TOKENS)
@@ -270,9 +451,21 @@ class MoondreamModel(nn.Module):
270
  else DEFAULT_TEMPERATURE
271
  )
272
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
 
 
 
 
 
 
273
 
274
  _, _, next_token, pos = self._prefill_prompt(
275
- prompt_tokens, pos, temperature, top_p
 
 
 
 
 
 
276
  )
277
 
278
  def generator(next_token, pos):
@@ -287,7 +480,7 @@ class MoondreamModel(nn.Module):
287
 
288
  while (
289
  next_token_id := next_token.item()
290
- ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
291
  # Add token to our cache
292
  token_cache.append(next_token_id)
293
 
@@ -307,7 +500,7 @@ class MoondreamModel(nn.Module):
307
  print_len += len(printable_text)
308
  if printable_text:
309
  yield printable_text
310
- # Otherwise, only print up to the last space to avoid cutting words
311
  else:
312
  last_space_idx = text.rfind(" ", print_len)
313
  if last_space_idx >= print_len:
@@ -319,13 +512,18 @@ class MoondreamModel(nn.Module):
319
  with torch.inference_mode():
320
  next_emb = text_encoder(next_token, self.text)
321
  mask[:, :, pos], pos_ids[0] = 1, pos
322
- logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
 
 
 
323
  pos += 1
324
 
325
  if temperature == 0:
326
- next_token = torch.argmax(logits, dim=-1).unsqueeze(1) # (1, 1)
 
 
327
  else:
328
- probs = torch.softmax(logits / temperature, dim=-1) # (1, V)
329
  probs = self._apply_top_p(probs, top_p)
330
  next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
331
 
@@ -342,34 +540,82 @@ class MoondreamModel(nn.Module):
342
 
343
  def query(
344
  self,
345
- image: Union[Image.Image, EncodedImage],
346
- question: str,
 
 
347
  stream: bool = False,
348
  settings: Optional[TextSamplingSettings] = None,
349
  ):
350
  if self.config.tokenizer.templates["query"] is None:
351
  raise NotImplementedError("Model does not support querying.")
352
 
353
- image = self.encode_image(image)
354
- self.load_encoded_image(image)
355
 
356
- prompt_tokens = torch.tensor(
357
- [
358
- self.config.tokenizer.templates["query"]["prefix"]
359
- + self.tokenizer.encode(" " + question).ids
360
- + self.config.tokenizer.templates["query"]["suffix"]
361
- ],
362
- device=self.device,
363
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  def generator():
366
- for token in self._generate_text(prompt_tokens, image.pos, settings):
 
 
367
  yield token
368
 
369
  if stream:
370
- return {"answer": generator()}
371
  else:
372
- return {"answer": "".join(list(generator()))}
373
 
374
  def load_encoded_image(self, encoded_image: EncodedImage):
375
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
@@ -388,7 +634,7 @@ class MoondreamModel(nn.Module):
388
  if length not in self.config.tokenizer.templates["caption"]:
389
  raise ValueError(f"Model does not support caption length '{length}'.")
390
 
391
- image = self.encode_image(image)
392
  self.load_encoded_image(image)
393
 
394
  prompt_tokens = torch.tensor(
@@ -396,7 +642,7 @@ class MoondreamModel(nn.Module):
396
  )
397
 
398
  def generator():
399
- for token in self._generate_text(prompt_tokens, image.pos, settings):
400
  yield token
401
 
402
  if stream:
@@ -411,6 +657,7 @@ class MoondreamModel(nn.Module):
411
  pos: int,
412
  include_size: bool = True,
413
  max_objects: int = DEFAULT_MAX_OBJECTS,
 
414
  ):
415
  out = []
416
  mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
@@ -430,7 +677,7 @@ class MoondreamModel(nn.Module):
430
 
431
  # Decode y-coordinate
432
  mask[:, :, pos], pos_ids[0] = 1, pos
433
- _, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
434
  pos += 1
435
  y_logits = decode_coordinate(hidden, self.region)
436
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
@@ -441,7 +688,7 @@ class MoondreamModel(nn.Module):
441
  # Decode size
442
  if include_size:
443
  mask[:, :, pos], pos_ids[0] = 1, pos
444
- logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
445
  pos += 1
446
  size_logits = decode_size(hidden, self.region)
447
 
@@ -479,7 +726,7 @@ class MoondreamModel(nn.Module):
479
 
480
  # Decode next token (x-coordinate, or eos)
481
  mask[:, :, pos], pos_ids[0] = 1, pos
482
- logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
483
  pos += 1
484
  next_token = torch.argmax(logits, dim=-1)
485
 
@@ -494,7 +741,7 @@ class MoondreamModel(nn.Module):
494
  if self.config.tokenizer.templates["detect"] is None:
495
  raise NotImplementedError("Model does not support object detection.")
496
 
497
- image = self.encode_image(image)
498
  self.load_encoded_image(image)
499
 
500
  prompt_tokens = torch.tensor(
@@ -506,8 +753,14 @@ class MoondreamModel(nn.Module):
506
  device=self.device,
507
  )
508
 
 
 
 
 
 
 
509
  _, hidden, next_token, pos = self._prefill_prompt(
510
- prompt_tokens, image.pos, temperature=0, top_p=0
511
  )
512
  hidden = hidden[:, -1:, :]
513
 
@@ -517,7 +770,12 @@ class MoondreamModel(nn.Module):
517
  else DEFAULT_MAX_OBJECTS
518
  )
519
  objects = self._generate_points(
520
- hidden, next_token, pos, include_size=True, max_objects=max_objects
 
 
 
 
 
521
  )
522
 
523
  return {"objects": objects}
@@ -531,7 +789,7 @@ class MoondreamModel(nn.Module):
531
  if self.config.tokenizer.templates["point"] is None:
532
  raise NotImplementedError("Model does not support pointing.")
533
 
534
- image = self.encode_image(image)
535
  self.load_encoded_image(image)
536
 
537
  prompt_tokens = torch.tensor(
@@ -543,8 +801,14 @@ class MoondreamModel(nn.Module):
543
  device=self.device,
544
  )
545
 
 
 
 
 
 
 
546
  _, hidden, next_token, pos = self._prefill_prompt(
547
- prompt_tokens, image.pos, temperature=0, top_p=0
548
  )
549
  hidden = hidden[:, -1:, :]
550
 
@@ -554,7 +818,12 @@ class MoondreamModel(nn.Module):
554
  else DEFAULT_MAX_OBJECTS
555
  )
556
  objects = self._generate_points(
557
- hidden, next_token, pos, include_size=False, max_objects=max_objects
 
 
 
 
 
558
  )
559
 
560
  return {"points": objects}
@@ -579,11 +848,11 @@ class MoondreamModel(nn.Module):
579
  self.text,
580
  )
581
  x_emb = encode_coordinate(
582
- torch.tensor([[[source[0]]]], device=self.device, dtype=torch.float16),
583
  self.region,
584
  )
585
  y_emb = encode_coordinate(
586
- torch.tensor([[[source[1]]]], device=self.device, dtype=torch.float16),
587
  self.region,
588
  )
589
 
@@ -595,7 +864,7 @@ class MoondreamModel(nn.Module):
595
  pos_ids = torch.arange(
596
  image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
597
  )
598
- hidden = self._prefill(prompt_emb, mask, pos_ids)
599
  logits = lm_head(hidden, self.text)
600
  next_token = torch.argmax(logits, dim=-1)
601
  pos = image.pos + prompt_emb.size(1)
 
11
  from .image_crops import reconstruct_from_crops
12
  from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
  from .text import build_text_model, text_encoder, lm_head, text_decoder
14
+ from .region import (
15
+ decode_coordinate,
16
+ encode_coordinate,
17
+ decode_size,
18
+ encode_size,
19
+ encode_spatial_refs,
20
+ SpatialRefs,
21
+ )
22
+ from .layers import QuantizedLinear
23
+ from .lora import variant_state_dict
24
  from .utils import remove_outlier_points
25
 
26
+ ImageEncodingSettings = TypedDict(
27
+ "ImageEncodingSettings",
28
+ {"variant": str},
29
+ total=False,
30
+ )
31
 
32
  TextSamplingSettings = TypedDict(
33
  "TextSamplingSettings",
 
35
  "max_tokens": int,
36
  "temperature": float,
37
  "top_p": float,
38
+ "variant": str,
39
  },
40
  total=False,
41
  )
42
 
43
  ObjectSamplingSettings = TypedDict(
44
  "ObjectSamplingSettings",
45
+ {"max_objects": int, "variant": str},
46
  total=False,
47
  )
48
 
49
+
50
  DEFAULT_MAX_TOKENS = 768
51
  DEFAULT_TEMPERATURE = 0.5
52
  DEFAULT_TOP_P = 0.3
 
79
 
80
 
81
  class MoondreamModel(nn.Module):
82
+
83
+ def __init__(
84
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
85
+ ):
86
  super().__init__()
87
  self.config = config
88
 
89
+ self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
 
 
90
  self.vision = build_vision_model(config.vision, dtype)
91
  self.text = build_text_model(config.text, dtype)
92
 
93
  # Region Model
94
+ linear_cls = (
95
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
96
+ )
97
  self.region = nn.ModuleDict(
98
  {
99
+ "coord_encoder": linear_cls(
100
  config.region.coord_feat_dim, config.region.dim, dtype=dtype
101
  ),
102
  "coord_decoder": nn.ModuleDict(
103
  {
104
+ "fc1": linear_cls(
105
  config.region.dim, config.region.inner_dim, dtype=dtype
106
  ),
107
+ "fc2": linear_cls(
108
  config.region.inner_dim,
109
  config.region.coord_out_dim,
110
  dtype=dtype,
111
  ),
112
  }
113
  ),
114
+ "size_encoder": linear_cls(
115
  config.region.size_feat_dim, config.region.dim, dtype=dtype
116
  ),
117
  "size_decoder": nn.ModuleDict(
118
  {
119
+ "fc1": linear_cls(
120
  config.region.dim, config.region.inner_dim, dtype=dtype
121
  ),
122
+ "fc2": linear_cls(
123
  config.region.inner_dim,
124
  config.region.size_out_dim,
125
  dtype=dtype,
 
171
  def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
172
  return vision_projection(g, r, self.vision, self.config.vision)
173
 
174
+ def _prefill(
175
+ self,
176
+ x: torch.Tensor,
177
+ attn_mask: torch.Tensor,
178
+ pos_ids: torch.Tensor,
179
+ lora: Optional[torch.Tensor],
180
+ ):
181
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
182
 
183
  def _decode_one_tok(
184
+ self,
185
+ x: torch.Tensor,
186
+ attn_mask: torch.Tensor,
187
+ pos_ids: torch.Tensor,
188
+ lora: Optional[torch.Tensor],
189
  ):
190
+ hidden = text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)
191
  logits = lm_head(hidden, self.text)
192
  return logits, hidden
193
 
194
  def compile(self):
195
+ for module in self.modules():
196
+ if isinstance(module, QuantizedLinear):
197
+ module.unpack()
198
+
199
  # TODO: vision_projection is not being compiled
200
  self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
201
  self._prefill = torch.compile(self._prefill, fullgraph=True)
 
205
 
206
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
207
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
208
+
209
  torch._dynamo.mark_dynamic(all_crops, 0)
210
 
211
  outputs = self._vis_enc(all_crops)
 
227
 
228
  return self._vis_proj(global_features, reconstructed)
229
 
230
+ def encode_image(
231
+ self,
232
+ image: Union[Image.Image, EncodedImage],
233
+ settings: Optional[ImageEncodingSettings] = None,
234
+ ) -> EncodedImage:
235
  if isinstance(image, EncodedImage):
236
  return image
237
  elif not isinstance(image, Image.Image):
238
  raise ValueError("image must be a PIL Image or EncodedImage")
239
 
240
+ lora = (
241
+ variant_state_dict(settings["variant"], device=self.device)
242
+ if settings is not None and settings["variant"] is not None
243
+ else None
244
+ )
245
+
246
  # Run through text model in addition to the vision encoder, to minimize
247
  # re-computation if multiple queries are performed on this image.
248
  with torch.inference_mode():
 
254
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
255
  mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
256
  pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
257
+ self._prefill(inputs_embeds, mask, pos_ids, lora)
258
 
259
  return EncodedImage(
260
  pos=inputs_embeds.size(1),
 
278
  return next_probs
279
 
280
  def _prefill_prompt(
281
+ self,
282
+ prompt_tokens: torch.Tensor,
283
+ pos: int,
284
+ temperature: float,
285
+ top_p: float,
286
+ spatial_refs: Optional[SpatialRefs] = None,
287
+ attn_mask: Optional[torch.Tensor] = None,
288
+ lora: Optional[dict] = None,
289
  ):
290
  with torch.inference_mode():
291
  prompt_emb = text_encoder(prompt_tokens, self.text)
292
+
293
+ if spatial_refs:
294
+ encoded_refs = encode_spatial_refs(spatial_refs, self.region)
295
+ prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
296
+ encoded_refs["coords"]
297
+ )
298
+ if encoded_refs["sizes"] is not None:
299
+ prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
300
+ encoded_refs["sizes"]
301
+ )
302
+
303
  torch._dynamo.mark_dynamic(prompt_emb, 1)
304
+
305
+ if attn_mask is None:
306
+ attn_mask = self.attn_mask
307
+
308
+ mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
309
  pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
310
+ hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
311
+ logits_BV = lm_head(hidden_BC, self.text)
312
 
313
  if temperature == 0:
314
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
315
  else:
316
+ probs = torch.softmax(logits_BV / temperature, dim=-1)
317
  probs = self._apply_top_p(probs, top_p)
318
  next_token = torch.multinomial(probs, num_samples=1)
319
 
320
  pos = pos + prompt_emb.size(1)
321
+ return logits_BV, hidden_BC, next_token, pos
322
 
323
+ def _generate_reasoning(
324
+ self,
325
+ prompt_tokens,
326
+ pos,
327
+ settings: Optional[TextSamplingSettings] = None,
328
+ spatial_refs: Optional[SpatialRefs] = None,
329
+ attn_mask: Optional[torch.Tensor] = None,
330
+ ) -> Tuple[int, str, List[dict]]:
331
+ max_tokens = (
332
+ settings.get("max_tokens", DEFAULT_MAX_TOKENS)
333
+ if settings
334
+ else DEFAULT_MAX_TOKENS
335
+ )
336
+ temperature = (
337
+ settings.get("temperature", DEFAULT_TEMPERATURE)
338
+ if settings
339
+ else DEFAULT_TEMPERATURE
340
+ )
341
+ lora = (
342
+ variant_state_dict(settings["variant"], device=self.device)
343
+ if settings is not None and "variant" in settings
344
+ else None
345
+ )
346
+
347
+ top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
348
+ eos_id = self.config.tokenizer.answer_id
349
+
350
+ _, last_hidden_BC, next_token, pos = self._prefill_prompt(
351
+ prompt_tokens,
352
+ pos,
353
+ temperature,
354
+ top_p,
355
+ spatial_refs,
356
+ attn_mask=attn_mask,
357
+ lora=lora,
358
+ )
359
+
360
+ text_token_chunks = [[]]
361
+ grounding_chunks = [[]]
362
+
363
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
364
+ mask[:, :, :pos] = 1
365
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
366
+ generated_tokens = 0
367
+
368
+ while (
369
+ next_token_id := next_token.item()
370
+ ) != eos_id and generated_tokens < max_tokens:
371
+ if (
372
+ next_token_id == self.config.tokenizer.start_ground_points_id
373
+ or next_token_id == self.config.tokenizer.end_ground_id
374
+ ):
375
+ text_token_chunks.append([])
376
+ grounding_chunks.append([])
377
+
378
+ text_token_chunks[-1].append(next_token_id)
379
+
380
+ with torch.inference_mode():
381
+ if next_token_id == self.config.tokenizer.coord_id:
382
+ coord_logits = decode_coordinate(last_hidden_BC, self.region)
383
+ coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
384
+ grounding_chunks[-1].append(coord.item())
385
+
386
+ next_emb = encode_coordinate(
387
+ coord.to(dtype=coord_logits.dtype), self.region
388
+ ).unsqueeze(0)
389
+ else:
390
+ next_emb = text_encoder(next_token, self.text)
391
+
392
+ mask[:, :, pos], pos_ids[0] = 1, pos
393
+
394
+ logits_BV, last_hidden_BC = self._decode_one_tok(
395
+ next_emb, mask, pos_ids, lora
396
+ )
397
+ logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
398
+ logits_BV[:, self.config.tokenizer.size_id] = float("-inf")
399
+
400
+ pos += 1
401
+
402
+ if temperature == 0:
403
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1)
404
+ else:
405
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
406
+ probs = self._apply_top_p(probs, top_p)
407
+ next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
408
+
409
+ generated_tokens += 1
410
+
411
+ text_chunks = [
412
+ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
413
+ ]
414
+ text = "".join(text_chunks)
415
+
416
+ start_idx = 0
417
+ grounding = []
418
+ for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
419
+ if len(grounding_chunk) > 1:
420
+ points = []
421
+ for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
422
+ points.append((grounding_chunk[i], grounding_chunk[i + 1]))
423
+ grounding.append(
424
+ {
425
+ "start_idx": start_idx,
426
+ "end_idx": start_idx + len(text_chunk),
427
+ "points": points,
428
+ }
429
+ )
430
+ start_idx += len(text_chunk)
431
+
432
+ return pos, text, grounding
433
+
434
+ def _generate_answer(
435
  self,
436
  prompt_tokens: torch.Tensor,
437
  pos: int,
438
  settings: Optional[TextSamplingSettings] = None,
439
+ spatial_refs: Optional[SpatialRefs] = None,
440
+ eos_id: Optional[int] = None,
441
+ attn_mask: Optional[torch.Tensor] = None,
442
  ):
443
  max_tokens = (
444
  settings.get("max_tokens", DEFAULT_MAX_TOKENS)
 
451
  else DEFAULT_TEMPERATURE
452
  )
453
  top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
454
+ eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
455
+ lora = (
456
+ variant_state_dict(settings["variant"], device=self.device)
457
+ if settings is not None and "variant" in settings
458
+ else None
459
+ )
460
 
461
  _, _, next_token, pos = self._prefill_prompt(
462
+ prompt_tokens,
463
+ pos,
464
+ temperature,
465
+ top_p,
466
+ spatial_refs,
467
+ attn_mask=attn_mask,
468
+ lora=lora,
469
  )
470
 
471
  def generator(next_token, pos):
 
480
 
481
  while (
482
  next_token_id := next_token.item()
483
+ ) != eos_id and generated_tokens < max_tokens:
484
  # Add token to our cache
485
  token_cache.append(next_token_id)
486
 
 
500
  print_len += len(printable_text)
501
  if printable_text:
502
  yield printable_text
503
+ # Otherwise, only yield up to the last space to avoid cutting words
504
  else:
505
  last_space_idx = text.rfind(" ", print_len)
506
  if last_space_idx >= print_len:
 
512
  with torch.inference_mode():
513
  next_emb = text_encoder(next_token, self.text)
514
  mask[:, :, pos], pos_ids[0] = 1, pos
515
+
516
+ logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
517
+ logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")
518
+
519
  pos += 1
520
 
521
  if temperature == 0:
522
+ next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
523
+ 1
524
+ ) # (1, 1)
525
  else:
526
+ probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V)
527
  probs = self._apply_top_p(probs, top_p)
528
  next_token = torch.multinomial(probs, num_samples=1) # (1, 1)
529
 
 
540
 
541
  def query(
542
  self,
543
+ image: Optional[Union[Image.Image, EncodedImage]] = None,
544
+ question: str = None,
545
+ reasoning: bool = False,
546
+ spatial_refs: Optional[SpatialRefs] = None,
547
  stream: bool = False,
548
  settings: Optional[TextSamplingSettings] = None,
549
  ):
550
  if self.config.tokenizer.templates["query"] is None:
551
  raise NotImplementedError("Model does not support querying.")
552
 
553
+ if question is None:
554
+ raise ValueError("question must be provided.")
555
 
556
+ if spatial_refs and image is None:
557
+ raise ValueError("spatial_refs can only be used with an image.")
558
+
559
+ attn_mask = self.attn_mask
560
+ if image is not None:
561
+ image = self.encode_image(image, settings)
562
+ self.load_encoded_image(image)
563
+ pos = image.pos
564
+ prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
565
+ else:
566
+ self._setup_caches()
567
+ pos = 0
568
+ prompt_toks = [
569
+ self.config.tokenizer.bos_id
570
+ ] + self.config.tokenizer.templates["query"]["prefix"]
571
+ max_context = self.config.text.max_context
572
+ attn_mask = torch.tril(
573
+ torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
574
+ ).to(self.device)
575
+
576
+ spatial_toks = []
577
+ if spatial_refs:
578
+ for ref in spatial_refs:
579
+ coord_id = self.config.tokenizer.coord_id
580
+ size_id = self.config.tokenizer.size_id
581
+ if len(ref) == 2:
582
+ spatial_toks.extend([coord_id, coord_id])
583
+ else:
584
+ spatial_toks.extend([coord_id, coord_id, size_id])
585
+
586
+ prompt_tokens = [
587
+ prompt_toks
588
+ + spatial_toks
589
+ + self.tokenizer.encode(question).ids
590
+ + self.config.tokenizer.templates["query"]["suffix"]
591
+ ]
592
+
593
+ if reasoning:
594
+ prompt_tokens[0] += [self.config.tokenizer.thinking_id]
595
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
596
+ pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
597
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
598
+ )
599
+ prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
600
+ reasoning_dict = {
601
+ "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
602
+ }
603
+ else:
604
+ prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
605
+ reasoning_dict = {}
606
+
607
+ prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
608
 
609
  def generator():
610
+ for token in self._generate_answer(
611
+ prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
612
+ ):
613
  yield token
614
 
615
  if stream:
616
+ return {**reasoning_dict, "answer": generator()}
617
  else:
618
+ return {**reasoning_dict, "answer": "".join(list(generator()))}
619
 
620
  def load_encoded_image(self, encoded_image: EncodedImage):
621
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
 
634
  if length not in self.config.tokenizer.templates["caption"]:
635
  raise ValueError(f"Model does not support caption length '{length}'.")
636
 
637
+ image = self.encode_image(image, settings)
638
  self.load_encoded_image(image)
639
 
640
  prompt_tokens = torch.tensor(
 
642
  )
643
 
644
  def generator():
645
+ for token in self._generate_answer(prompt_tokens, image.pos, settings):
646
  yield token
647
 
648
  if stream:
 
657
  pos: int,
658
  include_size: bool = True,
659
  max_objects: int = DEFAULT_MAX_OBJECTS,
660
+ lora: Optional[dict] = None,
661
  ):
662
  out = []
663
  mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
 
677
 
678
  # Decode y-coordinate
679
  mask[:, :, pos], pos_ids[0] = 1, pos
680
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
681
  pos += 1
682
  y_logits = decode_coordinate(hidden, self.region)
683
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
 
688
  # Decode size
689
  if include_size:
690
  mask[:, :, pos], pos_ids[0] = 1, pos
691
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
692
  pos += 1
693
  size_logits = decode_size(hidden, self.region)
694
 
 
726
 
727
  # Decode next token (x-coordinate, or eos)
728
  mask[:, :, pos], pos_ids[0] = 1, pos
729
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
730
  pos += 1
731
  next_token = torch.argmax(logits, dim=-1)
732
 
 
741
  if self.config.tokenizer.templates["detect"] is None:
742
  raise NotImplementedError("Model does not support object detection.")
743
 
744
+ image = self.encode_image(image, settings)
745
  self.load_encoded_image(image)
746
 
747
  prompt_tokens = torch.tensor(
 
753
  device=self.device,
754
  )
755
 
756
+ lora = (
757
+ variant_state_dict(settings["variant"], device=self.device)
758
+ if settings is not None and "variant" in settings
759
+ else None
760
+ )
761
+
762
  _, hidden, next_token, pos = self._prefill_prompt(
763
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
764
  )
765
  hidden = hidden[:, -1:, :]
766
 
 
770
  else DEFAULT_MAX_OBJECTS
771
  )
772
  objects = self._generate_points(
773
+ hidden,
774
+ next_token,
775
+ pos,
776
+ include_size=True,
777
+ max_objects=max_objects,
778
+ lora=lora,
779
  )
780
 
781
  return {"objects": objects}
 
789
  if self.config.tokenizer.templates["point"] is None:
790
  raise NotImplementedError("Model does not support pointing.")
791
 
792
+ image = self.encode_image(image, settings)
793
  self.load_encoded_image(image)
794
 
795
  prompt_tokens = torch.tensor(
 
801
  device=self.device,
802
  )
803
 
804
+ lora = (
805
+ variant_state_dict(settings["variant"], device=self.device)
806
+ if settings is not None and "variant" in settings
807
+ else None
808
+ )
809
+
810
  _, hidden, next_token, pos = self._prefill_prompt(
811
+ prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
812
  )
813
  hidden = hidden[:, -1:, :]
814
 
 
818
  else DEFAULT_MAX_OBJECTS
819
  )
820
  objects = self._generate_points(
821
+ hidden,
822
+ next_token,
823
+ pos,
824
+ include_size=False,
825
+ max_objects=max_objects,
826
+ lora=lora,
827
  )
828
 
829
  return {"points": objects}
 
848
  self.text,
849
  )
850
  x_emb = encode_coordinate(
851
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
852
  self.region,
853
  )
854
  y_emb = encode_coordinate(
855
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
856
  self.region,
857
  )
858
 
 
864
  pos_ids = torch.arange(
865
  image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
866
  )
867
+ hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
868
  logits = lm_head(hidden, self.text)
869
  next_token = torch.argmax(logits, dim=-1)
870
  pos = image.pos + prompt_emb.size(1)
region.py CHANGED
@@ -2,7 +2,11 @@ import torch
2
  import torch.nn as nn
3
  import math
4
 
5
- from .layers import linear, mlp
 
 
 
 
6
 
7
 
8
  def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
@@ -36,7 +40,7 @@ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
36
  Returns:
37
  Encoded hidden states tensor for input to text model
38
  """
39
- return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
40
 
41
 
42
  def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
@@ -64,7 +68,7 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
64
  Returns:
65
  Encoded hidden states tensor for input to text model
66
  """
67
- return linear(fourier_features(size, w.size_features), w.size_encoder)
68
 
69
 
70
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
@@ -87,3 +91,46 @@ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
87
  Shape is (2, 1024) where the first dimension corresponds to width and height.
88
  """
89
  return mlp(hidden_state, w.size_decoder).view(2, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  import math
4
 
5
+ from typing import List, Tuple, Union
6
+
7
+ from .layers import mlp
8
+
9
+ SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]
10
 
11
 
12
  def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
 
40
  Returns:
41
  Encoded hidden states tensor for input to text model
42
  """
43
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
44
 
45
 
46
  def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
 
68
  Returns:
69
  Encoded hidden states tensor for input to text model
70
  """
71
+ return w.size_encoder(fourier_features(size, w.size_features))
72
 
73
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
 
91
  Shape is (2, 1024) where the first dimension corresponds to width and height.
92
  """
93
  return mlp(hidden_state, w.size_decoder).view(2, -1)
94
+
95
+
96
+ def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
97
+ """
98
+ Takes a list of spatial references (points or regions) and encodes them into
99
+ hidden states for input to the text model.
100
+
101
+ Args:
102
+ spatial_refs: List of spatial references (points or boxes)
103
+ - Points are represented as normalized (x, y) tuples
104
+ - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples
105
+
106
+ Returns:
107
+ {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
108
+ """
109
+ coords, sizes = [], []
110
+ for ref in spatial_refs:
111
+ if len(ref) == 2:
112
+ coords.append(ref[0])
113
+ coords.append(ref[1])
114
+ else:
115
+ x_c = (ref[0] + ref[2]) / 2
116
+ y_c = (ref[1] + ref[3]) / 2
117
+ width = ref[2] - ref[0]
118
+ height = ref[3] - ref[1]
119
+ coords.append(x_c)
120
+ coords.append(y_c)
121
+ sizes.append([width, height])
122
+
123
+ coords = torch.tensor(
124
+ coords, device=w.coord_features.device, dtype=w.coord_features.dtype
125
+ ).view(-1, 1)
126
+ coords = encode_coordinate(coords, w)
127
+
128
+ if sizes:
129
+ sizes = torch.tensor(
130
+ sizes, device=w.size_features.device, dtype=w.size_features.dtype
131
+ )
132
+ sizes = encode_size(sizes, w)
133
+ else:
134
+ sizes = None
135
+
136
+ return {"coords": coords, "sizes": sizes}
text.py CHANGED
@@ -2,8 +2,9 @@ import torch
2
  import torch.nn as nn
3
 
4
  from torch.nn import functional as F
 
5
 
6
- from .layers import layer_norm, mlp
7
  from .rope import apply_rotary_emb, precompute_freqs_cis
8
  from .config import TextConfig
9
 
@@ -21,25 +22,22 @@ def attn(
21
  n_heads: int,
22
  n_kv_heads: int,
23
  position_ids: torch.Tensor,
 
24
  ):
25
  bsz, q_len, d_model = x.shape
26
  head_dim = d_model // n_heads
27
 
28
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
 
 
29
  q_dim = n_heads * head_dim
30
  kv_dim = n_kv_heads * head_dim
 
 
31
 
32
- q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
33
- k = (
34
- qkv_out[..., q_dim : q_dim + kv_dim]
35
- .view(bsz, q_len, n_kv_heads, head_dim)
36
- .transpose(1, 2)
37
- )
38
- v = (
39
- qkv_out[..., q_dim + kv_dim :]
40
- .view(bsz, q_len, n_kv_heads, head_dim)
41
- .transpose(1, 2)
42
- )
43
 
44
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
45
  k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
@@ -51,7 +49,14 @@ def attn(
51
  q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
52
  )
53
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
54
- out = w.proj(out)
 
 
 
 
 
 
 
55
  return out
56
 
57
 
@@ -126,8 +131,17 @@ def text_decoder(
126
  attn_mask: torch.Tensor,
127
  position_ids: torch.Tensor,
128
  config: TextConfig,
 
129
  ):
130
  for i, block in enumerate(w.blocks):
 
 
 
 
 
 
 
 
131
  l_in = layer_norm(x, block.ln)
132
  l_attn = attn(
133
  l_in,
@@ -138,8 +152,9 @@ def text_decoder(
138
  n_heads=config.n_heads,
139
  n_kv_heads=config.n_kv_heads,
140
  position_ids=position_ids,
 
141
  )
142
- l_mlp = mlp(l_in, block.mlp)
143
  x = x + l_attn + l_mlp
144
 
145
  return x
@@ -160,6 +175,7 @@ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
160
 
161
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
162
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
 
163
 
164
  text = nn.ModuleDict(
165
  {
@@ -170,18 +186,18 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
170
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
171
  "attn": nn.ModuleDict(
172
  {
173
- "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
174
- "proj": nn.Linear(
175
  config.dim, config.dim, dtype=dtype
176
  ),
177
  }
178
  ),
179
  "mlp": nn.ModuleDict(
180
  {
181
- "fc1": nn.Linear(
182
  config.dim, config.ff_dim, dtype=dtype
183
  ),
184
- "fc2": nn.Linear(
185
  config.ff_dim, config.dim, dtype=dtype
186
  ),
187
  }
 
2
  import torch.nn as nn
3
 
4
  from torch.nn import functional as F
5
+ from typing import Optional
6
 
7
+ from .layers import layer_norm, mlp, QuantizedLinear
8
  from .rope import apply_rotary_emb, precompute_freqs_cis
9
  from .config import TextConfig
10
 
 
22
  n_heads: int,
23
  n_kv_heads: int,
24
  position_ids: torch.Tensor,
25
+ lora: Optional[dict],
26
  ):
27
  bsz, q_len, d_model = x.shape
28
  head_dim = d_model // n_heads
29
 
30
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
31
+ if lora is not None:
32
+ qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
33
  q_dim = n_heads * head_dim
34
  kv_dim = n_kv_heads * head_dim
35
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
36
+ del qkv_out
37
 
38
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
39
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
40
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
41
 
42
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
43
  k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
 
49
  q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
50
  )
51
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
52
+
53
+ out0 = w.proj(out)
54
+ if lora is not None:
55
+ out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
56
+ out = out0 + out1
57
+ else:
58
+ out = out0
59
+
60
  return out
61
 
62
 
 
131
  attn_mask: torch.Tensor,
132
  position_ids: torch.Tensor,
133
  config: TextConfig,
134
+ lora: Optional[dict],
135
  ):
136
  for i, block in enumerate(w.blocks):
137
+ if lora is not None:
138
+ layer_lora = lora["text"]["blocks"][str(i)]
139
+ mlp_lora = layer_lora["mlp"]
140
+ attn_lora = layer_lora["attn"]
141
+ else:
142
+ mlp_lora = None
143
+ attn_lora = None
144
+
145
  l_in = layer_norm(x, block.ln)
146
  l_attn = attn(
147
  l_in,
 
152
  n_heads=config.n_heads,
153
  n_kv_heads=config.n_kv_heads,
154
  position_ids=position_ids,
155
+ lora=attn_lora,
156
  )
157
+ l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)
158
  x = x + l_attn + l_mlp
159
 
160
  return x
 
175
 
176
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
177
  qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
178
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
179
 
180
  text = nn.ModuleDict(
181
  {
 
186
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
187
  "attn": nn.ModuleDict(
188
  {
189
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
190
+ "proj": linear_cls(
191
  config.dim, config.dim, dtype=dtype
192
  ),
193
  }
194
  ),
195
  "mlp": nn.ModuleDict(
196
  {
197
+ "fc1": linear_cls(
198
  config.dim, config.ff_dim, dtype=dtype
199
  ),
200
+ "fc2": linear_cls(
201
  config.ff_dim, config.dim, dtype=dtype
202
  ),
203
  }
vision.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from typing import Union, Tuple
7
  from PIL import Image
8
 
9
- from .layers import attn, layer_norm, linear, mlp
10
  from .image_crops import overlap_crop_image
11
  from .config import VisionConfig
12
 
@@ -33,7 +33,7 @@ def prepare_crops(
33
  all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
  all_crops = (
35
  torch.from_numpy(all_crops)
36
- .to(device=device, dtype=torch.float16)
37
  .div_(255.0)
38
  .sub_(0.5)
39
  .div_(0.5)
@@ -64,7 +64,7 @@ def create_patches(x, patch_size):
64
  def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
  x = create_patches(input_BCHW, config.enc_patch_size)
66
 
67
- x = linear(x, w.patch_emb)
68
  x = x + w.pos_emb
69
  for block in w.blocks:
70
  x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
 
6
  from typing import Union, Tuple
7
  from PIL import Image
8
 
9
+ from .layers import attn, layer_norm, mlp
10
  from .image_crops import overlap_crop_image
11
  from .config import VisionConfig
12
 
 
33
  all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
  all_crops = (
35
  torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
  .div_(255.0)
38
  .sub_(0.5)
39
  .div_(0.5)
 
64
  def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
  x = create_patches(input_BCHW, config.enc_patch_size)
66
 
67
+ x = w.patch_emb(x)
68
  x = x + w.pos_emb
69
  for block in w.blocks:
70
  x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)