Update modeling_qwen2.py
Browse files### **Pull Request Title**
Fix padding\_idx bug when `use_position_embedding=false` and `use_position_idx=true`
---
### **Description**
This PR fixes a bug where `self.padding_idx` is undefined if `use_position_embedding=false` and `use_position_idx=true`, since its initialization only happens inside the `if self.use_position_embedding:` block.
**Changes made:**
* Added a conditional check around lines 178–179 to avoid accessing `self.padding_idx` when position embeddings are disabled.
* Moved
```python
position_indices = position_indices.reshape(pc, self.patch_size)
```
into the `use_position_embedding` block for consistency.
These changes prevent runtime errors and keep position index reshaping aligned with the intended design.
- modeling_qwen2.py +6 -5
modeling_qwen2.py
CHANGED
|
@@ -174,16 +174,17 @@ class TimeSeriesEmbedding(nn.Module):
|
|
| 174 |
padding = last_value.repeat(padding_length, 1)
|
| 175 |
xi = torch.cat([xi, padding], dim=0)
|
| 176 |
|
| 177 |
-
# Use special padding index for padding positions
|
| 178 |
-
|
| 179 |
-
|
|
|
|
| 180 |
|
| 181 |
# Reshape to patches
|
| 182 |
xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
if self.use_position_embedding:
|
| 186 |
# Collect position indices instead of calling embedding immediately
|
|
|
|
| 187 |
all_position_indices.append(position_indices)
|
| 188 |
patch_info_list.append({
|
| 189 |
'xi': xi,
|
|
|
|
| 174 |
padding = last_value.repeat(padding_length, 1)
|
| 175 |
xi = torch.cat([xi, padding], dim=0)
|
| 176 |
|
| 177 |
+
# Use special padding index for padding positions when use_position_embedding enabled
|
| 178 |
+
if self.use_position_embedding:
|
| 179 |
+
padding_positions = torch.full((padding_length,), self.padding_idx, device=x.device)
|
| 180 |
+
position_indices = torch.cat([position_indices, padding_positions], dim=0)
|
| 181 |
|
| 182 |
# Reshape to patches
|
| 183 |
xi = xi.reshape(pc, self.patch_size) # (num_patches, patch_size)
|
| 184 |
+
|
|
|
|
| 185 |
if self.use_position_embedding:
|
| 186 |
# Collect position indices instead of calling embedding immediately
|
| 187 |
+
position_indices = position_indices.reshape(pc, self.patch_size) # (num_patches, patch_size)
|
| 188 |
all_position_indices.append(position_indices)
|
| 189 |
patch_info_list.append({
|
| 190 |
'xi': xi,
|