cyrilvallez HF Staff commited on
Commit
def1b87
·
verified ·
1 Parent(s): 9b4bf51

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +13 -70
custom_generate/generate.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn as nn
6
 
7
  from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
- from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
9
  from transformers.configuration_utils import PretrainedConfig
10
  from transformers.generation.utils import (
11
  ALL_CACHE_NAMES,
@@ -249,17 +249,13 @@ def _contrastive_search(
249
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
250
  "for contrastive search."
251
  )
252
- # Only those caches have the necesary methods
253
- elif not (
254
- isinstance(past_key_values, DynamicCache)
255
- or (
256
- isinstance(past_key_values, EncoderDecoderCache)
257
- and isinstance(past_key_values.self_attention_cache, DynamicCache)
258
- )
259
  ):
260
  raise ValueError(
261
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
262
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
263
  )
264
 
265
  # contrastive_search main logic start:
@@ -294,24 +290,7 @@ def _contrastive_search(
294
 
295
  if not sequential:
296
  # Replicates the new past_key_values to match the `top_k` candidates
297
- past = model_kwargs["past_key_values"]
298
- # If it is a static cache, modify it in-place layer after layer to save memory
299
- if isinstance(past, DynamicCache) or (
300
- isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
301
- ):
302
- past.batch_repeat_interleave(top_k)
303
- else:
304
- new_key_values = []
305
- for layer in past:
306
- items = []
307
- # item is either the key or the value matrix
308
- for item in layer:
309
- items.append(item.repeat_interleave(top_k, dim=0))
310
- new_key_values.append(tuple(items))
311
-
312
- past = tuple(new_key_values)
313
-
314
- model_kwargs["past_key_values"] = past
315
 
316
  if sequential:
317
  all_outputs = []
@@ -325,19 +304,10 @@ def _contrastive_search(
325
  output_hidden_states=True,
326
  output_attentions=output_attentions,
327
  )
328
- if isinstance(outputs["past_key_values"], DynamicCache) or (
329
- isinstance(outputs["past_key_values"], EncoderDecoderCache)
330
- and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
331
- ):
332
- # Remove past K-V from output since we don't need to stack later
333
- outputs["past_key_values"] = None
334
- # Remove last token from past K-V since we don't want to append it at this point
335
- model_kwargs["past_key_values"].crop(-1)
336
- else:
337
- raise ValueError(
338
- f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
339
- "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
340
- )
341
 
342
  all_outputs.append(outputs)
343
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
@@ -424,22 +394,7 @@ def _contrastive_search(
424
  next_past_key_values = None
425
  for possible_cache_name in ALL_CACHE_NAMES:
426
  next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
427
- # Do it in-place layer per layer to save memory
428
- if isinstance(next_past_key_values, DynamicCache) or (
429
- isinstance(next_past_key_values, EncoderDecoderCache)
430
- and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
431
- ):
432
- next_past_key_values.batch_select_indices(augmented_idx)
433
- else:
434
- new_key_values = []
435
- for layer in next_past_key_values:
436
- items = []
437
- # item is either the key or the value matrix
438
- for item in layer:
439
- items.append(item[augmented_idx, ...])
440
- new_key_values.append(tuple(items))
441
-
442
- next_past_key_values = tuple(new_key_values)
443
 
444
  logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
445
  logit_for_next_step = logit_for_next_step.to(input_ids.device)
@@ -503,19 +458,7 @@ def _contrastive_search(
503
  # Contrastive search works by forward looking at the next token, so we need to exclude it from
504
  # `past_key_values` to be consistent with the other decoding methods
505
  if model_kwargs.get("past_key_values") is not None:
506
- if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
507
- isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
508
- and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
509
- ):
510
- model_kwargs["past_key_values"].crop(-1)
511
- else:
512
- past_key_values = []
513
- for layer in model_kwargs["past_key_values"]:
514
- layer_past_key_values = []
515
- for item in layer:
516
- layer_past_key_values.append(item[..., :-1, :])
517
- past_key_values.append(tuple(layer_past_key_values))
518
- model_kwargs["past_key_values"] = tuple(past_key_values)
519
 
520
  if model.config.is_encoder_decoder:
521
  return GenerateEncoderDecoderOutput(
 
5
  import torch.nn as nn
6
 
7
  from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
+ from transformers.cache_utils import Cache, EncoderDecoderCache
9
  from transformers.configuration_utils import PretrainedConfig
10
  from transformers.generation.utils import (
11
  ALL_CACHE_NAMES,
 
249
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
250
  "for contrastive search."
251
  )
252
+ elif (
253
+ not isinstance(past_key_values[0], (tuple, torch.Tensor))
254
+ or past_key_values[0][0].shape[0] != batch_size
 
 
 
 
255
  ):
256
  raise ValueError(
257
+ f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
258
+ "used for contrastive search without further modifications."
259
  )
260
 
261
  # contrastive_search main logic start:
 
290
 
291
  if not sequential:
292
  # Replicates the new past_key_values to match the `top_k` candidates
293
+ model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  if sequential:
296
  all_outputs = []
 
304
  output_hidden_states=True,
305
  output_attentions=output_attentions,
306
  )
307
+ # Remove past K-V from output since we don't need to stack later
308
+ outputs["past_key_values"] = None
309
+ # Remove last token from past K-V since we don't want to append it at this point
310
+ model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
 
 
 
 
311
 
312
  all_outputs.append(outputs)
313
  outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
 
394
  next_past_key_values = None
395
  for possible_cache_name in ALL_CACHE_NAMES:
396
  next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
397
+ next_past_key_values.batch_select_indices(augmented_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
400
  logit_for_next_step = logit_for_next_step.to(input_ids.device)
 
458
  # Contrastive search works by forward looking at the next token, so we need to exclude it from
459
  # `past_key_values` to be consistent with the other decoding methods
460
  if model_kwargs.get("past_key_values") is not None:
461
+ model_kwargs["past_key_values"].crop(-1)
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
  if model.config.is_encoder_decoder:
464
  return GenerateEncoderDecoderOutput(