Update modeling_chatglm.py
Browse files- modeling_chatglm.py +12 -8
modeling_chatglm.py
CHANGED
@@ -489,9 +489,10 @@ class GLMBlock(torch.nn.Module):
|
|
489 |
self.fp32_residual_connection = config.fp32_residual_connection
|
490 |
|
491 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
492 |
# Layernorm on the input data.
|
493 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
494 |
-
dtype=
|
495 |
|
496 |
# Self attention.
|
497 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
@@ -499,7 +500,7 @@ class GLMBlock(torch.nn.Module):
|
|
499 |
|
500 |
# Layernorm on the attention output
|
501 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
502 |
-
dtype=
|
503 |
|
504 |
# MLP
|
505 |
self.mlp = MLP(config, device=device)
|
@@ -567,9 +568,10 @@ class GLMTransformer(torch.nn.Module):
|
|
567 |
|
568 |
if self.post_layer_norm:
|
569 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
|
|
570 |
# Final layer norm before output.
|
571 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
572 |
-
dtype=
|
573 |
|
574 |
self.gradient_checkpointing = False
|
575 |
|
@@ -690,10 +692,11 @@ class Embedding(torch.nn.Module):
|
|
690 |
|
691 |
self.hidden_size = config.hidden_size
|
692 |
# Word embeddings (parallel).
|
|
|
693 |
self.word_embeddings = nn.Embedding(
|
694 |
config.padded_vocab_size,
|
695 |
self.hidden_size,
|
696 |
-
dtype=
|
697 |
device=device
|
698 |
)
|
699 |
self.fp32_residual_connection = config.fp32_residual_connection
|
@@ -728,12 +731,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
728 |
rotary_dim = (
|
729 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
730 |
)
|
731 |
-
|
732 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
|
733 |
-
device=device, dtype=
|
734 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
735 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
736 |
-
dtype=
|
737 |
|
738 |
def get_input_embeddings(self):
|
739 |
return self.embedding.word_embeddings
|
@@ -1153,8 +1156,9 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
|
|
1153 |
|
1154 |
self.num_labels = config.num_labels
|
1155 |
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
|
|
1156 |
|
1157 |
-
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=
|
1158 |
if config.classifier_dropout is not None:
|
1159 |
self.dropout = nn.Dropout(config.classifier_dropout)
|
1160 |
else:
|
|
|
489 |
self.fp32_residual_connection = config.fp32_residual_connection
|
490 |
|
491 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
492 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
493 |
# Layernorm on the input data.
|
494 |
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
495 |
+
dtype=dtype)
|
496 |
|
497 |
# Self attention.
|
498 |
self.self_attention = SelfAttention(config, layer_number, device=device)
|
|
|
500 |
|
501 |
# Layernorm on the attention output
|
502 |
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
503 |
+
dtype=dtype)
|
504 |
|
505 |
# MLP
|
506 |
self.mlp = MLP(config, device=device)
|
|
|
568 |
|
569 |
if self.post_layer_norm:
|
570 |
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
571 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
572 |
# Final layer norm before output.
|
573 |
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
|
574 |
+
dtype=dtype)
|
575 |
|
576 |
self.gradient_checkpointing = False
|
577 |
|
|
|
692 |
|
693 |
self.hidden_size = config.hidden_size
|
694 |
# Word embeddings (parallel).
|
695 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
696 |
self.word_embeddings = nn.Embedding(
|
697 |
config.padded_vocab_size,
|
698 |
self.hidden_size,
|
699 |
+
dtype=dtype,
|
700 |
device=device
|
701 |
)
|
702 |
self.fp32_residual_connection = config.fp32_residual_connection
|
|
|
731 |
rotary_dim = (
|
732 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
733 |
)
|
734 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
735 |
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
|
736 |
+
device=device, dtype=dtype)
|
737 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
738 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
739 |
+
dtype=dtype, **init_kwargs)
|
740 |
|
741 |
def get_input_embeddings(self):
|
742 |
return self.embedding.word_embeddings
|
|
|
1156 |
|
1157 |
self.num_labels = config.num_labels
|
1158 |
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
|
1159 |
+
dtype = getattr(torch, config.torch_dtype) if isinstance(config.torch_dtype, str) else config.torch_dtype
|
1160 |
|
1161 |
+
self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=dtype)
|
1162 |
if config.classifier_dropout is not None:
|
1163 |
self.dropout = nn.Dropout(config.classifier_dropout)
|
1164 |
else:
|