Update custom_generate/generate.py
Browse files- custom_generate/generate.py +10 -6
custom_generate/generate.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
8 |
-
from transformers.cache_utils import Cache, EncoderDecoderCache
|
9 |
from transformers.configuration_utils import PretrainedConfig
|
10 |
from transformers.generation.utils import (
|
11 |
ALL_CACHE_NAMES,
|
@@ -249,13 +249,17 @@ def _contrastive_search(
|
|
249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
250 |
"for contrastive search."
|
251 |
)
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
255 |
):
|
256 |
raise ValueError(
|
257 |
-
f"
|
258 |
-
"
|
259 |
)
|
260 |
|
261 |
# contrastive_search main logic start:
|
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
8 |
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
9 |
from transformers.configuration_utils import PretrainedConfig
|
10 |
from transformers.generation.utils import (
|
11 |
ALL_CACHE_NAMES,
|
|
|
249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
250 |
"for contrastive search."
|
251 |
)
|
252 |
+
# Only those caches have the necesary methods
|
253 |
+
elif not (
|
254 |
+
isinstance(past_key_values, DynamicCache)
|
255 |
+
or (
|
256 |
+
isinstance(past_key_values, EncoderDecoderCache)
|
257 |
+
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
258 |
+
)
|
259 |
):
|
260 |
raise ValueError(
|
261 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
262 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
263 |
)
|
264 |
|
265 |
# contrastive_search main logic start:
|