Wannabtl commited on
Commit
38907af
·
verified ·
1 Parent(s): 89cccd9

Update modeling_qwen2.py

Browse files

### **Pull Request Title**

Fix padding\_idx bug when `use_position_embedding=false` and `use_position_idx=true`

---

### **Description**

This PR fixes a bug where `self.padding_idx` is undefined if `use_position_embedding=false` and `use_position_idx=true`, since its initialization only happens inside the `if self.use_position_embedding:` block.

**Changes made:**

* Added a conditional check around lines 178–179 to avoid accessing `self.padding_idx` when position embeddings are disabled.
* Moved

```python
position_indices = position_indices.reshape(pc, self.patch_size)
```

into the `use_position_embedding` block for consistency.

These changes prevent runtime errors and keep position index reshaping aligned with the intended design.

Files changed (1) hide show
  1. modeling_qwen2.py +6 -5
modeling_qwen2.py CHANGED
@@ -174,16 +174,17 @@ class TimeSeriesEmbedding(nn.Module):
174
  padding = last_value.repeat(padding_length, 1)
175
  xi = torch.cat([xi, padding], dim=0)
176
 
177
- # Use special padding index for padding positions
178
- padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
179
- position_indices = torch.cat([position_indices, padding_positions], dim=0)
 
180
 
181
  # Reshape to patches
182
  xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
183
- position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
184
-
185
  if self.use_position_embedding:
186
  # Collect position indices instead of calling embedding immediately
 
187
  all_position_indices.append(position_indices)
188
  patch_info_list.append({
189
  'xi': xi,
 
174
  padding = last_value.repeat(padding_length, 1)
175
  xi = torch.cat([xi, padding], dim=0)
176
 
177
+ # Use special padding index for padding positions when use_position_embedding enabled
178
+ if self.use_position_embedding:
179
+ padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
180
+ position_indices = torch.cat([position_indices, padding_positions], dim=0)
181
 
182
  # Reshape to patches
183
  xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
184
+
 
185
  if self.use_position_embedding:
186
  # Collect position indices instead of calling embedding immediately
187
+ position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
188
  all_position_indices.append(position_indices)
189
  patch_info_list.append({
190
  'xi': xi,