Ensure both tensors are on the same device.
Browse filesThis code change addresses issue #968 in THUDM/ChatGLM3. Ensure both tensors are on the same device.
- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
|
@@ -667,6 +667,9 @@ class GLMTransformer(torch.nn.Module):
|
|
| 667 |
if len(presents) == 0:
|
| 668 |
presents = kv_cache
|
| 669 |
else:
|
|
|
|
|
|
|
|
|
|
| 670 |
presents = torch.cat((presents, kv_cache), dim=0)
|
| 671 |
|
| 672 |
if output_hidden_states:
|
|
|
|
| 667 |
if len(presents) == 0:
|
| 668 |
presents = kv_cache
|
| 669 |
else:
|
| 670 |
+
# Ensure both tensors are on the same device
|
| 671 |
+
if presents.device != kv_cache.device:
|
| 672 |
+
presents = presents.to(kv_cache.device)
|
| 673 |
presents = torch.cat((presents, kv_cache), dim=0)
|
| 674 |
|
| 675 |
if output_hidden_states:
|