Update custom_generate/generate.py
Browse files- custom_generate/generate.py +58 -114
custom_generate/generate.py
CHANGED
@@ -1,18 +1,22 @@
|
|
1 |
-
|
|
|
|
|
2 |
import torch
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
from transformers.generation.utils import (
|
5 |
-
|
6 |
-
GenerateNonBeamOutput,
|
7 |
GenerateDecoderOnlyOutput,
|
|
|
|
|
|
|
8 |
)
|
9 |
-
from transformers.cache_utils import Cache, EncoderDecoderCache, DynamicCache
|
10 |
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
11 |
-
from transformers.generation.utils import GenerateEncoderDecoderOutput, ALL_CACHE_NAMES
|
12 |
from transformers.utils import ModelOutput
|
13 |
-
|
14 |
-
import torch.nn as nn
|
15 |
-
import logging
|
16 |
|
17 |
if TYPE_CHECKING:
|
18 |
from transformers.generation.streamers import BaseStreamer
|
@@ -20,9 +24,7 @@ if TYPE_CHECKING:
|
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
23 |
-
def stack_model_outputs(
|
24 |
-
model_outputs: list[ModelOutput], config: PretrainedConfig
|
25 |
-
) -> ModelOutput:
|
26 |
"""
|
27 |
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
28 |
specific ModelOutput subclass from the list provided.
|
@@ -50,17 +52,11 @@ def stack_model_outputs(
|
|
50 |
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
51 |
if isinstance(data[0][0], tuple):
|
52 |
return tuple(
|
53 |
-
tuple(
|
54 |
-
torch.cat([attr[i][j] for attr in data], dim=0)
|
55 |
-
for j in range(len(data[0][0]))
|
56 |
-
)
|
57 |
for i in range(len(data[0]))
|
58 |
)
|
59 |
else:
|
60 |
-
return tuple(
|
61 |
-
torch.cat([attr[i] for attr in data], dim=0)
|
62 |
-
for i in range(len(data[0]))
|
63 |
-
)
|
64 |
elif isinstance(data[0], (int, float)):
|
65 |
# If the elements are integers or floats, return a tensor
|
66 |
return torch.tensor(data)
|
@@ -92,9 +88,7 @@ def _ranking_fast(
|
|
92 |
"""
|
93 |
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
94 |
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
95 |
-
cosine_matrix = torch.matmul(
|
96 |
-
norm_context_hidden, norm_next_hidden.transpose(1, 2)
|
97 |
-
).squeeze(-1) # [B*K, S]
|
98 |
|
99 |
# Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
|
100 |
# Using a large negative value for masked positions
|
@@ -105,9 +99,7 @@ def _ranking_fast(
|
|
105 |
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
|
106 |
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
|
107 |
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
|
108 |
-
contrastive_score = torch.stack(
|
109 |
-
torch.split(contrastive_score, beam_width)
|
110 |
-
) # [B, K]
|
111 |
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
112 |
return selected_idx
|
113 |
|
@@ -163,9 +155,7 @@ def _contrastive_search(
|
|
163 |
f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
|
164 |
)
|
165 |
# init values
|
166 |
-
has_eos_stopping_criteria = any(
|
167 |
-
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
|
168 |
-
)
|
169 |
top_k = generation_config.top_k
|
170 |
penalty_alpha = generation_config.penalty_alpha
|
171 |
pad_token_id = generation_config._pad_token_tensor
|
@@ -181,39 +171,22 @@ def _contrastive_search(
|
|
181 |
scores = () if (return_dict_in_generate and output_scores) else None
|
182 |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
183 |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
184 |
-
decoder_hidden_states = (
|
185 |
-
() if (return_dict_in_generate and output_hidden_states) else None
|
186 |
-
)
|
187 |
|
188 |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
189 |
if return_dict_in_generate and model.config.is_encoder_decoder:
|
190 |
-
encoder_attentions = (
|
191 |
-
|
192 |
-
if output_attentions
|
193 |
-
else None
|
194 |
-
)
|
195 |
-
encoder_hidden_states = (
|
196 |
-
model_kwargs["encoder_outputs"].get("hidden_states")
|
197 |
-
if output_hidden_states
|
198 |
-
else None
|
199 |
-
)
|
200 |
|
201 |
# keep track of which sequences are already finished
|
202 |
batch_size, cur_len = input_ids.shape[:2]
|
203 |
-
unfinished_sequences = torch.ones(
|
204 |
-
|
205 |
-
)
|
206 |
-
model_kwargs = model._get_initial_cache_position(
|
207 |
-
cur_len, input_ids.device, model_kwargs
|
208 |
-
)
|
209 |
|
210 |
# Create cosine_matrix_mask based on the attention_mask
|
211 |
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|
212 |
if model.config.is_encoder_decoder:
|
213 |
-
if
|
214 |
-
"decoder_attention_mask" in model_kwargs
|
215 |
-
and model_kwargs["decoder_attention_mask"] is not None
|
216 |
-
):
|
217 |
cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
|
218 |
else:
|
219 |
cosine_matrix_mask = model_kwargs["attention_mask"]
|
@@ -221,9 +194,7 @@ def _contrastive_search(
|
|
221 |
|
222 |
this_peer_finished = False
|
223 |
|
224 |
-
while model._has_unfinished_sequences(
|
225 |
-
this_peer_finished, synced_gpus, device=input_ids.device
|
226 |
-
):
|
227 |
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
228 |
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
229 |
if model_kwargs.get("past_key_values") is None or (
|
@@ -232,9 +203,7 @@ def _contrastive_search(
|
|
232 |
):
|
233 |
# prepare inputs
|
234 |
model_kwargs["use_cache"] = True
|
235 |
-
model_inputs = model.prepare_inputs_for_generation(
|
236 |
-
input_ids, **model_kwargs
|
237 |
-
)
|
238 |
|
239 |
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
240 |
# the `encoder_outputs`
|
@@ -256,9 +225,7 @@ def _contrastive_search(
|
|
256 |
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
|
257 |
# (the clone itmodel is always small)
|
258 |
# torch.float32 is needed to retain precision for later logits manipulations
|
259 |
-
logit_for_next_step = outputs.logits[:, -1, :].to(
|
260 |
-
copy=True, dtype=torch.float32, device=input_ids.device
|
261 |
-
)
|
262 |
|
263 |
model_kwargs = model._update_model_kwargs_for_generation(
|
264 |
outputs,
|
@@ -282,7 +249,18 @@ def _contrastive_search(
|
|
282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
283 |
"for contrastive search."
|
284 |
)
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
# contrastive_search main logic start:
|
288 |
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
@@ -300,18 +278,14 @@ def _contrastive_search(
|
|
300 |
scores += (processed_logit_for_next_step,)
|
301 |
if output_attentions:
|
302 |
decoder_attentions += (
|
303 |
-
(outputs.decoder_attentions,)
|
304 |
-
if model.config.is_encoder_decoder
|
305 |
-
else (outputs.attentions,)
|
306 |
)
|
307 |
if model.config.is_encoder_decoder:
|
308 |
cross_attentions += (outputs.cross_attentions,)
|
309 |
|
310 |
if output_hidden_states:
|
311 |
decoder_hidden_states += (
|
312 |
-
(outputs.decoder_hidden_states,)
|
313 |
-
if model.config.is_encoder_decoder
|
314 |
-
else (outputs.hidden_states,)
|
315 |
)
|
316 |
|
317 |
# This is needed to properly delete outputs.logits which may be very large for this first iteration
|
@@ -323,8 +297,7 @@ def _contrastive_search(
|
|
323 |
past = model_kwargs["past_key_values"]
|
324 |
# If it is a static cache, modify it in-place layer after layer to save memory
|
325 |
if isinstance(past, DynamicCache) or (
|
326 |
-
isinstance(past, EncoderDecoderCache)
|
327 |
-
and isinstance(past.self_attention_cache, DynamicCache)
|
328 |
):
|
329 |
past.batch_repeat_interleave(top_k)
|
330 |
else:
|
@@ -344,9 +317,7 @@ def _contrastive_search(
|
|
344 |
all_outputs = []
|
345 |
for i in range(top_k):
|
346 |
# compute the candidate tokens by the language model and collect their hidden_states
|
347 |
-
next_model_inputs = model.prepare_inputs_for_generation(
|
348 |
-
top_k_ids[:, i].view(-1, 1), **model_kwargs
|
349 |
-
)
|
350 |
|
351 |
outputs = model(
|
352 |
**next_model_inputs,
|
@@ -356,9 +327,7 @@ def _contrastive_search(
|
|
356 |
)
|
357 |
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
358 |
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
359 |
-
and isinstance(
|
360 |
-
outputs["past_key_values"].self_attention_cache, DynamicCache
|
361 |
-
)
|
362 |
):
|
363 |
# Remove past K-V from output since we don't need to stack later
|
364 |
outputs["past_key_values"] = None
|
@@ -376,9 +345,7 @@ def _contrastive_search(
|
|
376 |
else:
|
377 |
# compute the candidate tokens by the language model and collect their hidden_states
|
378 |
# assembles top_k_ids into batch of size k
|
379 |
-
next_model_inputs = model.prepare_inputs_for_generation(
|
380 |
-
top_k_ids.view(-1, 1), **model_kwargs
|
381 |
-
)
|
382 |
|
383 |
outputs = model(
|
384 |
**next_model_inputs,
|
@@ -424,9 +391,7 @@ def _contrastive_search(
|
|
424 |
selected_idx = selected_idx.to("cpu")
|
425 |
|
426 |
# This will be used instead of the previous inneficient torch.stack(torch.split())
|
427 |
-
augmented_idx = torch.tensor(
|
428 |
-
[x + i * top_k for i, x in enumerate(selected_idx)]
|
429 |
-
)
|
430 |
|
431 |
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
|
432 |
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
|
@@ -434,15 +399,11 @@ def _contrastive_search(
|
|
434 |
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
|
435 |
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
|
436 |
next_hidden = next_hidden[range(batch_size), selected_idx, :]
|
437 |
-
last_hidden_states = torch.cat(
|
438 |
-
[last_hidden_states, next_hidden.unsqueeze(1)], dim=1
|
439 |
-
)
|
440 |
|
441 |
next_decoder_hidden_states = ()
|
442 |
for layer in full_hidden_states:
|
443 |
-
layer = torch.stack(torch.split(layer, top_k))[
|
444 |
-
range(batch_size), selected_idx, :
|
445 |
-
]
|
446 |
next_decoder_hidden_states += (layer,)
|
447 |
|
448 |
# generate past_key_values cache of only the selected token
|
@@ -462,9 +423,7 @@ def _contrastive_search(
|
|
462 |
else:
|
463 |
next_past_key_values = None
|
464 |
for possible_cache_name in ALL_CACHE_NAMES:
|
465 |
-
next_past_key_values = next_past_key_values or getattr(
|
466 |
-
outputs, possible_cache_name, None
|
467 |
-
)
|
468 |
# Do it in-place layer per layer to save memory
|
469 |
if isinstance(next_past_key_values, DynamicCache) or (
|
470 |
isinstance(next_past_key_values, EncoderDecoderCache)
|
@@ -482,9 +441,7 @@ def _contrastive_search(
|
|
482 |
|
483 |
next_past_key_values = tuple(new_key_values)
|
484 |
|
485 |
-
logit_for_next_step = torch.stack(torch.split(logits, top_k))[
|
486 |
-
range(batch_size), selected_idx, :
|
487 |
-
]
|
488 |
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
489 |
|
490 |
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
|
@@ -493,14 +450,10 @@ def _contrastive_search(
|
|
493 |
next_step_decoder_attentions = ()
|
494 |
if output_attentions:
|
495 |
for layer in outputs.cross_attentions:
|
496 |
-
layer = torch.stack(torch.split(layer, top_k, dim=0))[
|
497 |
-
range(batch_size), selected_idx, ...
|
498 |
-
]
|
499 |
next_step_cross_attentions += (layer,)
|
500 |
for layer in outputs.decoder_attentions:
|
501 |
-
layer = torch.stack(torch.split(layer, top_k, dim=0))[
|
502 |
-
range(batch_size), selected_idx, ...
|
503 |
-
]
|
504 |
next_step_decoder_attentions += (layer,)
|
505 |
outputs = Seq2SeqLMOutput(
|
506 |
past_key_values=next_past_key_values,
|
@@ -512,9 +465,7 @@ def _contrastive_search(
|
|
512 |
next_step_attentions = ()
|
513 |
if output_attentions:
|
514 |
for layer in outputs.attentions:
|
515 |
-
layer = torch.stack(torch.split(layer, top_k, dim=0))[
|
516 |
-
range(batch_size), selected_idx, ...
|
517 |
-
]
|
518 |
next_step_attentions += (layer,)
|
519 |
outputs = CausalLMOutputWithPast(
|
520 |
past_key_values=next_past_key_values,
|
@@ -534,9 +485,7 @@ def _contrastive_search(
|
|
534 |
|
535 |
# finished sentences should have their next token be a padding token
|
536 |
if has_eos_stopping_criteria:
|
537 |
-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
538 |
-
1 - unfinished_sequences
|
539 |
-
)
|
540 |
|
541 |
# update generated ids, model inputs, and length for next step
|
542 |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
@@ -544,9 +493,7 @@ def _contrastive_search(
|
|
544 |
streamer.put(next_tokens.cpu())
|
545 |
|
546 |
# stop when each sentence is finished
|
547 |
-
unfinished_sequences = unfinished_sequences & ~stopping_criteria(
|
548 |
-
input_ids, scores
|
549 |
-
)
|
550 |
this_peer_finished = unfinished_sequences.max() == 0
|
551 |
|
552 |
if streamer is not None:
|
@@ -558,9 +505,7 @@ def _contrastive_search(
|
|
558 |
if model_kwargs.get("past_key_values") is not None:
|
559 |
if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
|
560 |
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
561 |
-
and isinstance(
|
562 |
-
model_kwargs["past_key_values"].self_attention_cache, DynamicCache
|
563 |
-
)
|
564 |
):
|
565 |
model_kwargs["past_key_values"].crop(-1)
|
566 |
else:
|
@@ -607,8 +552,7 @@ def generate(model, *args, **kwargs):
|
|
607 |
"""
|
608 |
cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
|
609 |
if cache_implementation != "dynamic_full" and (
|
610 |
-
"sliding_attention"
|
611 |
-
in getattr(model.config.get_text_config(), "layer_types", [])
|
612 |
or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
|
613 |
):
|
614 |
logger.warning_once(
|
|
|
1 |
+
import logging
|
2 |
+
from typing import TYPE_CHECKING, Optional, Union
|
3 |
+
|
4 |
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
8 |
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
9 |
+
from transformers.configuration_utils import PretrainedConfig
|
10 |
from transformers.generation.utils import (
|
11 |
+
ALL_CACHE_NAMES,
|
|
|
12 |
GenerateDecoderOnlyOutput,
|
13 |
+
GenerateEncoderDecoderOutput,
|
14 |
+
GenerateNonBeamOutput,
|
15 |
+
GenerationMixin,
|
16 |
)
|
|
|
17 |
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
|
|
18 |
from transformers.utils import ModelOutput
|
19 |
+
|
|
|
|
|
20 |
|
21 |
if TYPE_CHECKING:
|
22 |
from transformers.generation.streamers import BaseStreamer
|
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
|
27 |
+
def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
|
|
|
|
|
28 |
"""
|
29 |
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
30 |
specific ModelOutput subclass from the list provided.
|
|
|
52 |
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
53 |
if isinstance(data[0][0], tuple):
|
54 |
return tuple(
|
55 |
+
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
|
|
|
|
|
|
|
56 |
for i in range(len(data[0]))
|
57 |
)
|
58 |
else:
|
59 |
+
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
|
|
|
|
|
|
|
60 |
elif isinstance(data[0], (int, float)):
|
61 |
# If the elements are integers or floats, return a tensor
|
62 |
return torch.tensor(data)
|
|
|
88 |
"""
|
89 |
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
90 |
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
91 |
+
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]
|
|
|
|
|
92 |
|
93 |
# Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
|
94 |
# Using a large negative value for masked positions
|
|
|
99 |
degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
|
100 |
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
|
101 |
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
|
102 |
+
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
|
|
|
|
103 |
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
104 |
return selected_idx
|
105 |
|
|
|
155 |
f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
|
156 |
)
|
157 |
# init values
|
158 |
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
|
|
|
|
159 |
top_k = generation_config.top_k
|
160 |
penalty_alpha = generation_config.penalty_alpha
|
161 |
pad_token_id = generation_config._pad_token_tensor
|
|
|
171 |
scores = () if (return_dict_in_generate and output_scores) else None
|
172 |
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
173 |
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
174 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
|
|
|
|
175 |
|
176 |
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
177 |
if return_dict_in_generate and model.config.is_encoder_decoder:
|
178 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
179 |
+
encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
# keep track of which sequences are already finished
|
182 |
batch_size, cur_len = input_ids.shape[:2]
|
183 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
184 |
+
model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
|
|
|
|
|
|
|
|
185 |
|
186 |
# Create cosine_matrix_mask based on the attention_mask
|
187 |
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|
188 |
if model.config.is_encoder_decoder:
|
189 |
+
if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
|
|
|
|
|
|
|
190 |
cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
|
191 |
else:
|
192 |
cosine_matrix_mask = model_kwargs["attention_mask"]
|
|
|
194 |
|
195 |
this_peer_finished = False
|
196 |
|
197 |
+
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
|
|
|
|
198 |
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
199 |
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
200 |
if model_kwargs.get("past_key_values") is None or (
|
|
|
203 |
):
|
204 |
# prepare inputs
|
205 |
model_kwargs["use_cache"] = True
|
206 |
+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
|
|
|
|
207 |
|
208 |
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
|
209 |
# the `encoder_outputs`
|
|
|
225 |
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
|
226 |
# (the clone itmodel is always small)
|
227 |
# torch.float32 is needed to retain precision for later logits manipulations
|
228 |
+
logit_for_next_step = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
|
|
|
|
|
229 |
|
230 |
model_kwargs = model._update_model_kwargs_for_generation(
|
231 |
outputs,
|
|
|
249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
250 |
"for contrastive search."
|
251 |
)
|
252 |
+
# Only those caches have the necesary methods
|
253 |
+
elif not (
|
254 |
+
isinstance(past_key_values, DynamicCache)
|
255 |
+
or (
|
256 |
+
isinstance(past_key_values, EncoderDecoderCache)
|
257 |
+
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
258 |
+
)
|
259 |
+
):
|
260 |
+
raise ValueError(
|
261 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
262 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
263 |
+
)
|
264 |
|
265 |
# contrastive_search main logic start:
|
266 |
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
|
|
|
278 |
scores += (processed_logit_for_next_step,)
|
279 |
if output_attentions:
|
280 |
decoder_attentions += (
|
281 |
+
(outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
|
|
|
|
|
282 |
)
|
283 |
if model.config.is_encoder_decoder:
|
284 |
cross_attentions += (outputs.cross_attentions,)
|
285 |
|
286 |
if output_hidden_states:
|
287 |
decoder_hidden_states += (
|
288 |
+
(outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
|
|
|
|
|
289 |
)
|
290 |
|
291 |
# This is needed to properly delete outputs.logits which may be very large for this first iteration
|
|
|
297 |
past = model_kwargs["past_key_values"]
|
298 |
# If it is a static cache, modify it in-place layer after layer to save memory
|
299 |
if isinstance(past, DynamicCache) or (
|
300 |
+
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
|
|
|
301 |
):
|
302 |
past.batch_repeat_interleave(top_k)
|
303 |
else:
|
|
|
317 |
all_outputs = []
|
318 |
for i in range(top_k):
|
319 |
# compute the candidate tokens by the language model and collect their hidden_states
|
320 |
+
next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
|
|
|
|
321 |
|
322 |
outputs = model(
|
323 |
**next_model_inputs,
|
|
|
327 |
)
|
328 |
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
329 |
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
330 |
+
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
|
|
|
|
|
331 |
):
|
332 |
# Remove past K-V from output since we don't need to stack later
|
333 |
outputs["past_key_values"] = None
|
|
|
345 |
else:
|
346 |
# compute the candidate tokens by the language model and collect their hidden_states
|
347 |
# assembles top_k_ids into batch of size k
|
348 |
+
next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
|
|
|
|
349 |
|
350 |
outputs = model(
|
351 |
**next_model_inputs,
|
|
|
391 |
selected_idx = selected_idx.to("cpu")
|
392 |
|
393 |
# This will be used instead of the previous inneficient torch.stack(torch.split())
|
394 |
+
augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
|
|
|
|
|
395 |
|
396 |
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
|
397 |
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
|
|
|
399 |
next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
|
400 |
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
|
401 |
next_hidden = next_hidden[range(batch_size), selected_idx, :]
|
402 |
+
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
|
|
|
|
|
403 |
|
404 |
next_decoder_hidden_states = ()
|
405 |
for layer in full_hidden_states:
|
406 |
+
layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
|
|
|
|
|
407 |
next_decoder_hidden_states += (layer,)
|
408 |
|
409 |
# generate past_key_values cache of only the selected token
|
|
|
423 |
else:
|
424 |
next_past_key_values = None
|
425 |
for possible_cache_name in ALL_CACHE_NAMES:
|
426 |
+
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
|
|
|
|
427 |
# Do it in-place layer per layer to save memory
|
428 |
if isinstance(next_past_key_values, DynamicCache) or (
|
429 |
isinstance(next_past_key_values, EncoderDecoderCache)
|
|
|
441 |
|
442 |
next_past_key_values = tuple(new_key_values)
|
443 |
|
444 |
+
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
|
|
|
|
445 |
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
446 |
|
447 |
# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
|
|
|
450 |
next_step_decoder_attentions = ()
|
451 |
if output_attentions:
|
452 |
for layer in outputs.cross_attentions:
|
453 |
+
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
|
|
|
|
454 |
next_step_cross_attentions += (layer,)
|
455 |
for layer in outputs.decoder_attentions:
|
456 |
+
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
|
|
|
|
457 |
next_step_decoder_attentions += (layer,)
|
458 |
outputs = Seq2SeqLMOutput(
|
459 |
past_key_values=next_past_key_values,
|
|
|
465 |
next_step_attentions = ()
|
466 |
if output_attentions:
|
467 |
for layer in outputs.attentions:
|
468 |
+
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
|
|
|
|
469 |
next_step_attentions += (layer,)
|
470 |
outputs = CausalLMOutputWithPast(
|
471 |
past_key_values=next_past_key_values,
|
|
|
485 |
|
486 |
# finished sentences should have their next token be a padding token
|
487 |
if has_eos_stopping_criteria:
|
488 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
|
|
|
|
489 |
|
490 |
# update generated ids, model inputs, and length for next step
|
491 |
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
|
493 |
streamer.put(next_tokens.cpu())
|
494 |
|
495 |
# stop when each sentence is finished
|
496 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
|
|
|
|
497 |
this_peer_finished = unfinished_sequences.max() == 0
|
498 |
|
499 |
if streamer is not None:
|
|
|
505 |
if model_kwargs.get("past_key_values") is not None:
|
506 |
if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
|
507 |
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
508 |
+
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
|
|
|
|
|
509 |
):
|
510 |
model_kwargs["past_key_values"].crop(-1)
|
511 |
else:
|
|
|
552 |
"""
|
553 |
cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
|
554 |
if cache_implementation != "dynamic_full" and (
|
555 |
+
"sliding_attention" in getattr(model.config.get_text_config(), "layer_types", [])
|
|
|
556 |
or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
|
557 |
):
|
558 |
logger.warning_once(
|