cyrilvallez HF Staff commited on
Commit
9b4bf51
·
verified ·
1 Parent(s): 18740b7

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +58 -114
custom_generate/generate.py CHANGED
@@ -1,18 +1,22 @@
1
- from typing import Union, Optional, TYPE_CHECKING
 
 
2
  import torch
3
- from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
 
 
 
 
4
  from transformers.generation.utils import (
5
- GenerationMixin,
6
- GenerateNonBeamOutput,
7
  GenerateDecoderOnlyOutput,
 
 
 
8
  )
9
- from transformers.cache_utils import Cache, EncoderDecoderCache, DynamicCache
10
  from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
11
- from transformers.generation.utils import GenerateEncoderDecoderOutput, ALL_CACHE_NAMES
12
  from transformers.utils import ModelOutput
13
- from transformers.configuration_utils import PretrainedConfig
14
- import torch.nn as nn
15
- import logging
16
 
17
  if TYPE_CHECKING:
18
  from transformers.generation.streamers import BaseStreamer
@@ -20,9 +24,7 @@ if TYPE_CHECKING:
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
- def stack_model_outputs(
24
- model_outputs: list[ModelOutput], config: PretrainedConfig
25
- ) -> ModelOutput:
26
  """
27
  Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
28
  specific ModelOutput subclass from the list provided.
@@ -50,17 +52,11 @@ def stack_model_outputs(
50
  # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
51
  if isinstance(data[0][0], tuple):
52
  return tuple(
53
- tuple(
54
- torch.cat([attr[i][j] for attr in data], dim=0)
55
- for j in range(len(data[0][0]))
56
- )
57
  for i in range(len(data[0]))
58
  )
59
  else:
60
- return tuple(
61
- torch.cat([attr[i] for attr in data], dim=0)
62
- for i in range(len(data[0]))
63
- )
64
  elif isinstance(data[0], (int, float)):
65
  # If the elements are integers or floats, return a tensor
66
  return torch.tensor(data)
@@ -92,9 +88,7 @@ def _ranking_fast(
92
  """
93
  norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
94
  norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
95
- cosine_matrix = torch.matmul(
96
- norm_context_hidden, norm_next_hidden.transpose(1, 2)
97
- ).squeeze(-1) # [B*K, S]
98
 
99
  # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
100
  # Using a large negative value for masked positions
@@ -105,9 +99,7 @@ def _ranking_fast(
105
  degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
106
  next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
107
  contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
108
- contrastive_score = torch.stack(
109
- torch.split(contrastive_score, beam_width)
110
- ) # [B, K]
111
  _, selected_idx = contrastive_score.max(dim=-1) # [B]
112
  return selected_idx
113
 
@@ -163,9 +155,7 @@ def _contrastive_search(
163
  f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
164
  )
165
  # init values
166
- has_eos_stopping_criteria = any(
167
- hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
168
- )
169
  top_k = generation_config.top_k
170
  penalty_alpha = generation_config.penalty_alpha
171
  pad_token_id = generation_config._pad_token_tensor
@@ -181,39 +171,22 @@ def _contrastive_search(
181
  scores = () if (return_dict_in_generate and output_scores) else None
182
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
183
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
184
- decoder_hidden_states = (
185
- () if (return_dict_in_generate and output_hidden_states) else None
186
- )
187
 
188
  # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
189
  if return_dict_in_generate and model.config.is_encoder_decoder:
190
- encoder_attentions = (
191
- model_kwargs["encoder_outputs"].get("attentions")
192
- if output_attentions
193
- else None
194
- )
195
- encoder_hidden_states = (
196
- model_kwargs["encoder_outputs"].get("hidden_states")
197
- if output_hidden_states
198
- else None
199
- )
200
 
201
  # keep track of which sequences are already finished
202
  batch_size, cur_len = input_ids.shape[:2]
203
- unfinished_sequences = torch.ones(
204
- batch_size, dtype=torch.long, device=input_ids.device
205
- )
206
- model_kwargs = model._get_initial_cache_position(
207
- cur_len, input_ids.device, model_kwargs
208
- )
209
 
210
  # Create cosine_matrix_mask based on the attention_mask
211
  cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
212
  if model.config.is_encoder_decoder:
213
- if (
214
- "decoder_attention_mask" in model_kwargs
215
- and model_kwargs["decoder_attention_mask"] is not None
216
- ):
217
  cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
218
  else:
219
  cosine_matrix_mask = model_kwargs["attention_mask"]
@@ -221,9 +194,7 @@ def _contrastive_search(
221
 
222
  this_peer_finished = False
223
 
224
- while model._has_unfinished_sequences(
225
- this_peer_finished, synced_gpus, device=input_ids.device
226
- ):
227
  # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
228
  # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
229
  if model_kwargs.get("past_key_values") is None or (
@@ -232,9 +203,7 @@ def _contrastive_search(
232
  ):
233
  # prepare inputs
234
  model_kwargs["use_cache"] = True
235
- model_inputs = model.prepare_inputs_for_generation(
236
- input_ids, **model_kwargs
237
- )
238
 
239
  # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
240
  # the `encoder_outputs`
@@ -256,9 +225,7 @@ def _contrastive_search(
256
  # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
257
  # (the clone itmodel is always small)
258
  # torch.float32 is needed to retain precision for later logits manipulations
259
- logit_for_next_step = outputs.logits[:, -1, :].to(
260
- copy=True, dtype=torch.float32, device=input_ids.device
261
- )
262
 
263
  model_kwargs = model._update_model_kwargs_for_generation(
264
  outputs,
@@ -282,7 +249,18 @@ def _contrastive_search(
282
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
283
  "for contrastive search."
284
  )
285
-
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  # contrastive_search main logic start:
288
  # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
@@ -300,18 +278,14 @@ def _contrastive_search(
300
  scores += (processed_logit_for_next_step,)
301
  if output_attentions:
302
  decoder_attentions += (
303
- (outputs.decoder_attentions,)
304
- if model.config.is_encoder_decoder
305
- else (outputs.attentions,)
306
  )
307
  if model.config.is_encoder_decoder:
308
  cross_attentions += (outputs.cross_attentions,)
309
 
310
  if output_hidden_states:
311
  decoder_hidden_states += (
312
- (outputs.decoder_hidden_states,)
313
- if model.config.is_encoder_decoder
314
- else (outputs.hidden_states,)
315
  )
316
 
317
  # This is needed to properly delete outputs.logits which may be very large for this first iteration
@@ -323,8 +297,7 @@ def _contrastive_search(
323
  past = model_kwargs["past_key_values"]
324
  # If it is a static cache, modify it in-place layer after layer to save memory
325
  if isinstance(past, DynamicCache) or (
326
- isinstance(past, EncoderDecoderCache)
327
- and isinstance(past.self_attention_cache, DynamicCache)
328
  ):
329
  past.batch_repeat_interleave(top_k)
330
  else:
@@ -344,9 +317,7 @@ def _contrastive_search(
344
  all_outputs = []
345
  for i in range(top_k):
346
  # compute the candidate tokens by the language model and collect their hidden_states
347
- next_model_inputs = model.prepare_inputs_for_generation(
348
- top_k_ids[:, i].view(-1, 1), **model_kwargs
349
- )
350
 
351
  outputs = model(
352
  **next_model_inputs,
@@ -356,9 +327,7 @@ def _contrastive_search(
356
  )
357
  if isinstance(outputs["past_key_values"], DynamicCache) or (
358
  isinstance(outputs["past_key_values"], EncoderDecoderCache)
359
- and isinstance(
360
- outputs["past_key_values"].self_attention_cache, DynamicCache
361
- )
362
  ):
363
  # Remove past K-V from output since we don't need to stack later
364
  outputs["past_key_values"] = None
@@ -376,9 +345,7 @@ def _contrastive_search(
376
  else:
377
  # compute the candidate tokens by the language model and collect their hidden_states
378
  # assembles top_k_ids into batch of size k
379
- next_model_inputs = model.prepare_inputs_for_generation(
380
- top_k_ids.view(-1, 1), **model_kwargs
381
- )
382
 
383
  outputs = model(
384
  **next_model_inputs,
@@ -424,9 +391,7 @@ def _contrastive_search(
424
  selected_idx = selected_idx.to("cpu")
425
 
426
  # This will be used instead of the previous inneficient torch.stack(torch.split())
427
- augmented_idx = torch.tensor(
428
- [x + i * top_k for i, x in enumerate(selected_idx)]
429
- )
430
 
431
  # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
432
  # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
@@ -434,15 +399,11 @@ def _contrastive_search(
434
  next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
435
  next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
436
  next_hidden = next_hidden[range(batch_size), selected_idx, :]
437
- last_hidden_states = torch.cat(
438
- [last_hidden_states, next_hidden.unsqueeze(1)], dim=1
439
- )
440
 
441
  next_decoder_hidden_states = ()
442
  for layer in full_hidden_states:
443
- layer = torch.stack(torch.split(layer, top_k))[
444
- range(batch_size), selected_idx, :
445
- ]
446
  next_decoder_hidden_states += (layer,)
447
 
448
  # generate past_key_values cache of only the selected token
@@ -462,9 +423,7 @@ def _contrastive_search(
462
  else:
463
  next_past_key_values = None
464
  for possible_cache_name in ALL_CACHE_NAMES:
465
- next_past_key_values = next_past_key_values or getattr(
466
- outputs, possible_cache_name, None
467
- )
468
  # Do it in-place layer per layer to save memory
469
  if isinstance(next_past_key_values, DynamicCache) or (
470
  isinstance(next_past_key_values, EncoderDecoderCache)
@@ -482,9 +441,7 @@ def _contrastive_search(
482
 
483
  next_past_key_values = tuple(new_key_values)
484
 
485
- logit_for_next_step = torch.stack(torch.split(logits, top_k))[
486
- range(batch_size), selected_idx, :
487
- ]
488
  logit_for_next_step = logit_for_next_step.to(input_ids.device)
489
 
490
  # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
@@ -493,14 +450,10 @@ def _contrastive_search(
493
  next_step_decoder_attentions = ()
494
  if output_attentions:
495
  for layer in outputs.cross_attentions:
496
- layer = torch.stack(torch.split(layer, top_k, dim=0))[
497
- range(batch_size), selected_idx, ...
498
- ]
499
  next_step_cross_attentions += (layer,)
500
  for layer in outputs.decoder_attentions:
501
- layer = torch.stack(torch.split(layer, top_k, dim=0))[
502
- range(batch_size), selected_idx, ...
503
- ]
504
  next_step_decoder_attentions += (layer,)
505
  outputs = Seq2SeqLMOutput(
506
  past_key_values=next_past_key_values,
@@ -512,9 +465,7 @@ def _contrastive_search(
512
  next_step_attentions = ()
513
  if output_attentions:
514
  for layer in outputs.attentions:
515
- layer = torch.stack(torch.split(layer, top_k, dim=0))[
516
- range(batch_size), selected_idx, ...
517
- ]
518
  next_step_attentions += (layer,)
519
  outputs = CausalLMOutputWithPast(
520
  past_key_values=next_past_key_values,
@@ -534,9 +485,7 @@ def _contrastive_search(
534
 
535
  # finished sentences should have their next token be a padding token
536
  if has_eos_stopping_criteria:
537
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
538
- 1 - unfinished_sequences
539
- )
540
 
541
  # update generated ids, model inputs, and length for next step
542
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@@ -544,9 +493,7 @@ def _contrastive_search(
544
  streamer.put(next_tokens.cpu())
545
 
546
  # stop when each sentence is finished
547
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(
548
- input_ids, scores
549
- )
550
  this_peer_finished = unfinished_sequences.max() == 0
551
 
552
  if streamer is not None:
@@ -558,9 +505,7 @@ def _contrastive_search(
558
  if model_kwargs.get("past_key_values") is not None:
559
  if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
560
  isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
561
- and isinstance(
562
- model_kwargs["past_key_values"].self_attention_cache, DynamicCache
563
- )
564
  ):
565
  model_kwargs["past_key_values"].crop(-1)
566
  else:
@@ -607,8 +552,7 @@ def generate(model, *args, **kwargs):
607
  """
608
  cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
609
  if cache_implementation != "dynamic_full" and (
610
- "sliding_attention"
611
- in getattr(model.config.get_text_config(), "layer_types", [])
612
  or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
613
  ):
614
  logger.warning_once(
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, Optional, Union
3
+
4
  import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
9
+ from transformers.configuration_utils import PretrainedConfig
10
  from transformers.generation.utils import (
11
+ ALL_CACHE_NAMES,
 
12
  GenerateDecoderOnlyOutput,
13
+ GenerateEncoderDecoderOutput,
14
+ GenerateNonBeamOutput,
15
+ GenerationMixin,
16
  )
 
17
  from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
 
18
  from transformers.utils import ModelOutput
19
+
 
 
20
 
21
  if TYPE_CHECKING:
22
  from transformers.generation.streamers import BaseStreamer
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
27
+ def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
 
 
28
  """
29
  Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
30
  specific ModelOutput subclass from the list provided.
 
52
  # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
53
  if isinstance(data[0][0], tuple):
54
  return tuple(
55
+ tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
 
 
 
56
  for i in range(len(data[0]))
57
  )
58
  else:
59
+ return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
 
 
 
60
  elif isinstance(data[0], (int, float)):
61
  # If the elements are integers or floats, return a tensor
62
  return torch.tensor(data)
 
88
  """
89
  norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
90
  norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
91
+ cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
 
 
92
 
93
  # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
94
  # Using a large negative value for masked positions
 
99
  degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
100
  next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
101
  contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
102
+ contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
 
 
103
  _, selected_idx = contrastive_score.max(dim=-1) # [B]
104
  return selected_idx
105
 
 
155
  f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
156
  )
157
  # init values
158
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
 
 
159
  top_k = generation_config.top_k
160
  penalty_alpha = generation_config.penalty_alpha
161
  pad_token_id = generation_config._pad_token_tensor
 
171
  scores = () if (return_dict_in_generate and output_scores) else None
172
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
173
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
174
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
 
 
175
 
176
  # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
177
  if return_dict_in_generate and model.config.is_encoder_decoder:
178
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
179
+ encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
 
 
 
 
 
 
 
 
180
 
181
  # keep track of which sequences are already finished
182
  batch_size, cur_len = input_ids.shape[:2]
183
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
184
+ model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
 
 
 
 
185
 
186
  # Create cosine_matrix_mask based on the attention_mask
187
  cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
188
  if model.config.is_encoder_decoder:
189
+ if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
 
 
 
190
  cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
191
  else:
192
  cosine_matrix_mask = model_kwargs["attention_mask"]
 
194
 
195
  this_peer_finished = False
196
 
197
+ while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
 
 
198
  # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
199
  # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
200
  if model_kwargs.get("past_key_values") is None or (
 
203
  ):
204
  # prepare inputs
205
  model_kwargs["use_cache"] = True
206
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
 
207
 
208
  # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
209
  # the `encoder_outputs`
 
225
  # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
226
  # (the clone itmodel is always small)
227
  # torch.float32 is needed to retain precision for later logits manipulations
228
+ logit_for_next_step = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
 
 
229
 
230
  model_kwargs = model._update_model_kwargs_for_generation(
231
  outputs,
 
249
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
250
  "for contrastive search."
251
  )
252
+ # Only those caches have the necesary methods
253
+ elif not (
254
+ isinstance(past_key_values, DynamicCache)
255
+ or (
256
+ isinstance(past_key_values, EncoderDecoderCache)
257
+ and isinstance(past_key_values.self_attention_cache, DynamicCache)
258
+ )
259
+ ):
260
+ raise ValueError(
261
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
262
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
263
+ )
264
 
265
  # contrastive_search main logic start:
266
  # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
 
278
  scores += (processed_logit_for_next_step,)
279
  if output_attentions:
280
  decoder_attentions += (
281
+ (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
 
 
282
  )
283
  if model.config.is_encoder_decoder:
284
  cross_attentions += (outputs.cross_attentions,)
285
 
286
  if output_hidden_states:
287
  decoder_hidden_states += (
288
+ (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
 
 
289
  )
290
 
291
  # This is needed to properly delete outputs.logits which may be very large for this first iteration
 
297
  past = model_kwargs["past_key_values"]
298
  # If it is a static cache, modify it in-place layer after layer to save memory
299
  if isinstance(past, DynamicCache) or (
300
+ isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
 
301
  ):
302
  past.batch_repeat_interleave(top_k)
303
  else:
 
317
  all_outputs = []
318
  for i in range(top_k):
319
  # compute the candidate tokens by the language model and collect their hidden_states
320
+ next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
 
 
321
 
322
  outputs = model(
323
  **next_model_inputs,
 
327
  )
328
  if isinstance(outputs["past_key_values"], DynamicCache) or (
329
  isinstance(outputs["past_key_values"], EncoderDecoderCache)
330
+ and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
 
 
331
  ):
332
  # Remove past K-V from output since we don't need to stack later
333
  outputs["past_key_values"] = None
 
345
  else:
346
  # compute the candidate tokens by the language model and collect their hidden_states
347
  # assembles top_k_ids into batch of size k
348
+ next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
 
 
349
 
350
  outputs = model(
351
  **next_model_inputs,
 
391
  selected_idx = selected_idx.to("cpu")
392
 
393
  # This will be used instead of the previous inneficient torch.stack(torch.split())
394
+ augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
 
 
395
 
396
  # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
397
  # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
 
399
  next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
400
  next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
401
  next_hidden = next_hidden[range(batch_size), selected_idx, :]
402
+ last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
 
 
403
 
404
  next_decoder_hidden_states = ()
405
  for layer in full_hidden_states:
406
+ layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
 
 
407
  next_decoder_hidden_states += (layer,)
408
 
409
  # generate past_key_values cache of only the selected token
 
423
  else:
424
  next_past_key_values = None
425
  for possible_cache_name in ALL_CACHE_NAMES:
426
+ next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
 
 
427
  # Do it in-place layer per layer to save memory
428
  if isinstance(next_past_key_values, DynamicCache) or (
429
  isinstance(next_past_key_values, EncoderDecoderCache)
 
441
 
442
  next_past_key_values = tuple(new_key_values)
443
 
444
+ logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
 
 
445
  logit_for_next_step = logit_for_next_step.to(input_ids.device)
446
 
447
  # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
 
450
  next_step_decoder_attentions = ()
451
  if output_attentions:
452
  for layer in outputs.cross_attentions:
453
+ layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
 
 
454
  next_step_cross_attentions += (layer,)
455
  for layer in outputs.decoder_attentions:
456
+ layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
 
 
457
  next_step_decoder_attentions += (layer,)
458
  outputs = Seq2SeqLMOutput(
459
  past_key_values=next_past_key_values,
 
465
  next_step_attentions = ()
466
  if output_attentions:
467
  for layer in outputs.attentions:
468
+ layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
 
 
469
  next_step_attentions += (layer,)
470
  outputs = CausalLMOutputWithPast(
471
  past_key_values=next_past_key_values,
 
485
 
486
  # finished sentences should have their next token be a padding token
487
  if has_eos_stopping_criteria:
488
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 
 
489
 
490
  # update generated ids, model inputs, and length for next step
491
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 
493
  streamer.put(next_tokens.cpu())
494
 
495
  # stop when each sentence is finished
496
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
 
 
497
  this_peer_finished = unfinished_sequences.max() == 0
498
 
499
  if streamer is not None:
 
505
  if model_kwargs.get("past_key_values") is not None:
506
  if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
507
  isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
508
+ and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
 
 
509
  ):
510
  model_kwargs["past_key_values"].crop(-1)
511
  else:
 
552
  """
553
  cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
554
  if cache_implementation != "dynamic_full" and (
555
+ "sliding_attention" in getattr(model.config.get_text_config(), "layer_types", [])
 
556
  or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
557
  ):
558
  logger.warning_once(