Update region.py
Browse filesfix: decode_size.
region.py
CHANGED
@@ -73,20 +73,22 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
73 |
|
74 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
75 |
"""
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
"""
|
79 |
-
#
|
80 |
-
x = w.size_decoder
|
81 |
-
|
82 |
-
# Most of the original code paths call this on a single hidden vector.
|
83 |
-
# If a higher-rank tensor slips in, collapse it conservatively.
|
84 |
-
x = x.reshape(-1)[-x.shape[-1]:] # take the final vector
|
85 |
last = x.shape[-1]
|
86 |
if last % 2 != 0:
|
87 |
raise RuntimeError(f"size_out_dim must be even, got {last}")
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
90 |
|
91 |
|
92 |
|
|
|
73 |
|
74 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
75 |
"""
|
76 |
+
Takes as input the last hidden state from the text model and outputs logits
|
77 |
+
for 1024 bins representing width and height in log-scale.
|
78 |
+
|
79 |
+
Returns logits shaped (..., 2, C) so batched code can handle it directly.
|
80 |
"""
|
81 |
+
# Run the two-layer MLP that projects to 2*C (width+height) bins
|
82 |
+
x = mlp(hidden_state, w.size_decoder) # shape: (..., 2*C)
|
83 |
+
|
|
|
|
|
|
|
84 |
last = x.shape[-1]
|
85 |
if last % 2 != 0:
|
86 |
raise RuntimeError(f"size_out_dim must be even, got {last}")
|
87 |
+
|
88 |
+
C = last // 2
|
89 |
+
# Keep any leading batch/seq dims intact and split the last dim into (2, C)
|
90 |
+
return x.view(*x.shape[:-1], 2, C)
|
91 |
+
|
92 |
|
93 |
|
94 |
|