Update region.py
Browse files
region.py
CHANGED
|
@@ -72,9 +72,22 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
| 72 |
|
| 73 |
|
| 74 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Original contract: returns (2, C) for width/height logits.
|
| 77 |
+
This keeps all downstream code in moondream.py working as-is.
|
| 78 |
+
"""
|
| 79 |
+
# w.size_decoder is your 2*C-projection MLP/Linear
|
| 80 |
+
x = w.size_decoder(hidden_state) # (..., 2*C) in practice called on the last token
|
| 81 |
+
if x.dim() != 1:
|
| 82 |
+
# Most of the original code paths call this on a single hidden vector.
|
| 83 |
+
# If a higher-rank tensor slips in, collapse it conservatively.
|
| 84 |
+
x = x.reshape(-1)[-x.shape[-1]:] # take the final vector
|
| 85 |
+
last = x.shape[-1]
|
| 86 |
+
if last % 2 != 0:
|
| 87 |
+
raise RuntimeError(f"size_out_dim must be even, got {last}")
|
| 88 |
+
# (2, C)
|
| 89 |
+
return x.view(2, last // 2)
|
| 90 |
+
|
| 91 |
|
| 92 |
|
| 93 |
|