SVECTOR-OFFICIAL commited on
Commit
d9adf9c
·
verified ·
1 Parent(s): a708217

Create modeling_theta.py

Browse files
Files changed (1) hide show
  1. modeling_theta.py +117 -0
modeling_theta.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.modeling_outputs import BaseModelOutputWithPast
4
+ from transformers import PreTrainedModel
5
+ from configuration_theta import ThetaConfig
6
+
7
+ class ThetaAttention(nn.Module):
8
+ def __init__(self, config: ThetaConfig):
9
+ super().__init__()
10
+ self.num_heads = config.num_attention_heads
11
+ self.head_dim = config.hidden_size // config.num_attention_heads
12
+ self.scale = self.head_dim ** -0.5
13
+
14
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
15
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
16
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
17
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
18
+
19
+ def forward(self, hidden_states):
20
+ batch_size, seq_length, embed_dim = hidden_states.size()
21
+
22
+ query = self.q_proj(hidden_states)
23
+ key = self.k_proj(hidden_states)
24
+ value = self.v_proj(hidden_states)
25
+
26
+ query = query.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
27
+ key = key.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
28
+ value = value.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
29
+
30
+ attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
31
+ attn_probs = nn.functional.softmax(attn_scores, dim=-1)
32
+ attn_output = torch.matmul(attn_probs, value)
33
+
34
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)
35
+ attn_output = self.out_proj(attn_output)
36
+
37
+ return attn_output
38
+
39
+ class ThetaMLP(nn.Module):
40
+ def __init__(self, config: ThetaConfig):
41
+ super().__init__()
42
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
43
+ self.act = nn.SiLU()
44
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
45
+
46
+ def forward(self, hidden_states):
47
+ hidden_states = self.fc1(hidden_states)
48
+ hidden_states = self.act(hidden_states)
49
+ hidden_states = self.fc2(hidden_states)
50
+ return hidden_states
51
+
52
+ class ThetaBlock(nn.Module):
53
+ def __init__(self, config: ThetaConfig):
54
+ super().__init__()
55
+ self.attention = ThetaAttention(config)
56
+ self.mlp = ThetaMLP(config)
57
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
58
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
59
+
60
+ def forward(self, hidden_states):
61
+ hidden_states = hidden_states + self.attention(self.norm1(hidden_states))
62
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
63
+ return hidden_states
64
+
65
+ class ThetaModel(PreTrainedModel):
66
+ config_class = ThetaConfig
67
+
68
+ def __init__(self, config: ThetaConfig):
69
+ super().__init__(config)
70
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
71
+ self.layers = nn.ModuleList([ThetaBlock(config) for _ in range(config.num_hidden_layers)])
72
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
73
+
74
+ def forward(
75
+ self,
76
+ input_ids=None,
77
+ attention_mask=None,
78
+ **kwargs,
79
+ ):
80
+ hidden_states = self.embed_tokens(input_ids)
81
+
82
+ for layer in self.layers:
83
+ hidden_states = layer(hidden_states)
84
+
85
+ hidden_states = self.norm(hidden_states)
86
+
87
+ return BaseModelOutputWithPast(
88
+ last_hidden_state=hidden_states,
89
+ hidden_states=None,
90
+ past_key_values=None,
91
+ )
92
+
93
+ class ThetaForCausalLM(PreTrainedModel):
94
+ config_class = ThetaConfig
95
+
96
+ def __init__(self, config: ThetaConfig):
97
+ super().__init__(config)
98
+ self.model = ThetaModel(config)
99
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
100
+
101
+ def forward(
102
+ self,
103
+ input_ids=None,
104
+ labels=None,
105
+ **kwargs,
106
+ ):
107
+ outputs = self.model(input_ids=input_ids, **kwargs)
108
+ logits = self.lm_head(outputs.last_hidden_state)
109
+
110
+ loss = None
111
+ if labels is not None:
112
+ shift_logits = logits[..., :-1, :].contiguous()
113
+ shift_labels = labels[..., 1:].contiguous()
114
+ loss_fct = nn.CrossEntropyLoss()
115
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
116
+
117
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}