HV-Khurdula commited on
Commit
17e2272
·
verified ·
1 Parent(s): 09597f3

Update region.py

Browse files

fix: decode_size.

Files changed (1) hide show
  1. region.py +12 -10
region.py CHANGED
@@ -73,20 +73,22 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
73
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
  """
76
- Original contract: returns (2, C) for width/height logits.
77
- This keeps all downstream code in moondream.py working as-is.
 
 
78
  """
79
- # w.size_decoder is your 2*C-projection MLP/Linear
80
- x = w.size_decoder(hidden_state) # (..., 2*C) in practice called on the last token
81
- if x.dim() != 1:
82
- # Most of the original code paths call this on a single hidden vector.
83
- # If a higher-rank tensor slips in, collapse it conservatively.
84
- x = x.reshape(-1)[-x.shape[-1]:] # take the final vector
85
  last = x.shape[-1]
86
  if last % 2 != 0:
87
  raise RuntimeError(f"size_out_dim must be even, got {last}")
88
- # (2, C)
89
- return x.view(2, last // 2)
 
 
 
90
 
91
 
92
 
 
73
 
74
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
75
  """
76
+ Takes as input the last hidden state from the text model and outputs logits
77
+ for 1024 bins representing width and height in log-scale.
78
+
79
+ Returns logits shaped (..., 2, C) so batched code can handle it directly.
80
  """
81
+ # Run the two-layer MLP that projects to 2*C (width+height) bins
82
+ x = mlp(hidden_state, w.size_decoder) # shape: (..., 2*C)
83
+
 
 
 
84
  last = x.shape[-1]
85
  if last % 2 != 0:
86
  raise RuntimeError(f"size_out_dim must be even, got {last}")
87
+
88
+ C = last // 2
89
+ # Keep any leading batch/seq dims intact and split the last dim into (2, C)
90
+ return x.view(*x.shape[:-1], 2, C)
91
+
92
 
93
 
94