joaogante HF Staff cyrilvallez HF Staff commited on
Commit
bbf36bb
·
verified ·
1 Parent(s): 245ec5a

Update cache format (#3)

Browse files

- Update custom_generate/generate.py (18740b73b7543c98b8961bc5489661f181889a2f)
- Update custom_generate/generate.py (9b4bf516f80bb86996ecf6c8e8c99ed1a1afdead)
- Update custom_generate/generate.py (def1b87d83f9a9d7a52c517e21e517946ffcf67b)
- Update custom_generate/generate.py (525fd175fe5561603d20af6cd032c33f5f91b52c)


Co-authored-by: Cyril Vallez <[email protected]>

Files changed (1) hide show
  1. custom_generate/generate.py +60 -176
custom_generate/generate.py CHANGED
@@ -1,18 +1,22 @@
1
- from typing import Union, Optional, TYPE_CHECKING
 
 
2
  import torch
3
- from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
 
 
 
 
4
  from transformers.generation.utils import (
5
- GenerationMixin,
6
- GenerateNonBeamOutput,
7
  GenerateDecoderOnlyOutput,
 
 
 
8
  )
9
- from transformers.cache_utils import Cache, EncoderDecoderCache, DynamicCache
10
  from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
11
- from transformers.generation.utils import GenerateEncoderDecoderOutput, ALL_CACHE_NAMES
12
  from transformers.utils import ModelOutput
13
- from transformers.configuration_utils import PretrainedConfig
14
- import torch.nn as nn
15
- import logging
16
 
17
  if TYPE_CHECKING:
18
  from transformers.generation.streamers import BaseStreamer
@@ -20,9 +24,7 @@ if TYPE_CHECKING:
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
- def stack_model_outputs(
24
- model_outputs: list[ModelOutput], config: PretrainedConfig
25
- ) -> ModelOutput:
26
  """
27
  Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
28
  specific ModelOutput subclass from the list provided.
@@ -50,17 +52,11 @@ def stack_model_outputs(
50
  # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
51
  if isinstance(data[0][0], tuple):
52
  return tuple(
53
- tuple(
54
- torch.cat([attr[i][j] for attr in data], dim=0)
55
- for j in range(len(data[0][0]))
56
- )
57
  for i in range(len(data[0]))
58
  )
59
  else:
60
- return tuple(
61
- torch.cat([attr[i] for attr in data], dim=0)
62
- for i in range(len(data[0]))
63
- )
64
  elif isinstance(data[0], (int, float)):
65
  # If the elements are integers or floats, return a tensor
66
  return torch.tensor(data)
@@ -92,9 +88,7 @@ def _ranking_fast(
92
  """
93
  norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
94
  norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
95
- cosine_matrix = torch.matmul(
96
- norm_context_hidden, norm_next_hidden.transpose(1, 2)
97
- ).squeeze(-1) # [B*K, S]
98
 
99
  # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
100
  # Using a large negative value for masked positions
@@ -105,9 +99,7 @@ def _ranking_fast(
105
  degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
106
  next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
107
  contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
108
- contrastive_score = torch.stack(
109
- torch.split(contrastive_score, beam_width)
110
- ) # [B, K]
111
  _, selected_idx = contrastive_score.max(dim=-1) # [B]
112
  return selected_idx
113
 
@@ -163,9 +155,7 @@ def _contrastive_search(
163
  f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
164
  )
165
  # init values
166
- has_eos_stopping_criteria = any(
167
- hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
168
- )
169
  top_k = generation_config.top_k
170
  penalty_alpha = generation_config.penalty_alpha
171
  pad_token_id = generation_config._pad_token_tensor
@@ -181,39 +171,22 @@ def _contrastive_search(
181
  scores = () if (return_dict_in_generate and output_scores) else None
182
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
183
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
184
- decoder_hidden_states = (
185
- () if (return_dict_in_generate and output_hidden_states) else None
186
- )
187
 
188
  # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
189
  if return_dict_in_generate and model.config.is_encoder_decoder:
190
- encoder_attentions = (
191
- model_kwargs["encoder_outputs"].get("attentions")
192
- if output_attentions
193
- else None
194
- )
195
- encoder_hidden_states = (
196
- model_kwargs["encoder_outputs"].get("hidden_states")
197
- if output_hidden_states
198
- else None
199
- )
200
 
201
  # keep track of which sequences are already finished
202
  batch_size, cur_len = input_ids.shape[:2]
203
- unfinished_sequences = torch.ones(
204
- batch_size, dtype=torch.long, device=input_ids.device
205
- )
206
- model_kwargs = model._get_initial_cache_position(
207
- cur_len, input_ids.device, model_kwargs
208
- )
209
 
210
  # Create cosine_matrix_mask based on the attention_mask
211
  cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
212
  if model.config.is_encoder_decoder:
213
- if (
214
- "decoder_attention_mask" in model_kwargs
215
- and model_kwargs["decoder_attention_mask"] is not None
216
- ):
217
  cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
218
  else:
219
  cosine_matrix_mask = model_kwargs["attention_mask"]
@@ -221,9 +194,7 @@ def _contrastive_search(
221
 
222
  this_peer_finished = False
223
 
224
- while model._has_unfinished_sequences(
225
- this_peer_finished, synced_gpus, device=input_ids.device
226
- ):
227
  # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
228
  # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
229
  if model_kwargs.get("past_key_values") is None or (
@@ -232,9 +203,7 @@ def _contrastive_search(
232
  ):
233
  # prepare inputs
234
  model_kwargs["use_cache"] = True
235
- model_inputs = model.prepare_inputs_for_generation(
236
- input_ids, **model_kwargs
237
- )
238
 
239
  # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
240
  # the `encoder_outputs`
@@ -256,9 +225,7 @@ def _contrastive_search(
256
  # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
257
  # (the clone itmodel is always small)
258
  # torch.float32 is needed to retain precision for later logits manipulations
259
- logit_for_next_step = outputs.logits[:, -1, :].to(
260
- copy=True, dtype=torch.float32, device=input_ids.device
261
- )
262
 
263
  model_kwargs = model._update_model_kwargs_for_generation(
264
  outputs,
@@ -282,13 +249,17 @@ def _contrastive_search(
282
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
283
  "for contrastive search."
284
  )
285
- elif (
286
- not isinstance(past_key_values[0], (tuple, torch.Tensor))
287
- or past_key_values[0][0].shape[0] != batch_size
 
 
 
 
288
  ):
289
  raise ValueError(
290
- f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
291
- "used for contrastive search without further modifications."
292
  )
293
 
294
  # contrastive_search main logic start:
@@ -307,18 +278,14 @@ def _contrastive_search(
307
  scores += (processed_logit_for_next_step,)
308
  if output_attentions:
309
  decoder_attentions += (
310
- (outputs.decoder_attentions,)
311
- if model.config.is_encoder_decoder
312
- else (outputs.attentions,)
313
  )
314
  if model.config.is_encoder_decoder:
315
  cross_attentions += (outputs.cross_attentions,)
316
 
317
  if output_hidden_states:
318
  decoder_hidden_states += (
319
- (outputs.decoder_hidden_states,)
320
- if model.config.is_encoder_decoder
321
- else (outputs.hidden_states,)
322
  )
323
 
324
  # This is needed to properly delete outputs.logits which may be very large for this first iteration
@@ -327,33 +294,13 @@ def _contrastive_search(
327
 
328
  if not sequential:
329
  # Replicates the new past_key_values to match the `top_k` candidates
330
- past = model_kwargs["past_key_values"]
331
- # If it is a static cache, modify it in-place layer after layer to save memory
332
- if isinstance(past, DynamicCache) or (
333
- isinstance(past, EncoderDecoderCache)
334
- and isinstance(past.self_attention_cache, DynamicCache)
335
- ):
336
- past.batch_repeat_interleave(top_k)
337
- else:
338
- new_key_values = []
339
- for layer in past:
340
- items = []
341
- # item is either the key or the value matrix
342
- for item in layer:
343
- items.append(item.repeat_interleave(top_k, dim=0))
344
- new_key_values.append(tuple(items))
345
-
346
- past = tuple(new_key_values)
347
-
348
- model_kwargs["past_key_values"] = past
349
 
350
  if sequential:
351
  all_outputs = []
352
  for i in range(top_k):
353
  # compute the candidate tokens by the language model and collect their hidden_states
354
- next_model_inputs = model.prepare_inputs_for_generation(
355
- top_k_ids[:, i].view(-1, 1), **model_kwargs
356
- )
357
 
358
  outputs = model(
359
  **next_model_inputs,
@@ -361,21 +308,10 @@ def _contrastive_search(
361
  output_hidden_states=True,
362
  output_attentions=output_attentions,
363
  )
364
- if isinstance(outputs["past_key_values"], DynamicCache) or (
365
- isinstance(outputs["past_key_values"], EncoderDecoderCache)
366
- and isinstance(
367
- outputs["past_key_values"].self_attention_cache, DynamicCache
368
- )
369
- ):
370
- # Remove past K-V from output since we don't need to stack later
371
- outputs["past_key_values"] = None
372
- # Remove last token from past K-V since we don't want to append it at this point
373
- model_kwargs["past_key_values"].crop(-1)
374
- else:
375
- raise ValueError(
376
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
377
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
378
- )
379
 
380
  all_outputs.append(outputs)
381
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
@@ -383,9 +319,7 @@ def _contrastive_search(
383
  else:
384
  # compute the candidate tokens by the language model and collect their hidden_states
385
  # assembles top_k_ids into batch of size k
386
- next_model_inputs = model.prepare_inputs_for_generation(
387
- top_k_ids.view(-1, 1), **model_kwargs
388
- )
389
 
390
  outputs = model(
391
  **next_model_inputs,
@@ -431,9 +365,7 @@ def _contrastive_search(
431
  selected_idx = selected_idx.to("cpu")
432
 
433
  # This will be used instead of the previous inneficient torch.stack(torch.split())
434
- augmented_idx = torch.tensor(
435
- [x + i * top_k for i, x in enumerate(selected_idx)]
436
- )
437
 
438
  # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
439
  # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
@@ -441,15 +373,11 @@ def _contrastive_search(
441
  next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
442
  next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
443
  next_hidden = next_hidden[range(batch_size), selected_idx, :]
444
- last_hidden_states = torch.cat(
445
- [last_hidden_states, next_hidden.unsqueeze(1)], dim=1
446
- )
447
 
448
  next_decoder_hidden_states = ()
449
  for layer in full_hidden_states:
450
- layer = torch.stack(torch.split(layer, top_k))[
451
- range(batch_size), selected_idx, :
452
- ]
453
  next_decoder_hidden_states += (layer,)
454
 
455
  # generate past_key_values cache of only the selected token
@@ -469,29 +397,10 @@ def _contrastive_search(
469
  else:
470
  next_past_key_values = None
471
  for possible_cache_name in ALL_CACHE_NAMES:
472
- next_past_key_values = next_past_key_values or getattr(
473
- outputs, possible_cache_name, None
474
- )
475
- # Do it in-place layer per layer to save memory
476
- if isinstance(next_past_key_values, DynamicCache) or (
477
- isinstance(next_past_key_values, EncoderDecoderCache)
478
- and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
479
- ):
480
- next_past_key_values.batch_select_indices(augmented_idx)
481
- else:
482
- new_key_values = []
483
- for layer in next_past_key_values:
484
- items = []
485
- # item is either the key or the value matrix
486
- for item in layer:
487
- items.append(item[augmented_idx, ...])
488
- new_key_values.append(tuple(items))
489
-
490
- next_past_key_values = tuple(new_key_values)
491
-
492
- logit_for_next_step = torch.stack(torch.split(logits, top_k))[
493
- range(batch_size), selected_idx, :
494
- ]
495
  logit_for_next_step = logit_for_next_step.to(input_ids.device)
496
 
497
  # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
@@ -500,14 +409,10 @@ def _contrastive_search(
500
  next_step_decoder_attentions = ()
501
  if output_attentions:
502
  for layer in outputs.cross_attentions:
503
- layer = torch.stack(torch.split(layer, top_k, dim=0))[
504
- range(batch_size), selected_idx, ...
505
- ]
506
  next_step_cross_attentions += (layer,)
507
  for layer in outputs.decoder_attentions:
508
- layer = torch.stack(torch.split(layer, top_k, dim=0))[
509
- range(batch_size), selected_idx, ...
510
- ]
511
  next_step_decoder_attentions += (layer,)
512
  outputs = Seq2SeqLMOutput(
513
  past_key_values=next_past_key_values,
@@ -519,9 +424,7 @@ def _contrastive_search(
519
  next_step_attentions = ()
520
  if output_attentions:
521
  for layer in outputs.attentions:
522
- layer = torch.stack(torch.split(layer, top_k, dim=0))[
523
- range(batch_size), selected_idx, ...
524
- ]
525
  next_step_attentions += (layer,)
526
  outputs = CausalLMOutputWithPast(
527
  past_key_values=next_past_key_values,
@@ -541,9 +444,7 @@ def _contrastive_search(
541
 
542
  # finished sentences should have their next token be a padding token
543
  if has_eos_stopping_criteria:
544
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
545
- 1 - unfinished_sequences
546
- )
547
 
548
  # update generated ids, model inputs, and length for next step
549
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@@ -551,9 +452,7 @@ def _contrastive_search(
551
  streamer.put(next_tokens.cpu())
552
 
553
  # stop when each sentence is finished
554
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(
555
- input_ids, scores
556
- )
557
  this_peer_finished = unfinished_sequences.max() == 0
558
 
559
  if streamer is not None:
@@ -563,21 +462,7 @@ def _contrastive_search(
563
  # Contrastive search works by forward looking at the next token, so we need to exclude it from
564
  # `past_key_values` to be consistent with the other decoding methods
565
  if model_kwargs.get("past_key_values") is not None:
566
- if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
567
- isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
568
- and isinstance(
569
- model_kwargs["past_key_values"].self_attention_cache, DynamicCache
570
- )
571
- ):
572
- model_kwargs["past_key_values"].crop(-1)
573
- else:
574
- past_key_values = []
575
- for layer in model_kwargs["past_key_values"]:
576
- layer_past_key_values = []
577
- for item in layer:
578
- layer_past_key_values.append(item[..., :-1, :])
579
- past_key_values.append(tuple(layer_past_key_values))
580
- model_kwargs["past_key_values"] = tuple(past_key_values)
581
 
582
  if model.config.is_encoder_decoder:
583
  return GenerateEncoderDecoderOutput(
@@ -614,8 +499,7 @@ def generate(model, *args, **kwargs):
614
  """
615
  cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
616
  if cache_implementation != "dynamic_full" and (
617
- "sliding_attention"
618
- in getattr(model.config.get_text_config(), "layer_types", [])
619
  or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
620
  ):
621
  logger.warning_once(
 
1
+ import logging
2
+ from typing import TYPE_CHECKING, Optional, Union
3
+
4
  import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
9
+ from transformers.configuration_utils import PretrainedConfig
10
  from transformers.generation.utils import (
11
+ ALL_CACHE_NAMES,
 
12
  GenerateDecoderOnlyOutput,
13
+ GenerateEncoderDecoderOutput,
14
+ GenerateNonBeamOutput,
15
+ GenerationMixin,
16
  )
 
17
  from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
 
18
  from transformers.utils import ModelOutput
19
+
 
 
20
 
21
  if TYPE_CHECKING:
22
  from transformers.generation.streamers import BaseStreamer
 
24
  logger = logging.getLogger(__name__)
25
 
26
 
27
+ def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
 
 
28
  """
29
  Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
30
  specific ModelOutput subclass from the list provided.
 
52
  # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
53
  if isinstance(data[0][0], tuple):
54
  return tuple(
55
+ tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
 
 
 
56
  for i in range(len(data[0]))
57
  )
58
  else:
59
+ return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
 
 
 
60
  elif isinstance(data[0], (int, float)):
61
  # If the elements are integers or floats, return a tensor
62
  return torch.tensor(data)
 
88
  """
89
  norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
90
  norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
91
+ cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
 
 
92
 
93
  # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
94
  # Using a large negative value for masked positions
 
99
  degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
100
  next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
101
  contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
102
+ contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
 
 
103
  _, selected_idx = contrastive_score.max(dim=-1) # [B]
104
  return selected_idx
105
 
 
155
  f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
156
  )
157
  # init values
158
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
 
 
159
  top_k = generation_config.top_k
160
  penalty_alpha = generation_config.penalty_alpha
161
  pad_token_id = generation_config._pad_token_tensor
 
171
  scores = () if (return_dict_in_generate and output_scores) else None
172
  decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
173
  cross_attentions = () if (return_dict_in_generate and output_attentions) else None
174
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
 
 
175
 
176
  # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
177
  if return_dict_in_generate and model.config.is_encoder_decoder:
178
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
179
+ encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
 
 
 
 
 
 
 
 
180
 
181
  # keep track of which sequences are already finished
182
  batch_size, cur_len = input_ids.shape[:2]
183
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
184
+ model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
 
 
 
 
185
 
186
  # Create cosine_matrix_mask based on the attention_mask
187
  cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
188
  if model.config.is_encoder_decoder:
189
+ if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
 
 
 
190
  cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
191
  else:
192
  cosine_matrix_mask = model_kwargs["attention_mask"]
 
194
 
195
  this_peer_finished = False
196
 
197
+ while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
 
 
198
  # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
199
  # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
200
  if model_kwargs.get("past_key_values") is None or (
 
203
  ):
204
  # prepare inputs
205
  model_kwargs["use_cache"] = True
206
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
 
 
207
 
208
  # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
209
  # the `encoder_outputs`
 
225
  # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
226
  # (the clone itmodel is always small)
227
  # torch.float32 is needed to retain precision for later logits manipulations
228
+ logit_for_next_step = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
 
 
229
 
230
  model_kwargs = model._update_model_kwargs_for_generation(
231
  outputs,
 
249
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
250
  "for contrastive search."
251
  )
252
+ # Only those caches have the necesary methods
253
+ elif not (
254
+ isinstance(past_key_values, DynamicCache)
255
+ or (
256
+ isinstance(past_key_values, EncoderDecoderCache)
257
+ and isinstance(past_key_values.self_attention_cache, DynamicCache)
258
+ )
259
  ):
260
  raise ValueError(
261
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
262
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
263
  )
264
 
265
  # contrastive_search main logic start:
 
278
  scores += (processed_logit_for_next_step,)
279
  if output_attentions:
280
  decoder_attentions += (
281
+ (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
 
 
282
  )
283
  if model.config.is_encoder_decoder:
284
  cross_attentions += (outputs.cross_attentions,)
285
 
286
  if output_hidden_states:
287
  decoder_hidden_states += (
288
+ (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
 
 
289
  )
290
 
291
  # This is needed to properly delete outputs.logits which may be very large for this first iteration
 
294
 
295
  if not sequential:
296
  # Replicates the new past_key_values to match the `top_k` candidates
297
+ model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  if sequential:
300
  all_outputs = []
301
  for i in range(top_k):
302
  # compute the candidate tokens by the language model and collect their hidden_states
303
+ next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
 
 
304
 
305
  outputs = model(
306
  **next_model_inputs,
 
308
  output_hidden_states=True,
309
  output_attentions=output_attentions,
310
  )
311
+ # Remove past K-V from output since we don't need to stack later
312
+ outputs["past_key_values"] = None
313
+ # Remove last token from past K-V since we don't want to append it at this point
314
+ model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  all_outputs.append(outputs)
317
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
 
319
  else:
320
  # compute the candidate tokens by the language model and collect their hidden_states
321
  # assembles top_k_ids into batch of size k
322
+ next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
 
 
323
 
324
  outputs = model(
325
  **next_model_inputs,
 
365
  selected_idx = selected_idx.to("cpu")
366
 
367
  # This will be used instead of the previous inneficient torch.stack(torch.split())
368
+ augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
 
 
369
 
370
  # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
371
  # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
 
373
  next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
374
  next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
375
  next_hidden = next_hidden[range(batch_size), selected_idx, :]
376
+ last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
 
 
377
 
378
  next_decoder_hidden_states = ()
379
  for layer in full_hidden_states:
380
+ layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
 
 
381
  next_decoder_hidden_states += (layer,)
382
 
383
  # generate past_key_values cache of only the selected token
 
397
  else:
398
  next_past_key_values = None
399
  for possible_cache_name in ALL_CACHE_NAMES:
400
+ next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
401
+ next_past_key_values.batch_select_indices(augmented_idx)
402
+
403
+ logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  logit_for_next_step = logit_for_next_step.to(input_ids.device)
405
 
406
  # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
 
409
  next_step_decoder_attentions = ()
410
  if output_attentions:
411
  for layer in outputs.cross_attentions:
412
+ layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
 
 
413
  next_step_cross_attentions += (layer,)
414
  for layer in outputs.decoder_attentions:
415
+ layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
 
 
416
  next_step_decoder_attentions += (layer,)
417
  outputs = Seq2SeqLMOutput(
418
  past_key_values=next_past_key_values,
 
424
  next_step_attentions = ()
425
  if output_attentions:
426
  for layer in outputs.attentions:
427
+ layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
 
 
428
  next_step_attentions += (layer,)
429
  outputs = CausalLMOutputWithPast(
430
  past_key_values=next_past_key_values,
 
444
 
445
  # finished sentences should have their next token be a padding token
446
  if has_eos_stopping_criteria:
447
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 
 
448
 
449
  # update generated ids, model inputs, and length for next step
450
  input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 
452
  streamer.put(next_tokens.cpu())
453
 
454
  # stop when each sentence is finished
455
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
 
 
456
  this_peer_finished = unfinished_sequences.max() == 0
457
 
458
  if streamer is not None:
 
462
  # Contrastive search works by forward looking at the next token, so we need to exclude it from
463
  # `past_key_values` to be consistent with the other decoding methods
464
  if model_kwargs.get("past_key_values") is not None:
465
+ model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  if model.config.is_encoder_decoder:
468
  return GenerateEncoderDecoderOutput(
 
499
  """
500
  cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
501
  if cache_implementation != "dynamic_full" and (
502
+ "sliding_attention" in getattr(model.config.get_text_config(), "layer_types", [])
 
503
  or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
504
  ):
505
  logger.warning_once(