HV-Khurdula commited on
Commit
d53b116
·
verified ·
1 Parent(s): 7eac0da

Update region.py

Browse files
Files changed (1) hide show
  1. region.py +16 -3
region.py CHANGED
@@ -72,9 +72,22 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
72
 
73
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
- # Original API expected by moondream.py: shape (2, C) when called on the last hidden state
76
- x = mlp(hidden_state, w.size_decoder) # (..., 2*C)
77
- return x.view(2, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
 
 
72
 
73
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
+ """
76
+ Original contract: returns (2, C) for width/height logits.
77
+ This keeps all downstream code in moondream.py working as-is.
78
+ """
79
+ # w.size_decoder is your 2*C-projection MLP/Linear
80
+ x = w.size_decoder(hidden_state) # (..., 2*C) in practice called on the last token
81
+ if x.dim() != 1:
82
+ # Most of the original code paths call this on a single hidden vector.
83
+ # If a higher-rank tensor slips in, collapse it conservatively.
84
+ x = x.reshape(-1)[-x.shape[-1]:] # take the final vector
85
+ last = x.shape[-1]
86
+ if last % 2 != 0:
87
+ raise RuntimeError(f"size_out_dim must be even, got {last}")
88
+ # (2, C)
89
+ return x.view(2, last // 2)
90
+
91
 
92
 
93