Update custom_generate/generate.py
Browse files- custom_generate/generate.py +13 -70
custom_generate/generate.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
8 |
-
from transformers.cache_utils import Cache,
|
9 |
from transformers.configuration_utils import PretrainedConfig
|
10 |
from transformers.generation.utils import (
|
11 |
ALL_CACHE_NAMES,
|
@@ -249,17 +249,13 @@ def _contrastive_search(
|
|
249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
250 |
"for contrastive search."
|
251 |
)
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
or (
|
256 |
-
isinstance(past_key_values, EncoderDecoderCache)
|
257 |
-
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
258 |
-
)
|
259 |
):
|
260 |
raise ValueError(
|
261 |
-
f"
|
262 |
-
"
|
263 |
)
|
264 |
|
265 |
# contrastive_search main logic start:
|
@@ -294,24 +290,7 @@ def _contrastive_search(
|
|
294 |
|
295 |
if not sequential:
|
296 |
# Replicates the new past_key_values to match the `top_k` candidates
|
297 |
-
|
298 |
-
# If it is a static cache, modify it in-place layer after layer to save memory
|
299 |
-
if isinstance(past, DynamicCache) or (
|
300 |
-
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
|
301 |
-
):
|
302 |
-
past.batch_repeat_interleave(top_k)
|
303 |
-
else:
|
304 |
-
new_key_values = []
|
305 |
-
for layer in past:
|
306 |
-
items = []
|
307 |
-
# item is either the key or the value matrix
|
308 |
-
for item in layer:
|
309 |
-
items.append(item.repeat_interleave(top_k, dim=0))
|
310 |
-
new_key_values.append(tuple(items))
|
311 |
-
|
312 |
-
past = tuple(new_key_values)
|
313 |
-
|
314 |
-
model_kwargs["past_key_values"] = past
|
315 |
|
316 |
if sequential:
|
317 |
all_outputs = []
|
@@ -325,19 +304,10 @@ def _contrastive_search(
|
|
325 |
output_hidden_states=True,
|
326 |
output_attentions=output_attentions,
|
327 |
)
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
)
|
332 |
-
# Remove past K-V from output since we don't need to stack later
|
333 |
-
outputs["past_key_values"] = None
|
334 |
-
# Remove last token from past K-V since we don't want to append it at this point
|
335 |
-
model_kwargs["past_key_values"].crop(-1)
|
336 |
-
else:
|
337 |
-
raise ValueError(
|
338 |
-
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
339 |
-
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
340 |
-
)
|
341 |
|
342 |
all_outputs.append(outputs)
|
343 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
@@ -424,22 +394,7 @@ def _contrastive_search(
|
|
424 |
next_past_key_values = None
|
425 |
for possible_cache_name in ALL_CACHE_NAMES:
|
426 |
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
427 |
-
|
428 |
-
if isinstance(next_past_key_values, DynamicCache) or (
|
429 |
-
isinstance(next_past_key_values, EncoderDecoderCache)
|
430 |
-
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
|
431 |
-
):
|
432 |
-
next_past_key_values.batch_select_indices(augmented_idx)
|
433 |
-
else:
|
434 |
-
new_key_values = []
|
435 |
-
for layer in next_past_key_values:
|
436 |
-
items = []
|
437 |
-
# item is either the key or the value matrix
|
438 |
-
for item in layer:
|
439 |
-
items.append(item[augmented_idx, ...])
|
440 |
-
new_key_values.append(tuple(items))
|
441 |
-
|
442 |
-
next_past_key_values = tuple(new_key_values)
|
443 |
|
444 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
445 |
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
@@ -503,19 +458,7 @@ def _contrastive_search(
|
|
503 |
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
504 |
# `past_key_values` to be consistent with the other decoding methods
|
505 |
if model_kwargs.get("past_key_values") is not None:
|
506 |
-
|
507 |
-
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
508 |
-
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
|
509 |
-
):
|
510 |
-
model_kwargs["past_key_values"].crop(-1)
|
511 |
-
else:
|
512 |
-
past_key_values = []
|
513 |
-
for layer in model_kwargs["past_key_values"]:
|
514 |
-
layer_past_key_values = []
|
515 |
-
for item in layer:
|
516 |
-
layer_past_key_values.append(item[..., :-1, :])
|
517 |
-
past_key_values.append(tuple(layer_past_key_values))
|
518 |
-
model_kwargs["past_key_values"] = tuple(past_key_values)
|
519 |
|
520 |
if model.config.is_encoder_decoder:
|
521 |
return GenerateEncoderDecoderOutput(
|
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
8 |
+
from transformers.cache_utils import Cache, EncoderDecoderCache
|
9 |
from transformers.configuration_utils import PretrainedConfig
|
10 |
from transformers.generation.utils import (
|
11 |
ALL_CACHE_NAMES,
|
|
|
249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
250 |
"for contrastive search."
|
251 |
)
|
252 |
+
elif (
|
253 |
+
not isinstance(past_key_values[0], (tuple, torch.Tensor))
|
254 |
+
or past_key_values[0][0].shape[0] != batch_size
|
|
|
|
|
|
|
|
|
255 |
):
|
256 |
raise ValueError(
|
257 |
+
f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
258 |
+
"used for contrastive search without further modifications."
|
259 |
)
|
260 |
|
261 |
# contrastive_search main logic start:
|
|
|
290 |
|
291 |
if not sequential:
|
292 |
# Replicates the new past_key_values to match the `top_k` candidates
|
293 |
+
model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
if sequential:
|
296 |
all_outputs = []
|
|
|
304 |
output_hidden_states=True,
|
305 |
output_attentions=output_attentions,
|
306 |
)
|
307 |
+
# Remove past K-V from output since we don't need to stack later
|
308 |
+
outputs["past_key_values"] = None
|
309 |
+
# Remove last token from past K-V since we don't want to append it at this point
|
310 |
+
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
all_outputs.append(outputs)
|
313 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
|
|
394 |
next_past_key_values = None
|
395 |
for possible_cache_name in ALL_CACHE_NAMES:
|
396 |
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
397 |
+
next_past_key_values.batch_select_indices(augmented_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
400 |
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
|
|
458 |
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
459 |
# `past_key_values` to be consistent with the other decoding methods
|
460 |
if model_kwargs.get("past_key_values") is not None:
|
461 |
+
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
463 |
if model.config.is_encoder_decoder:
|
464 |
return GenerateEncoderDecoderOutput(
|