Hugo Larcher commited on
Commit
1b1b52b
·
1 Parent(s): 6725288
Files changed (1) hide show
  1. modelling_RW.py +104 -108
modelling_RW.py CHANGED
@@ -43,7 +43,7 @@ from einops import rearrange
43
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
44
  def rotate_half(x):
45
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
46
- return torch.cat((-x2, x1), dim=-1)
47
 
48
 
49
  class RotaryEmbedding(torch.nn.Module):
@@ -61,16 +61,20 @@ class RotaryEmbedding(torch.nn.Module):
61
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
  self.register_buffer("inv_freq", inv_freq, persistent=False)
63
  self.head_dim = head_dim
64
- self.seq_len_cached = -1
65
  self.batch_size_cached = None
66
  self.cos_cached: torch.Tensor | None = None
67
  self.sin_cached: torch.Tensor | None = None
68
 
69
- def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
70
- total_length = seq_len + past_key_values_length
71
- if total_length > self.seq_len_cached:
72
- self.seq_len_cached = total_length
73
- t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
74
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
75
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
76
 
@@ -83,45 +87,44 @@ class RotaryEmbedding(torch.nn.Module):
83
  self.cos_cached = self.cos_cached.type(dtype)
84
  self.sin_cached = self.sin_cached.type(dtype)
85
 
86
- return (
87
- self.cos_cached[:, past_key_values_length: seq_len + past_key_values_length],
88
- self.sin_cached[:, past_key_values_length: seq_len + past_key_values_length],
89
- )
90
 
91
- def forward(self, query, key, past_key_values_length=0):
92
- batch, seq_len, head_dim = query.shape
93
- cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
94
- return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
 
 
 
 
 
 
 
 
95
 
96
 
97
  def _make_causal_mask(
98
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
99
  ) -> torch.BoolTensor:
100
- """
101
- Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
102
- just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
103
- target_length, target_length+past_key_values_length]`.
104
- """
105
  batch_size, target_length = input_ids_shape
106
- mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
107
- # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
108
- # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
109
- # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
110
- past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
111
- mask = torch.cat([past_mask, mask], dim=-1)
 
 
112
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
113
  return expanded_mask
114
 
115
 
116
- def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
117
- """
118
- Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
119
- """
120
- batch_size, total_length = mask.shape
121
- seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
122
 
123
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
124
- return expanded_mask.expand(batch_size, 1, seq_length, total_length)
125
 
126
 
127
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -263,60 +266,56 @@ class Attention(nn.Module):
263
 
264
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
265
 
266
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
267
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
 
 
 
 
268
 
269
  if layer_past is not None:
270
  past_key, past_value = layer_past
271
  # concatenate along seq_length dimension:
272
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
 
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
 
279
  if use_cache is True:
280
- present = (key_layer, value_layer)
 
281
  else:
282
  present = None
283
 
284
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
285
-
286
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
288
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
289
-
290
  if alibi is None:
291
- if output_attentions:
292
- # F.scaled_dot_product_attention doesn't return the attention weights, so we have
293
- # to do it by hand if we want them
294
- attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
295
- attention_scores /= math.sqrt(self.head_dim)
296
 
297
- attention_scores = F.softmax(
298
- attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
 
299
  )
300
- attn_output = attention_scores @ value_layer_
301
  else:
302
  attn_output = F.scaled_dot_product_attention(
303
- query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
304
  )
305
- attention_scores = None
306
 
307
- attn_output = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
308
- attn_output = attn_output.permute(0, 2, 1, 3)
309
- attn_output = attn_output.reshape(batch_size, q_length, self.num_heads * self.head_dim)
310
 
311
  output_tensor = self.dense(attn_output)
312
 
313
- if output_attentions:
314
- return output_tensor, present, attention_scores
315
- else:
316
- return output_tensor, present
317
-
318
  else:
319
- matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
320
 
321
  # change view to [batch_size, num_heads, q_length, kv_length]
322
  attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
@@ -326,34 +325,35 @@ class Attention(nn.Module):
326
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
327
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
328
  attention_scores = attention_scores.to(torch.float32)
329
- # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
330
- # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
331
- # equivalent and more performant, but there might be a numerical difference. If you're reading this
332
- # and you'd like to experiment and maybe file a PR, feel free!
333
- attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
334
- attention_logits *= self.inv_norm_factor
335
- attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
336
  # [batch_size, num_heads, q_length, kv_length]
337
  attention_probs = self.attention_dropout(attention_probs)
338
 
339
  if head_mask is not None:
340
  attention_probs = attention_probs * head_mask
341
 
342
- # change view [batch_size, num_heads, q_length, kv_length]
343
- attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, q_length, kv_length)
344
 
345
  # matmul: [batch_size * num_heads, q_length, head_dim]
346
- context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
347
 
348
  # change view [batch_size, num_heads, q_length, head_dim]
349
  context_layer = self._merge_heads(context_layer)
350
 
351
  output_tensor = self.dense(context_layer)
352
 
 
353
  if output_attentions:
354
- return output_tensor, present, attention_probs
355
- else:
356
- return output_tensor, present
357
 
358
 
359
  class MLP(nn.Module):
@@ -484,7 +484,7 @@ class RWPreTrainedModel(PreTrainedModel):
484
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
485
  num_heads, ...]))
486
  """
487
- batch_size_times_num_heads, seq_length, head_dim = past_key_value[0][0].shape
488
  num_heads = batch_size_times_num_heads // batch_size
489
  # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
490
  # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
@@ -502,7 +502,7 @@ class RWPreTrainedModel(PreTrainedModel):
502
  def _convert_to_rw_cache(
503
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
504
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
505
- batch_size, num_heads, seq_length, head_dim = past_key_value[0][0].shape
506
  batch_size_times_num_heads = batch_size * num_heads
507
  # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
508
  # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
@@ -540,31 +540,22 @@ class RWModel(RWPreTrainedModel):
540
  def get_input_embeddings(self):
541
  return self.word_embeddings
542
 
543
- @staticmethod
544
  def _prepare_attn_mask(
545
- attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
546
  ) -> torch.BoolTensor:
547
- # Create a causal mask
548
- # The attention mask we receive as input should cover the whole extended sequence, including any past
549
- # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
550
- # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
551
- if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
552
- raise ValueError(
553
- "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
554
- f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
555
- f" {past_key_values_length}."
556
- )
557
  combined_attention_mask = None
558
  device = attention_mask.device
559
- _, seq_length = input_shape
560
 
561
- if seq_length > 1:
562
- combined_attention_mask = _make_causal_mask(
563
- input_shape, device=device, past_key_values_length=past_key_values_length
564
- )
565
 
566
- # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
567
- expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
568
  combined_attention_mask = (
569
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
570
  )
@@ -615,8 +606,6 @@ class RWModel(RWPreTrainedModel):
615
 
616
  if past_key_values is None:
617
  past_key_values = tuple([None] * len(self.h))
618
- else:
619
- past_key_values = self._convert_to_rw_cache(past_key_values)
620
 
621
  # Prepare head mask if needed
622
  # 1.0 in head_mask indicate we keep the head
@@ -634,11 +623,13 @@ class RWModel(RWPreTrainedModel):
634
  all_hidden_states = () if output_hidden_states else None
635
 
636
  # Compute alibi tensor: check build_alibi_tensor documentation
 
637
  past_key_values_length = 0
638
  if past_key_values[0] is not None:
639
- past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
640
  if attention_mask is None:
641
- attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
642
  else:
643
  attention_mask = attention_mask.to(hidden_states.device)
644
 
@@ -704,9 +695,6 @@ class RWModel(RWPreTrainedModel):
704
  if output_hidden_states:
705
  all_hidden_states = all_hidden_states + (hidden_states,)
706
 
707
- if presents is not None:
708
- presents = self._convert_cache_to_standard_format(presents, batch_size)
709
-
710
  if not return_dict:
711
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
712
 
@@ -738,13 +726,20 @@ class RWForCausalLM(RWPreTrainedModel):
738
  def prepare_inputs_for_generation(
739
  self,
740
  input_ids: torch.LongTensor,
741
- past_key_values: Optional[torch.Tensor] = None,
742
  attention_mask: Optional[torch.Tensor] = None,
743
  **kwargs,
744
  ) -> dict:
745
  # only last token for input_ids if past is not None
746
- if past_key_values is not None:
747
- input_ids = input_ids[:, -1:]
 
 
 
 
 
 
 
748
 
749
  return {
750
  "input_ids": input_ids,
@@ -834,6 +829,7 @@ class RWForCausalLM(RWPreTrainedModel):
834
 
835
  Output shares the same memory storage as `past`.
836
  """
 
837
 
838
  # Get a copy of `beam_idx` on all the devices where we need those indices.
839
  device_to_beam_idx = {
@@ -844,9 +840,9 @@ class RWForCausalLM(RWPreTrainedModel):
844
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
845
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
846
  )
847
- for layer_past in past
848
  )
849
- return reordered_past
850
 
851
 
852
  class RWForSequenceClassification(RWPreTrainedModel):
 
43
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
44
  def rotate_half(x):
45
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
46
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
47
 
48
 
49
  class RotaryEmbedding(torch.nn.Module):
 
61
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
62
  self.register_buffer("inv_freq", inv_freq, persistent=False)
63
  self.head_dim = head_dim
64
+ self.seq_len_cached = None
65
  self.batch_size_cached = None
66
  self.cos_cached: torch.Tensor | None = None
67
  self.sin_cached: torch.Tensor | None = None
68
 
69
+ def cos_sin(
70
+ self,
71
+ seq_len: int,
72
+ device="cuda",
73
+ dtype=torch.bfloat16,
74
+ ) -> torch.Tensor:
75
+ if seq_len != self.seq_len_cached:
76
+ self.seq_len_cached = seq_len
77
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
78
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
79
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
80
 
 
87
  self.cos_cached = self.cos_cached.type(dtype)
88
  self.sin_cached = self.sin_cached.type(dtype)
89
 
90
+ return self.cos_cached, self.sin_cached
 
 
 
91
 
92
+ def forward(self, q, k, past_seq_length=None):
93
+ if past_seq_length is None:
94
+ batch, seq_len, head_dim = q.shape
95
+ else:
96
+ batch, input_seq_len, head_dim = q.shape
97
+ seq_len = input_seq_len + past_seq_length
98
+ cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
99
+ if past_seq_length is not None:
100
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (
101
+ k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
102
+ else:
103
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
104
 
105
 
106
  def _make_causal_mask(
107
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
108
  ) -> torch.BoolTensor:
 
 
 
 
 
109
  batch_size, target_length = input_ids_shape
110
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
111
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
112
+ seq_ids = torch.arange(target_length, device=device)
113
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
114
+
115
+ if past_key_values_length > 0:
116
+ mask[:, :past_key_values_length] = True
117
+
118
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
119
  return expanded_mask
120
 
121
 
122
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
123
+ batch_size, src_length = mask.shape
124
+ tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
125
 
126
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
127
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
128
 
129
 
130
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
266
 
267
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
268
 
269
+ if layer_past is not None:
270
+ past_key, past_value = layer_past
271
+ past_kv_length = past_key.shape[2]
272
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
273
+ else:
274
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
275
 
276
  if layer_past is not None:
277
  past_key, past_value = layer_past
278
  # concatenate along seq_length dimension:
279
+ # - key: [batch_size * self.num_heads, head_dim, kv_length]
280
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
281
+ past_key = past_key.permute(0, 2, 1)
282
  key_layer = torch.cat((past_key, key_layer), dim=1)
283
  value_layer = torch.cat((past_value, value_layer), dim=1)
284
 
285
  _, kv_length, _ = key_layer.shape
286
 
287
  if use_cache is True:
288
+ key_layer_permute = key_layer.permute(0, 2, 1)
289
+ present = (key_layer_permute, value_layer)
290
  else:
291
  present = None
292
 
 
 
 
 
 
 
293
  if alibi is None:
294
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
295
+ key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
296
+ value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
 
297
 
298
+ if attention_mask is not None:
299
+ attn_output = F.scaled_dot_product_attention(
300
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
301
  )
 
302
  else:
303
  attn_output = F.scaled_dot_product_attention(
304
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
305
  )
 
306
 
307
+ x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
308
+ x = x.permute(0, 2, 1, 3)
309
+ attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
310
 
311
  output_tensor = self.dense(attn_output)
312
 
313
+ outputs = (output_tensor, present)
314
+ assert not output_attentions # not supported.
315
+ return outputs
 
 
316
  else:
317
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
318
+ matmul_result = query_layer @ key_layer.transpose(-1, -2)
319
 
320
  # change view to [batch_size, num_heads, q_length, kv_length]
321
  attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
 
325
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
326
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
327
  attention_scores = attention_scores.to(torch.float32)
328
+ # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
329
+ attention_probs = F.softmax(
330
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1,
331
+ -1)) * self.inv_norm_factor + attention_mask_float,
332
+ dim=-1,
333
+ dtype=hidden_states.dtype,
334
+ )
335
  # [batch_size, num_heads, q_length, kv_length]
336
  attention_probs = self.attention_dropout(attention_probs)
337
 
338
  if head_mask is not None:
339
  attention_probs = attention_probs * head_mask
340
 
341
+ # change view [batch_size x num_heads, q_length, kv_length]
342
+ attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
343
 
344
  # matmul: [batch_size * num_heads, q_length, head_dim]
345
+ context_layer = attention_probs_reshaped @ value_layer
346
 
347
  # change view [batch_size, num_heads, q_length, head_dim]
348
  context_layer = self._merge_heads(context_layer)
349
 
350
  output_tensor = self.dense(context_layer)
351
 
352
+ outputs = (output_tensor, present)
353
  if output_attentions:
354
+ outputs += (attention_probs,)
355
+
356
+ return outputs
357
 
358
 
359
  class MLP(nn.Module):
 
484
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
485
  num_heads, ...]))
486
  """
487
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
488
  num_heads = batch_size_times_num_heads // batch_size
489
  # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
490
  # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
 
502
  def _convert_to_rw_cache(
503
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
504
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
505
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
506
  batch_size_times_num_heads = batch_size * num_heads
507
  # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
508
  # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
 
540
  def get_input_embeddings(self):
541
  return self.word_embeddings
542
 
 
543
  def _prepare_attn_mask(
544
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
545
  ) -> torch.BoolTensor:
546
+ # create causal mask
547
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
548
  combined_attention_mask = None
549
  device = attention_mask.device
550
+ _, src_length = input_shape
551
 
552
+ #if src_length > 1:
553
+ combined_attention_mask = _make_causal_mask(
554
+ input_shape, device=device, past_key_values_length=past_key_values_length
555
+ )
556
 
557
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
558
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
559
  combined_attention_mask = (
560
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
561
  )
 
606
 
607
  if past_key_values is None:
608
  past_key_values = tuple([None] * len(self.h))
 
 
609
 
610
  # Prepare head mask if needed
611
  # 1.0 in head_mask indicate we keep the head
 
623
  all_hidden_states = () if output_hidden_states else None
624
 
625
  # Compute alibi tensor: check build_alibi_tensor documentation
626
+ seq_length_with_past = seq_length
627
  past_key_values_length = 0
628
  if past_key_values[0] is not None:
629
+ past_key_values_length = past_key_values[0][0].shape[2]
630
+ seq_length_with_past = seq_length_with_past + past_key_values_length
631
  if attention_mask is None:
632
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
633
  else:
634
  attention_mask = attention_mask.to(hidden_states.device)
635
 
 
695
  if output_hidden_states:
696
  all_hidden_states = all_hidden_states + (hidden_states,)
697
 
 
 
 
698
  if not return_dict:
699
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
700
 
 
726
  def prepare_inputs_for_generation(
727
  self,
728
  input_ids: torch.LongTensor,
729
+ past: Optional[torch.Tensor] = None,
730
  attention_mask: Optional[torch.Tensor] = None,
731
  **kwargs,
732
  ) -> dict:
733
  # only last token for input_ids if past is not None
734
+ if kwargs.get("past_key_values", None) :
735
+ input_ids = input_ids[:, -1].unsqueeze(-1)
736
+ past_key_values = kwargs["past_key_values"]
737
+ # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
738
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
739
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
740
+ # past_key_values = kwargs["past_key_values"]
741
+ else :
742
+ past_key_values = None
743
 
744
  return {
745
  "input_ids": input_ids,
 
829
 
830
  Output shares the same memory storage as `past`.
831
  """
832
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
833
 
834
  # Get a copy of `beam_idx` on all the devices where we need those indices.
835
  device_to_beam_idx = {
 
840
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
841
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
842
  )
843
+ for layer_past in standardized_past
844
  )
845
+ return self._convert_to_rw_cache(reordered_past)
846
 
847
 
848
  class RWForSequenceClassification(RWPreTrainedModel):