katuni4ka commited on
Commit
2ed169c
·
verified ·
1 Parent(s): b60128f

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +12 -3
modeling_chatglm.py CHANGED
@@ -422,7 +422,7 @@ class SelfAttention(torch.nn.Module):
422
 
423
  def _config_to_kwargs(args):
424
  common_kwargs = {
425
- "dtype": args.torch_dtype,
426
  }
427
  return common_kwargs
428
 
@@ -720,7 +720,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
720
  init_method = default_init
721
  init_kwargs = {}
722
  if device is not None:
723
- init_kwargs["device"] = device
724
  self.embedding = init_method(Embedding, config, **init_kwargs)
725
  self.num_layers = config.num_layers
726
  self.multi_query_group_num = config.multi_query_group_num
@@ -954,6 +954,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
954
  for layer_past in past
955
  )
956
 
 
 
 
 
 
 
 
 
 
957
  def process_response(self, output, history):
958
  content = ""
959
  history = deepcopy(history)
@@ -1231,4 +1240,4 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1231
  past_key_values=transformer_outputs.past_key_values,
1232
  hidden_states=transformer_outputs.hidden_states,
1233
  attentions=transformer_outputs.attentions,
1234
- )
 
422
 
423
  def _config_to_kwargs(args):
424
  common_kwargs = {
425
+ "dtype": args.torch_dtype if not isinstance(args.torch_dtype, str) else getattr(torch, args.torch_dtype)
426
  }
427
  return common_kwargs
428
 
 
720
  init_method = default_init
721
  init_kwargs = {}
722
  if device is not None:
723
+ init_kwargs["device"] = device if not isinstance(device, str) else torch.device(device)
724
  self.embedding = init_method(Embedding, config, **init_kwargs)
725
  self.num_layers = config.num_layers
726
  self.multi_query_group_num = config.multi_query_group_num
 
954
  for layer_past in past
955
  )
956
 
957
+ @staticmethod
958
+ def _extract_past_from_model_output(outputs: ModelOutput, *args, **kwargs):
959
+ past_key_values = None
960
+ if "past_key_values" in outputs:
961
+ past_key_values = outputs.past_key_values
962
+ if is_transformers_4_42_or_higher:
963
+ return None, past_key_values
964
+ return past_key_values
965
+
966
  def process_response(self, output, history):
967
  content = ""
968
  history = deepcopy(history)
 
1240
  past_key_values=transformer_outputs.past_key_values,
1241
  hidden_states=transformer_outputs.hidden_states,
1242
  attentions=transformer_outputs.attentions,
1243
+ )