Update modeling_chatglm.py
Browse files- modeling_chatglm.py +12 -3
modeling_chatglm.py
CHANGED
@@ -422,7 +422,7 @@ class SelfAttention(torch.nn.Module):
|
|
422 |
|
423 |
def _config_to_kwargs(args):
|
424 |
common_kwargs = {
|
425 |
-
"dtype": args.torch_dtype,
|
426 |
}
|
427 |
return common_kwargs
|
428 |
|
@@ -720,7 +720,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
720 |
init_method = default_init
|
721 |
init_kwargs = {}
|
722 |
if device is not None:
|
723 |
-
init_kwargs["device"] = device
|
724 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
725 |
self.num_layers = config.num_layers
|
726 |
self.multi_query_group_num = config.multi_query_group_num
|
@@ -954,6 +954,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
954 |
for layer_past in past
|
955 |
)
|
956 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
957 |
def process_response(self, output, history):
|
958 |
content = ""
|
959 |
history = deepcopy(history)
|
@@ -1231,4 +1240,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1231 |
past_key_values=transformer_outputs.past_key_values,
|
1232 |
hidden_states=transformer_outputs.hidden_states,
|
1233 |
attentions=transformer_outputs.attentions,
|
1234 |
-
)
|
|
|
422 |
|
423 |
def _config_to_kwargs(args):
|
424 |
common_kwargs = {
|
425 |
+
"dtype": args.torch_dtype if not isinstance(args.torch_dtype, str) else getattr(torch, args.torch_dtype)
|
426 |
}
|
427 |
return common_kwargs
|
428 |
|
|
|
720 |
init_method = default_init
|
721 |
init_kwargs = {}
|
722 |
if device is not None:
|
723 |
+
init_kwargs["device"] = device if not isinstance(device, str) else torch.device(device)
|
724 |
self.embedding = init_method(Embedding, config, **init_kwargs)
|
725 |
self.num_layers = config.num_layers
|
726 |
self.multi_query_group_num = config.multi_query_group_num
|
|
|
954 |
for layer_past in past
|
955 |
)
|
956 |
|
957 |
+
@staticmethod
|
958 |
+
def _extract_past_from_model_output(outputs: ModelOutput, *args, **kwargs):
|
959 |
+
past_key_values = None
|
960 |
+
if "past_key_values" in outputs:
|
961 |
+
past_key_values = outputs.past_key_values
|
962 |
+
if is_transformers_4_42_or_higher:
|
963 |
+
return None, past_key_values
|
964 |
+
return past_key_values
|
965 |
+
|
966 |
def process_response(self, output, history):
|
967 |
content = ""
|
968 |
history = deepcopy(history)
|
|
|
1240 |
past_key_values=transformer_outputs.past_key_values,
|
1241 |
hidden_states=transformer_outputs.hidden_states,
|
1242 |
attentions=transformer_outputs.attentions,
|
1243 |
+
)
|