cyrilvallez HF Staff commited on
Commit
525fd17
·
verified ·
1 Parent(s): def1b87

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +10 -6
custom_generate/generate.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn as nn
6
 
7
  from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
- from transformers.cache_utils import Cache, EncoderDecoderCache
9
  from transformers.configuration_utils import PretrainedConfig
10
  from transformers.generation.utils import (
11
  ALL_CACHE_NAMES,
@@ -249,13 +249,17 @@ def _contrastive_search(
249
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
250
  "for contrastive search."
251
  )
252
- elif (
253
- not isinstance(past_key_values[0], (tuple, torch.Tensor))
254
- or past_key_values[0][0].shape[0] != batch_size
 
 
 
 
255
  ):
256
  raise ValueError(
257
- f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
258
- "used for contrastive search without further modifications."
259
  )
260
 
261
  # contrastive_search main logic start:
 
5
  import torch.nn as nn
6
 
7
  from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
8
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
9
  from transformers.configuration_utils import PretrainedConfig
10
  from transformers.generation.utils import (
11
  ALL_CACHE_NAMES,
 
249
  f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
250
  "for contrastive search."
251
  )
252
+ # Only those caches have the necesary methods
253
+ elif not (
254
+ isinstance(past_key_values, DynamicCache)
255
+ or (
256
+ isinstance(past_key_values, EncoderDecoderCache)
257
+ and isinstance(past_key_values.self_attention_cache, DynamicCache)
258
+ )
259
  ):
260
  raise ValueError(
261
+ f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
262
+ "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
263
  )
264
 
265
  # contrastive_search main logic start: