Hugo Larcher
commited on
Commit
·
1b1b52b
1
Parent(s):
6725288
Revert
Browse files- modelling_RW.py +104 -108
modelling_RW.py
CHANGED
@@ -43,7 +43,7 @@ from einops import rearrange
|
|
43 |
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
|
44 |
def rotate_half(x):
|
45 |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
46 |
-
return torch.cat((-x2, x1), dim=-1
|
47 |
|
48 |
|
49 |
class RotaryEmbedding(torch.nn.Module):
|
@@ -61,16 +61,20 @@ class RotaryEmbedding(torch.nn.Module):
|
|
61 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
62 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
63 |
self.head_dim = head_dim
|
64 |
-
self.seq_len_cached =
|
65 |
self.batch_size_cached = None
|
66 |
self.cos_cached: torch.Tensor | None = None
|
67 |
self.sin_cached: torch.Tensor | None = None
|
68 |
|
69 |
-
def cos_sin(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
75 |
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
76 |
|
@@ -83,45 +87,44 @@ class RotaryEmbedding(torch.nn.Module):
|
|
83 |
self.cos_cached = self.cos_cached.type(dtype)
|
84 |
self.sin_cached = self.sin_cached.type(dtype)
|
85 |
|
86 |
-
return
|
87 |
-
self.cos_cached[:, past_key_values_length: seq_len + past_key_values_length],
|
88 |
-
self.sin_cached[:, past_key_values_length: seq_len + past_key_values_length],
|
89 |
-
)
|
90 |
|
91 |
-
def forward(self,
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
def _make_causal_mask(
|
98 |
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
99 |
) -> torch.BoolTensor:
|
100 |
-
"""
|
101 |
-
Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
|
102 |
-
just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
|
103 |
-
target_length, target_length+past_key_values_length]`.
|
104 |
-
"""
|
105 |
batch_size, target_length = input_ids_shape
|
106 |
-
mask = torch.
|
107 |
-
#
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
113 |
return expanded_mask
|
114 |
|
115 |
|
116 |
-
def _expand_mask(mask: torch.Tensor,
|
117 |
-
|
118 |
-
|
119 |
-
"""
|
120 |
-
batch_size, total_length = mask.shape
|
121 |
-
seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
|
122 |
|
123 |
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
124 |
-
return expanded_mask.expand(batch_size, 1,
|
125 |
|
126 |
|
127 |
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
@@ -263,60 +266,56 @@ class Attention(nn.Module):
|
|
263 |
|
264 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
265 |
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
268 |
|
269 |
if layer_past is not None:
|
270 |
past_key, past_value = layer_past
|
271 |
# concatenate along seq_length dimension:
|
272 |
-
# - key: [batch_size * self.num_heads,
|
273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
|
274 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
276 |
|
277 |
_, kv_length, _ = key_layer.shape
|
278 |
|
279 |
if use_cache is True:
|
280 |
-
|
|
|
281 |
else:
|
282 |
present = None
|
283 |
|
284 |
-
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
285 |
-
|
286 |
-
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
287 |
-
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
288 |
-
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
289 |
-
|
290 |
if alibi is None:
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
|
295 |
-
attention_scores /= math.sqrt(self.head_dim)
|
296 |
|
297 |
-
|
298 |
-
|
|
|
299 |
)
|
300 |
-
attn_output = attention_scores @ value_layer_
|
301 |
else:
|
302 |
attn_output = F.scaled_dot_product_attention(
|
303 |
-
query_layer_, key_layer_, value_layer_,
|
304 |
)
|
305 |
-
attention_scores = None
|
306 |
|
307 |
-
|
308 |
-
|
309 |
-
attn_output =
|
310 |
|
311 |
output_tensor = self.dense(attn_output)
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
return output_tensor, present
|
317 |
-
|
318 |
else:
|
319 |
-
|
|
|
320 |
|
321 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
322 |
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
@@ -326,34 +325,35 @@ class Attention(nn.Module):
|
|
326 |
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
327 |
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
328 |
attention_scores = attention_scores.to(torch.float32)
|
329 |
-
#
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
# [batch_size, num_heads, q_length, kv_length]
|
337 |
attention_probs = self.attention_dropout(attention_probs)
|
338 |
|
339 |
if head_mask is not None:
|
340 |
attention_probs = attention_probs * head_mask
|
341 |
|
342 |
-
# change view [batch_size
|
343 |
-
attention_probs_reshaped = attention_probs.view(batch_size
|
344 |
|
345 |
# matmul: [batch_size * num_heads, q_length, head_dim]
|
346 |
-
context_layer =
|
347 |
|
348 |
# change view [batch_size, num_heads, q_length, head_dim]
|
349 |
context_layer = self._merge_heads(context_layer)
|
350 |
|
351 |
output_tensor = self.dense(context_layer)
|
352 |
|
|
|
353 |
if output_attentions:
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
|
358 |
|
359 |
class MLP(nn.Module):
|
@@ -484,7 +484,7 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
484 |
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
485 |
num_heads, ...]))
|
486 |
"""
|
487 |
-
batch_size_times_num_heads,
|
488 |
num_heads = batch_size_times_num_heads // batch_size
|
489 |
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
490 |
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
@@ -502,7 +502,7 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
502 |
def _convert_to_rw_cache(
|
503 |
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
504 |
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
505 |
-
batch_size, num_heads,
|
506 |
batch_size_times_num_heads = batch_size * num_heads
|
507 |
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
508 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
@@ -540,31 +540,22 @@ class RWModel(RWPreTrainedModel):
|
|
540 |
def get_input_embeddings(self):
|
541 |
return self.word_embeddings
|
542 |
|
543 |
-
@staticmethod
|
544 |
def _prepare_attn_mask(
|
545 |
-
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
546 |
) -> torch.BoolTensor:
|
547 |
-
#
|
548 |
-
#
|
549 |
-
# cache, so its shape should be [batch_size, seq_length + past_key_values_length]
|
550 |
-
# The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
|
551 |
-
if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
|
552 |
-
raise ValueError(
|
553 |
-
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
|
554 |
-
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
|
555 |
-
f" {past_key_values_length}."
|
556 |
-
)
|
557 |
combined_attention_mask = None
|
558 |
device = attention_mask.device
|
559 |
-
_,
|
560 |
|
561 |
-
if
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
|
566 |
-
# [batch_size, seq_length
|
567 |
-
expanded_attn_mask = _expand_mask(attention_mask,
|
568 |
combined_attention_mask = (
|
569 |
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
570 |
)
|
@@ -615,8 +606,6 @@ class RWModel(RWPreTrainedModel):
|
|
615 |
|
616 |
if past_key_values is None:
|
617 |
past_key_values = tuple([None] * len(self.h))
|
618 |
-
else:
|
619 |
-
past_key_values = self._convert_to_rw_cache(past_key_values)
|
620 |
|
621 |
# Prepare head mask if needed
|
622 |
# 1.0 in head_mask indicate we keep the head
|
@@ -634,11 +623,13 @@ class RWModel(RWPreTrainedModel):
|
|
634 |
all_hidden_states = () if output_hidden_states else None
|
635 |
|
636 |
# Compute alibi tensor: check build_alibi_tensor documentation
|
|
|
637 |
past_key_values_length = 0
|
638 |
if past_key_values[0] is not None:
|
639 |
-
past_key_values_length = past_key_values[0][0].shape[
|
|
|
640 |
if attention_mask is None:
|
641 |
-
attention_mask = torch.ones((batch_size,
|
642 |
else:
|
643 |
attention_mask = attention_mask.to(hidden_states.device)
|
644 |
|
@@ -704,9 +695,6 @@ class RWModel(RWPreTrainedModel):
|
|
704 |
if output_hidden_states:
|
705 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
706 |
|
707 |
-
if presents is not None:
|
708 |
-
presents = self._convert_cache_to_standard_format(presents, batch_size)
|
709 |
-
|
710 |
if not return_dict:
|
711 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
712 |
|
@@ -738,13 +726,20 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
738 |
def prepare_inputs_for_generation(
|
739 |
self,
|
740 |
input_ids: torch.LongTensor,
|
741 |
-
|
742 |
attention_mask: Optional[torch.Tensor] = None,
|
743 |
**kwargs,
|
744 |
) -> dict:
|
745 |
# only last token for input_ids if past is not None
|
746 |
-
if past_key_values
|
747 |
-
input_ids = input_ids[:, -1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
|
749 |
return {
|
750 |
"input_ids": input_ids,
|
@@ -834,6 +829,7 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
834 |
|
835 |
Output shares the same memory storage as `past`.
|
836 |
"""
|
|
|
837 |
|
838 |
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
839 |
device_to_beam_idx = {
|
@@ -844,9 +840,9 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
844 |
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
845 |
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
846 |
)
|
847 |
-
for layer_past in
|
848 |
)
|
849 |
-
return reordered_past
|
850 |
|
851 |
|
852 |
class RWForSequenceClassification(RWPreTrainedModel):
|
|
|
43 |
# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
|
44 |
def rotate_half(x):
|
45 |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
46 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
|
47 |
|
48 |
|
49 |
class RotaryEmbedding(torch.nn.Module):
|
|
|
61 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
62 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
63 |
self.head_dim = head_dim
|
64 |
+
self.seq_len_cached = None
|
65 |
self.batch_size_cached = None
|
66 |
self.cos_cached: torch.Tensor | None = None
|
67 |
self.sin_cached: torch.Tensor | None = None
|
68 |
|
69 |
+
def cos_sin(
|
70 |
+
self,
|
71 |
+
seq_len: int,
|
72 |
+
device="cuda",
|
73 |
+
dtype=torch.bfloat16,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
if seq_len != self.seq_len_cached:
|
76 |
+
self.seq_len_cached = seq_len
|
77 |
+
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
78 |
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
79 |
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
80 |
|
|
|
87 |
self.cos_cached = self.cos_cached.type(dtype)
|
88 |
self.sin_cached = self.sin_cached.type(dtype)
|
89 |
|
90 |
+
return self.cos_cached, self.sin_cached
|
|
|
|
|
|
|
91 |
|
92 |
+
def forward(self, q, k, past_seq_length=None):
|
93 |
+
if past_seq_length is None:
|
94 |
+
batch, seq_len, head_dim = q.shape
|
95 |
+
else:
|
96 |
+
batch, input_seq_len, head_dim = q.shape
|
97 |
+
seq_len = input_seq_len + past_seq_length
|
98 |
+
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
99 |
+
if past_seq_length is not None:
|
100 |
+
return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (
|
101 |
+
k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
|
102 |
+
else:
|
103 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
104 |
|
105 |
|
106 |
def _make_causal_mask(
|
107 |
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
108 |
) -> torch.BoolTensor:
|
|
|
|
|
|
|
|
|
|
|
109 |
batch_size, target_length = input_ids_shape
|
110 |
+
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
111 |
+
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
112 |
+
seq_ids = torch.arange(target_length, device=device)
|
113 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
|
114 |
+
|
115 |
+
if past_key_values_length > 0:
|
116 |
+
mask[:, :past_key_values_length] = True
|
117 |
+
|
118 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
119 |
return expanded_mask
|
120 |
|
121 |
|
122 |
+
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
123 |
+
batch_size, src_length = mask.shape
|
124 |
+
tgt_length = tgt_length if tgt_length is not None else src_length
|
|
|
|
|
|
|
125 |
|
126 |
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
127 |
+
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
128 |
|
129 |
|
130 |
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
|
|
266 |
|
267 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
268 |
|
269 |
+
if layer_past is not None:
|
270 |
+
past_key, past_value = layer_past
|
271 |
+
past_kv_length = past_key.shape[2]
|
272 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
273 |
+
else:
|
274 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
275 |
|
276 |
if layer_past is not None:
|
277 |
past_key, past_value = layer_past
|
278 |
# concatenate along seq_length dimension:
|
279 |
+
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
280 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
281 |
+
past_key = past_key.permute(0, 2, 1)
|
282 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
283 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
284 |
|
285 |
_, kv_length, _ = key_layer.shape
|
286 |
|
287 |
if use_cache is True:
|
288 |
+
key_layer_permute = key_layer.permute(0, 2, 1)
|
289 |
+
present = (key_layer_permute, value_layer)
|
290 |
else:
|
291 |
present = None
|
292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
if alibi is None:
|
294 |
+
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
295 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
296 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
|
|
|
|
297 |
|
298 |
+
if attention_mask is not None:
|
299 |
+
attn_output = F.scaled_dot_product_attention(
|
300 |
+
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
301 |
)
|
|
|
302 |
else:
|
303 |
attn_output = F.scaled_dot_product_attention(
|
304 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
305 |
)
|
|
|
306 |
|
307 |
+
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
308 |
+
x = x.permute(0, 2, 1, 3)
|
309 |
+
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
310 |
|
311 |
output_tensor = self.dense(attn_output)
|
312 |
|
313 |
+
outputs = (output_tensor, present)
|
314 |
+
assert not output_attentions # not supported.
|
315 |
+
return outputs
|
|
|
|
|
316 |
else:
|
317 |
+
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
318 |
+
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
319 |
|
320 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
321 |
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
|
|
325 |
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
326 |
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
327 |
attention_scores = attention_scores.to(torch.float32)
|
328 |
+
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
329 |
+
attention_probs = F.softmax(
|
330 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1,
|
331 |
+
-1)) * self.inv_norm_factor + attention_mask_float,
|
332 |
+
dim=-1,
|
333 |
+
dtype=hidden_states.dtype,
|
334 |
+
)
|
335 |
# [batch_size, num_heads, q_length, kv_length]
|
336 |
attention_probs = self.attention_dropout(attention_probs)
|
337 |
|
338 |
if head_mask is not None:
|
339 |
attention_probs = attention_probs * head_mask
|
340 |
|
341 |
+
# change view [batch_size x num_heads, q_length, kv_length]
|
342 |
+
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
343 |
|
344 |
# matmul: [batch_size * num_heads, q_length, head_dim]
|
345 |
+
context_layer = attention_probs_reshaped @ value_layer
|
346 |
|
347 |
# change view [batch_size, num_heads, q_length, head_dim]
|
348 |
context_layer = self._merge_heads(context_layer)
|
349 |
|
350 |
output_tensor = self.dense(context_layer)
|
351 |
|
352 |
+
outputs = (output_tensor, present)
|
353 |
if output_attentions:
|
354 |
+
outputs += (attention_probs,)
|
355 |
+
|
356 |
+
return outputs
|
357 |
|
358 |
|
359 |
class MLP(nn.Module):
|
|
|
484 |
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
485 |
num_heads, ...]))
|
486 |
"""
|
487 |
+
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
488 |
num_heads = batch_size_times_num_heads // batch_size
|
489 |
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
490 |
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
|
|
502 |
def _convert_to_rw_cache(
|
503 |
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
504 |
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
505 |
+
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
506 |
batch_size_times_num_heads = batch_size * num_heads
|
507 |
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
508 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
|
|
540 |
def get_input_embeddings(self):
|
541 |
return self.word_embeddings
|
542 |
|
|
|
543 |
def _prepare_attn_mask(
|
544 |
+
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
545 |
) -> torch.BoolTensor:
|
546 |
+
# create causal mask
|
547 |
+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
combined_attention_mask = None
|
549 |
device = attention_mask.device
|
550 |
+
_, src_length = input_shape
|
551 |
|
552 |
+
#if src_length > 1:
|
553 |
+
combined_attention_mask = _make_causal_mask(
|
554 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
555 |
+
)
|
556 |
|
557 |
+
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
558 |
+
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
559 |
combined_attention_mask = (
|
560 |
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
561 |
)
|
|
|
606 |
|
607 |
if past_key_values is None:
|
608 |
past_key_values = tuple([None] * len(self.h))
|
|
|
|
|
609 |
|
610 |
# Prepare head mask if needed
|
611 |
# 1.0 in head_mask indicate we keep the head
|
|
|
623 |
all_hidden_states = () if output_hidden_states else None
|
624 |
|
625 |
# Compute alibi tensor: check build_alibi_tensor documentation
|
626 |
+
seq_length_with_past = seq_length
|
627 |
past_key_values_length = 0
|
628 |
if past_key_values[0] is not None:
|
629 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
630 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
631 |
if attention_mask is None:
|
632 |
+
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
633 |
else:
|
634 |
attention_mask = attention_mask.to(hidden_states.device)
|
635 |
|
|
|
695 |
if output_hidden_states:
|
696 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
697 |
|
|
|
|
|
|
|
698 |
if not return_dict:
|
699 |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
700 |
|
|
|
726 |
def prepare_inputs_for_generation(
|
727 |
self,
|
728 |
input_ids: torch.LongTensor,
|
729 |
+
past: Optional[torch.Tensor] = None,
|
730 |
attention_mask: Optional[torch.Tensor] = None,
|
731 |
**kwargs,
|
732 |
) -> dict:
|
733 |
# only last token for input_ids if past is not None
|
734 |
+
if kwargs.get("past_key_values", None) :
|
735 |
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
736 |
+
past_key_values = kwargs["past_key_values"]
|
737 |
+
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
738 |
+
# if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
|
739 |
+
# past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
|
740 |
+
# past_key_values = kwargs["past_key_values"]
|
741 |
+
else :
|
742 |
+
past_key_values = None
|
743 |
|
744 |
return {
|
745 |
"input_ids": input_ids,
|
|
|
829 |
|
830 |
Output shares the same memory storage as `past`.
|
831 |
"""
|
832 |
+
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
833 |
|
834 |
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
835 |
device_to_beam_idx = {
|
|
|
840 |
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
841 |
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
842 |
)
|
843 |
+
for layer_past in standardized_past
|
844 |
)
|
845 |
+
return self._convert_to_rw_cache(reordered_past)
|
846 |
|
847 |
|
848 |
class RWForSequenceClassification(RWPreTrainedModel):
|