Update custom_generate/generate.py
Browse files
custom_generate/generate.py
CHANGED
@@ -282,14 +282,7 @@ def _contrastive_search(
|
|
282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
283 |
"for contrastive search."
|
284 |
)
|
285 |
-
|
286 |
-
not isinstance(past_key_values[0], (tuple, torch.Tensor))
|
287 |
-
or past_key_values[0][0].shape[0] != batch_size
|
288 |
-
):
|
289 |
-
raise ValueError(
|
290 |
-
f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
291 |
-
"used for contrastive search without further modifications."
|
292 |
-
)
|
293 |
|
294 |
# contrastive_search main logic start:
|
295 |
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
|
|
282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
283 |
"for contrastive search."
|
284 |
)
|
285 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
# contrastive_search main logic start:
|
288 |
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|