katuni4ka commited on
Commit
b60128f
·
verified ·
1 Parent(s): 2593b9b

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +12 -8
modeling_chatglm.py CHANGED
@@ -489,9 +489,10 @@ class GLMBlock(torch.nn.Module):
489
  self.fp32_residual_connection = config.fp32_residual_connection
490
 
491
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
492
  # Layernorm on the input data.
493
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
494
- dtype=config.torch_dtype)
495
 
496
  # Self attention.
497
  self.self_attention = SelfAttention(config, layer_number, device=device)
@@ -499,7 +500,7 @@ class GLMBlock(torch.nn.Module):
499
 
500
  # Layernorm on the attention output
501
  self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
502
- dtype=config.torch_dtype)
503
 
504
  # MLP
505
  self.mlp = MLP(config, device=device)
@@ -567,9 +568,10 @@ class GLMTransformer(torch.nn.Module):
567
 
568
  if self.post_layer_norm:
569
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
 
570
  # Final layer norm before output.
571
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
572
- dtype=config.torch_dtype)
573
 
574
  self.gradient_checkpointing = False
575
 
@@ -690,10 +692,11 @@ class Embedding(torch.nn.Module):
690
 
691
  self.hidden_size = config.hidden_size
692
  # Word embeddings (parallel).
 
693
  self.word_embeddings = nn.Embedding(
694
  config.padded_vocab_size,
695
  self.hidden_size,
696
- dtype=config.torch_dtype,
697
  device=device
698
  )
699
  self.fp32_residual_connection = config.fp32_residual_connection
@@ -728,12 +731,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
728
  rotary_dim = (
729
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
730
  )
731
-
732
  self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
733
- device=device, dtype=config.torch_dtype)
734
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
735
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
736
- dtype=config.torch_dtype, **init_kwargs)
737
 
738
  def get_input_embeddings(self):
739
  return self.embedding.word_embeddings
@@ -1153,8 +1156,9 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1153
 
1154
  self.num_labels = config.num_labels
1155
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
 
1156
 
1157
- self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=config.torch_dtype)
1158
  if config.classifier_dropout is not None:
1159
  self.dropout = nn.Dropout(config.classifier_dropout)
1160
  else:
 
489
  self.fp32_residual_connection = config.fp32_residual_connection
490
 
491
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
492
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
493
  # Layernorm on the input data.
494
  self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
495
+ dtype=dtype)
496
 
497
  # Self attention.
498
  self.self_attention = SelfAttention(config, layer_number, device=device)
 
500
 
501
  # Layernorm on the attention output
502
  self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
503
+ dtype=dtype)
504
 
505
  # MLP
506
  self.mlp = MLP(config, device=device)
 
568
 
569
  if self.post_layer_norm:
570
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
571
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
572
  # Final layer norm before output.
573
  self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
574
+ dtype=dtype)
575
 
576
  self.gradient_checkpointing = False
577
 
 
692
 
693
  self.hidden_size = config.hidden_size
694
  # Word embeddings (parallel).
695
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
696
  self.word_embeddings = nn.Embedding(
697
  config.padded_vocab_size,
698
  self.hidden_size,
699
+ dtype=dtype,
700
  device=device
701
  )
702
  self.fp32_residual_connection = config.fp32_residual_connection
 
731
  rotary_dim = (
732
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
733
  )
734
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
735
  self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
736
+ device=device, dtype=dtype)
737
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
738
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
739
+ dtype=dtype, **init_kwargs)
740
 
741
  def get_input_embeddings(self):
742
  return self.embedding.word_embeddings
 
1156
 
1157
  self.num_labels = config.num_labels
1158
  self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1159
+ dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
1160
 
1161
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=dtype)
1162
  if config.classifier_dropout is not None:
1163
  self.dropout = nn.Dropout(config.classifier_dropout)
1164
  else: