klemenk commited on
Commit
79296a4
·
verified ·
1 Parent(s): 1d352e0

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +20 -26
modeling_auristream.py CHANGED
@@ -242,26 +242,35 @@ class AuriStream(PreTrainedModel):
242
  return sampled
243
 
244
  @torch.no_grad()
245
- def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0,
246
- top_k=500, top_p=0.5, seed=None):
 
 
 
 
 
 
 
247
  """
248
  Parameters:
249
- seq: torch.Tensor of shape (b, t, n_freq_bins)
250
- Input cochleagram to use for generation
251
  n_tokens: int
252
- Number of time bins to predict
253
  temp: float
254
  Temperature for sampling logits
 
 
 
 
255
  seed: int
256
  Random seed for sampling
257
 
258
  Returns:
259
- pred_coch: torch.Tensor of shape (b, t, n_freq_bins)
260
- The predicted cochleagram
261
- all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins)
262
- The logits for each time step
263
- all_embs: (optional if return_embs is not None) list of torch.Tensor
264
- The embeddings for each transformer block
265
  """
266
 
267
  # Set seed if provided
@@ -277,14 +286,6 @@ class AuriStream(PreTrainedModel):
277
  # grab shape of the cochleagram
278
  b, t = seq.size()
279
 
280
- # TODO: double check this works then delete the block bellow:
281
- # pass the given input through the model to get the predictions and cache
282
- # the k and v values for each transformer block in the process
283
- # pos = torch.arange(0, t, dtype=torch.long, device=device)
284
- # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
285
- # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
286
- # x = self.transformer.drop(tok_emb + pos_emb)
287
-
288
  #### Embed conditioning sequence into KV cache
289
 
290
  tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
@@ -322,13 +323,6 @@ class AuriStream(PreTrainedModel):
322
  # using the last embedding of the input
323
  for i in range(n_tokens-1):
324
 
325
- # TODO: double check this works then delete the block bellow:
326
- # # Get the emb and pos embedding of just the last token
327
- # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
328
- # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
329
- # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
330
- # x = self.transformer.drop(tok_emb + pos_emb)
331
-
332
  # Get the emb and pos embedding of just the last token
333
  tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
334
  # if wpe exists in self.transformer apply leanred positional embedding
 
242
  return sampled
243
 
244
  @torch.no_grad()
245
+ def generate(
246
+ self,
247
+ seq: torch.Tensor,
248
+ n_tokens: int = 1,
249
+ temp: float = 1.0,
250
+ top_k: int = None,
251
+ top_p: float = None,
252
+ seed: int = None,
253
+ ):
254
  """
255
  Parameters:
256
+ seq: torch.Tensor of shape (b, t)
257
+ Input cochlear tokens to condition the generation
258
  n_tokens: int
259
+ Number of future tokens (5ms time bins) to predict
260
  temp: float
261
  Temperature for sampling logits
262
+ top_k: int
263
+ Restrict sampling to k tokens with highest probability (sample from all tokens if None)
264
+ top_p: float
265
+ Restrict sampling to most probable tokens with cumulative probability of p (sample form all tokens if None)
266
  seed: int
267
  Random seed for sampling
268
 
269
  Returns:
270
+ pred_coch: torch.Tensor of shape (b, t)
271
+ The generated cochlear tokens
272
+ all_logits: torch.Tensor of shape (b, n_tokens, vocab_size)
273
+ The logits at each time step
 
 
274
  """
275
 
276
  # Set seed if provided
 
286
  # grab shape of the cochleagram
287
  b, t = seq.size()
288
 
 
 
 
 
 
 
 
 
289
  #### Embed conditioning sequence into KV cache
290
 
291
  tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
 
323
  # using the last embedding of the input
324
  for i in range(n_tokens-1):
325
 
 
 
 
 
 
 
 
326
  # Get the emb and pos embedding of just the last token
327
  tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
328
  # if wpe exists in self.transformer apply leanred positional embedding