Update modeling_auristream.py
Browse files- modeling_auristream.py +20 -26
modeling_auristream.py
CHANGED
@@ -242,26 +242,35 @@ class AuriStream(PreTrainedModel):
|
|
242 |
return sampled
|
243 |
|
244 |
@torch.no_grad()
|
245 |
-
def generate(
|
246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
"""
|
248 |
Parameters:
|
249 |
-
seq: torch.Tensor of shape (b, t
|
250 |
-
Input
|
251 |
n_tokens: int
|
252 |
-
Number of time bins to predict
|
253 |
temp: float
|
254 |
Temperature for sampling logits
|
|
|
|
|
|
|
|
|
255 |
seed: int
|
256 |
Random seed for sampling
|
257 |
|
258 |
Returns:
|
259 |
-
pred_coch: torch.Tensor of shape (b, t
|
260 |
-
The
|
261 |
-
all_logits:
|
262 |
-
The logits
|
263 |
-
all_embs: (optional if return_embs is not None) list of torch.Tensor
|
264 |
-
The embeddings for each transformer block
|
265 |
"""
|
266 |
|
267 |
# Set seed if provided
|
@@ -277,14 +286,6 @@ class AuriStream(PreTrainedModel):
|
|
277 |
# grab shape of the cochleagram
|
278 |
b, t = seq.size()
|
279 |
|
280 |
-
# TODO: double check this works then delete the block bellow:
|
281 |
-
# pass the given input through the model to get the predictions and cache
|
282 |
-
# the k and v values for each transformer block in the process
|
283 |
-
# pos = torch.arange(0, t, dtype=torch.long, device=device)
|
284 |
-
# tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
|
285 |
-
# pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
286 |
-
# x = self.transformer.drop(tok_emb + pos_emb)
|
287 |
-
|
288 |
#### Embed conditioning sequence into KV cache
|
289 |
|
290 |
tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
|
@@ -322,13 +323,6 @@ class AuriStream(PreTrainedModel):
|
|
322 |
# using the last embedding of the input
|
323 |
for i in range(n_tokens-1):
|
324 |
|
325 |
-
# TODO: double check this works then delete the block bellow:
|
326 |
-
# # Get the emb and pos embedding of just the last token
|
327 |
-
# pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t)
|
328 |
-
# tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
|
329 |
-
# pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
330 |
-
# x = self.transformer.drop(tok_emb + pos_emb)
|
331 |
-
|
332 |
# Get the emb and pos embedding of just the last token
|
333 |
tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
|
334 |
# if wpe exists in self.transformer apply leanred positional embedding
|
|
|
242 |
return sampled
|
243 |
|
244 |
@torch.no_grad()
|
245 |
+
def generate(
|
246 |
+
self,
|
247 |
+
seq: torch.Tensor,
|
248 |
+
n_tokens: int = 1,
|
249 |
+
temp: float = 1.0,
|
250 |
+
top_k: int = None,
|
251 |
+
top_p: float = None,
|
252 |
+
seed: int = None,
|
253 |
+
):
|
254 |
"""
|
255 |
Parameters:
|
256 |
+
seq: torch.Tensor of shape (b, t)
|
257 |
+
Input cochlear tokens to condition the generation
|
258 |
n_tokens: int
|
259 |
+
Number of future tokens (5ms time bins) to predict
|
260 |
temp: float
|
261 |
Temperature for sampling logits
|
262 |
+
top_k: int
|
263 |
+
Restrict sampling to k tokens with highest probability (sample from all tokens if None)
|
264 |
+
top_p: float
|
265 |
+
Restrict sampling to most probable tokens with cumulative probability of p (sample form all tokens if None)
|
266 |
seed: int
|
267 |
Random seed for sampling
|
268 |
|
269 |
Returns:
|
270 |
+
pred_coch: torch.Tensor of shape (b, t)
|
271 |
+
The generated cochlear tokens
|
272 |
+
all_logits: torch.Tensor of shape (b, n_tokens, vocab_size)
|
273 |
+
The logits at each time step
|
|
|
|
|
274 |
"""
|
275 |
|
276 |
# Set seed if provided
|
|
|
286 |
# grab shape of the cochleagram
|
287 |
b, t = seq.size()
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
#### Embed conditioning sequence into KV cache
|
290 |
|
291 |
tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
|
|
|
323 |
# using the last embedding of the input
|
324 |
for i in range(n_tokens-1):
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
# Get the emb and pos embedding of just the last token
|
327 |
tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd)
|
328 |
# if wpe exists in self.transformer apply leanred positional embedding
|