Update modeling_qwen2.py (#33)
Browse files- Update modeling_qwen2.py (38907af6c6838420f3a233eb9cc00833bba592ce)
Co-authored-by: Zhirui Zhang <[email protected]>
- modeling_qwen2.py +6 -5
modeling_qwen2.py
CHANGED
|
@@ -174,16 +174,17 @@ class TimeSeriesEmbedding(nn.Module):
|
|
| 174 |
padding = last_value.repeat(padding_length, 1)
|
| 175 |
xi = torch.cat([xi, padding], dim=0)
|
| 176 |
|
| 177 |
-
# Use special padding index for padding positions
|
| 178 |
-
|
| 179 |
-
|
|
|
|
| 180 |
|
| 181 |
# Reshape to patches
|
| 182 |
xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
if self.use_position_embedding:
|
| 186 |
# Collect position indices instead of calling embedding immediately
|
|
|
|
| 187 |
all_position_indices.append(position_indices)
|
| 188 |
patch_info_list.append({
|
| 189 |
'xi': xi,
|
|
|
|
| 174 |
padding = last_value.repeat(padding_length, 1)
|
| 175 |
xi = torch.cat([xi, padding], dim=0)
|
| 176 |
|
| 177 |
+
# Use special padding index for padding positions when use_position_embedding enabled
|
| 178 |
+
if self.use_position_embedding:
|
| 179 |
+
padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
|
| 180 |
+
position_indices = torch.cat([position_indices, padding_positions], dim=0)
|
| 181 |
|
| 182 |
# Reshape to patches
|
| 183 |
xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
|
| 184 |
+
|
|
|
|
| 185 |
if self.use_position_embedding:
|
| 186 |
# Collect position indices instead of calling embedding immediately
|
| 187 |
+
position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
|
| 188 |
all_position_indices.append(position_indices)
|
| 189 |
patch_info_list.append({
|
| 190 |
'xi': xi,
|