Update modeling_chatglm.py (#2)
Browse files- Ensure both tensors are on the same device. (d56d27dbdf514a7b7e0f39c62883ed0836389af4)
Co-authored-by: Hou <[email protected]>
- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
|
@@ -666,6 +666,9 @@ class GLMTransformer(torch.nn.Module):
|
|
| 666 |
if len(presents) == 0:
|
| 667 |
presents = kv_cache
|
| 668 |
else:
|
|
|
|
|
|
|
|
|
|
| 669 |
presents = torch.cat((presents, kv_cache), dim=0)
|
| 670 |
|
| 671 |
if output_hidden_states:
|
|
|
|
| 666 |
if len(presents) == 0:
|
| 667 |
presents = kv_cache
|
| 668 |
else:
|
| 669 |
+
# Ensure both tensors are on the same device
|
| 670 |
+
if presents.device != kv_cache.device:
|
| 671 |
+
presents = presents.to(kv_cache.device)
|
| 672 |
presents = torch.cat((presents, kv_cache), dim=0)
|
| 673 |
|
| 674 |
if output_hidden_states:
|