xiezhe24 Wannabtl commited on
Commit
6344c1f
·
verified ·
1 Parent(s): 89cccd9

Update modeling_qwen2.py (#33)

Browse files

- Update modeling_qwen2.py (38907af6c6838420f3a233eb9cc00833bba592ce)


Co-authored-by: Zhirui Zhang <[email protected]>

Files changed (1) hide show
  1. modeling_qwen2.py +6 -5
modeling_qwen2.py CHANGED
@@ -174,16 +174,17 @@ class TimeSeriesEmbedding(nn.Module):
174
  padding = last_value.repeat(padding_length, 1)
175
  xi = torch.cat([xi, padding], dim=0)
176
 
177
- # Use special padding index for padding positions
178
- padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
179
- position_indices = torch.cat([position_indices, padding_positions], dim=0)
 
180
 
181
  # Reshape to patches
182
  xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
183
- position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
184
-
185
  if self.use_position_embedding:
186
  # Collect position indices instead of calling embedding immediately
 
187
  all_position_indices.append(position_indices)
188
  patch_info_list.append({
189
  'xi': xi,
 
174
  padding = last_value.repeat(padding_length, 1)
175
  xi = torch.cat([xi, padding], dim=0)
176
 
177
+ # Use special padding index for padding positions when use_position_embedding enabled
178
+ if self.use_position_embedding:
179
+ padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
180
+ position_indices = torch.cat([position_indices, padding_positions], dim=0)
181
 
182
  # Reshape to patches
183
  xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
184
+
 
185
  if self.use_position_embedding:
186
  # Collect position indices instead of calling embedding immediately
187
+ position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
188
  all_position_indices.append(position_indices)
189
  patch_info_list.append({
190
  'xi': xi,