LinyingLyu commited on
Commit
ff99f02
·
verified ·
1 Parent(s): abc5bd2

Upload ChronoGPT_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ChronoGPT_inference.py +2 -86
ChronoGPT_inference.py CHANGED
@@ -125,90 +125,6 @@ class Block(nn.Module):
125
  return x
126
 
127
  class ValueEmbedding(nn.Module):
128
- def __init__(self, vocab_size, model_dim):
129
- super().__init__()
130
- self.embed = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
131
-
132
- @torch.inference_mode()
133
- def forward(self, inputs):
134
- ve = [emb(inputs).bfloat16() for emb in self.embed]
135
- ve = [ve[0], ve[1], ve[2], None, None, None, None, None, None, ve[0], ve[1], ve[2]]
136
- return ve
137
-
138
- class ChronoGPT(nn.Module, PyTorchModelHubMixin):
139
- def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs):
140
- super().__init__()
141
- self.num_heads = num_heads
142
- self.vocab_size = vocab_size # Store vocab_size as instance variable
143
- self.embed = nn.Embedding(vocab_size, model_dim)
144
- self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=(i != 7))
145
- for i in range(num_layers)])
146
- self.value_embeds = ValueEmbedding(vocab_size, model_dim)
147
- self.lm_head = CastedLinear(model_dim, vocab_size)
148
- self.lm_head.weight.data.zero_()
149
- self.num_encoder_layers = num_layers // 2
150
- self.num_decoder_layers = num_layers - self.num_encoder_layers
151
- self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
152
- @torch.inference_mode()
153
- def forward(self, inputs, past_key_values=None):
154
- B = inputs.size(0)
155
- if inputs.dim() == 1:
156
- inputs = inputs.unsqueeze(0) # Add batch dimension if not present
157
-
158
- x0 = norm(self.embed(inputs).bfloat16())
159
- x = x0
160
-
161
- # Modify value embedding handling for batched input
162
- ve = [self.value_embeds(inputs[i].view(-1)) for i in range(B)]
163
- ve = [torch.stack([ve[b][i] for b in range(B)]) if ve[0][i] is not None else None
164
- for i in range(len(ve[0]))]
165
- ve_enc, ve_dec = ve[:self.num_encoder_layers], ve[self.num_encoder_layers:]
166
-
167
- # Handle cached states for batched input
168
- if past_key_values is not None:
169
- for i, block in enumerate(self.blocks):
170
- if block.attn is not None:
171
- block.attn.kv_cache = past_key_values[i]
172
-
173
- present = []
174
- layer_outputs = []
175
- skip_connections = []
176
-
177
- # Process through encoder layers
178
- for i in range(self.num_encoder_layers):
179
- block = self.blocks[i]
180
- x = block(x, ve_enc[i], x0)
181
- if block.attn is not None:
182
- present.append(block.attn.kv_cache)
183
- block.attn.kv_cache = None
184
- skip_connections.append(x)
185
- layer_outputs.append(norm(x))
186
-
187
- # Process through decoder layers
188
- for i in range(self.num_decoder_layers):
189
- x = x + self.skip_weights[i] * skip_connections.pop()
190
- block = self.blocks[self.num_encoder_layers + i]
191
- x = block(x, ve_dec[i], x0)
192
- layer_outputs.append(norm(x))
193
- if block.attn is not None:
194
- present.append(block.attn.kv_cache)
195
- block.attn.kv_cache = None
196
-
197
- x = norm(x)
198
- logits = self.lm_head(x)
199
- logits = 15 * torch.tanh(logits / 15)
200
-
201
- return logits.float(), layer_outputs
202
- @classmethod
203
- def from_pretrained(cls, repo_id, cache_dir=None, **kwargs):
204
- config_path = hf_hub_download(repo_id=repo_id, filename="config.pt", cache_dir=cache_dir)
205
- bin_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir)
206
- config = torch.load(config_path)
207
- model = cls(**config)
208
- model.load_state_dict(torch.load(bin_path))
209
- return model
210
-
211
- class ValueEmbedding_xl(nn.Module):
212
  def __init__(self, vocab_size, model_dim, num_layers=52):
213
  super().__init__()
214
  self.num_layers = num_layers
@@ -228,14 +144,14 @@ class ValueEmbedding_xl(nn.Module):
228
  return encoder + decoder
229
 
230
 
231
- class ChronoGPT_xl(nn.Module, PyTorchModelHubMixin):
232
  def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs):
233
  super().__init__()
234
  self.num_heads = num_heads
235
  self.vocab_size = vocab_size # Store vocab_size as instance variable
236
  self.embed = nn.Embedding(vocab_size, model_dim)
237
  self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=True) for i in range(num_layers)])
238
- self.value_embeds = ValueEmbedding_xl(vocab_size, model_dim, num_layers=num_layers)
239
  self.lm_head = CastedLinear(model_dim, vocab_size)
240
  self.lm_head.weight.data.zero_()
241
  self.num_encoder_layers = num_layers // 2
 
125
  return x
126
 
127
  class ValueEmbedding(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def __init__(self, vocab_size, model_dim, num_layers=52):
129
  super().__init__()
130
  self.num_layers = num_layers
 
144
  return encoder + decoder
145
 
146
 
147
+ class ChronoGPT(nn.Module, PyTorchModelHubMixin):
148
  def __init__(self, vocab_size, num_layers, num_heads, model_dim, **kwargs):
149
  super().__init__()
150
  self.num_heads = num_heads
151
  self.vocab_size = vocab_size # Store vocab_size as instance variable
152
  self.embed = nn.Embedding(vocab_size, model_dim)
153
  self.blocks = nn.ModuleList([Block(model_dim, num_heads, use_attn=True) for i in range(num_layers)])
154
+ self.value_embeds = ValueEmbedding(vocab_size, model_dim, num_layers=num_layers)
155
  self.lm_head = CastedLinear(model_dim, vocab_size)
156
  self.lm_head.weight.data.zero_()
157
  self.num_encoder_layers = num_layers // 2