vikhyatk commited on
Commit
2b4cc76
·
verified ·
1 Parent(s): 5f88d77

Delete modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +0 -1463
modeling_phi.py DELETED
@@ -1,1463 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """PyTorch Phi model."""
17
-
18
- import math
19
- from typing import List, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.utils.checkpoint
23
- from packaging import version
24
- from torch import nn
25
- from torch.nn import CrossEntropyLoss
26
-
27
- from transformers.activations import ACT2FN
28
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
29
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPast,
32
- CausalLMOutputWithPast,
33
- )
34
- from transformers.modeling_utils import PreTrainedModel
35
- from transformers.utils import (
36
- add_start_docstrings,
37
- add_start_docstrings_to_model_forward,
38
- get_torch_version,
39
- is_flash_attn_2_available,
40
- is_flash_attn_greater_or_equal_2_10,
41
- is_torchdynamo_compiling,
42
- logging,
43
- replace_return_docstrings,
44
- )
45
- from .configuration_moondream import PhiConfig
46
-
47
-
48
- if is_flash_attn_2_available():
49
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
50
-
51
-
52
- logger = logging.get_logger(__name__)
53
-
54
- _CONFIG_FOR_DOC = "PhiConfig"
55
-
56
-
57
- # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
58
- def _prepare_4d_causal_attention_mask_with_cache_position(
59
- attention_mask: torch.Tensor,
60
- sequence_length: int,
61
- target_length: int,
62
- dtype: torch.dtype,
63
- device: torch.device,
64
- min_dtype: float,
65
- cache_position: torch.Tensor,
66
- batch_size: int,
67
- ):
68
- """
69
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
70
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
71
-
72
- Args:
73
- attention_mask (`torch.Tensor`):
74
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
75
- sequence_length (`int`):
76
- The sequence length being processed.
77
- target_length (`int`):
78
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
79
- dtype (`torch.dtype`):
80
- The dtype to use for the 4D attention mask.
81
- device (`torch.device`):
82
- The device to plcae the 4D attention mask on.
83
- min_dtype (`float`):
84
- The minimum value representable with the dtype `dtype`.
85
- cache_position (`torch.Tensor`):
86
- Indices depicting the position of the input sequence tokens in the sequence.
87
- batch_size (`torch.Tensor`):
88
- Batch size.
89
- """
90
- if attention_mask is not None and attention_mask.dim() == 4:
91
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
92
- causal_mask = attention_mask
93
- else:
94
- causal_mask = torch.full(
95
- (sequence_length, target_length),
96
- fill_value=min_dtype,
97
- dtype=dtype,
98
- device=device,
99
- )
100
- if sequence_length != 1:
101
- causal_mask = torch.triu(causal_mask, diagonal=1)
102
- causal_mask *= torch.arange(
103
- target_length, device=device
104
- ) > cache_position.reshape(-1, 1)
105
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
106
- if attention_mask is not None:
107
- causal_mask = (
108
- causal_mask.clone()
109
- ) # copy to contiguous memory for in-place edit
110
- mask_length = attention_mask.shape[-1]
111
- padding_mask = (
112
- causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
113
- )
114
- padding_mask = padding_mask == 0
115
- causal_mask[:, :, :, :mask_length] = causal_mask[
116
- :, :, :, :mask_length
117
- ].masked_fill(padding_mask, min_dtype)
118
-
119
- return causal_mask
120
-
121
-
122
- # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
123
- class PhiRotaryEmbedding(nn.Module):
124
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
125
- super().__init__()
126
-
127
- self.dim = dim
128
- self.max_position_embeddings = max_position_embeddings
129
- self.base = base
130
- inv_freq = 1.0 / (
131
- self.base
132
- ** (
133
- torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
134
- / self.dim
135
- )
136
- )
137
- self.register_buffer("inv_freq", inv_freq, persistent=False)
138
-
139
- # Build here to make `torch.jit.trace` work.
140
- self._set_cos_sin_cache(
141
- seq_len=max_position_embeddings,
142
- device=self.inv_freq.device,
143
- dtype=torch.get_default_dtype(),
144
- )
145
-
146
- def _set_cos_sin_cache(self, seq_len, device, dtype):
147
- self.max_seq_len_cached = seq_len
148
- t = torch.arange(
149
- self.max_seq_len_cached, device=device, dtype=torch.int64
150
- ).type_as(self.inv_freq)
151
-
152
- freqs = torch.outer(t, self.inv_freq)
153
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
- emb = torch.cat((freqs, freqs), dim=-1)
155
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
156
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
157
-
158
- def forward(self, x, seq_len=None):
159
- # x: [bs, num_attention_heads, seq_len, head_size]
160
- if seq_len > self.max_seq_len_cached:
161
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
162
-
163
- return (
164
- self.cos_cached[:seq_len].to(dtype=x.dtype),
165
- self.sin_cached[:seq_len].to(dtype=x.dtype),
166
- )
167
-
168
-
169
- # Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi
170
- class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
171
- """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
-
173
- def __init__(
174
- self,
175
- dim,
176
- max_position_embeddings=2048,
177
- base=10000,
178
- device=None,
179
- scaling_factor=1.0,
180
- ):
181
- self.scaling_factor = scaling_factor
182
- super().__init__(dim, max_position_embeddings, base, device)
183
-
184
- def _set_cos_sin_cache(self, seq_len, device, dtype):
185
- self.max_seq_len_cached = seq_len
186
- t = torch.arange(
187
- self.max_seq_len_cached, device=device, dtype=torch.int64
188
- ).type_as(self.inv_freq)
189
- t = t / self.scaling_factor
190
-
191
- freqs = torch.outer(t, self.inv_freq)
192
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
193
- emb = torch.cat((freqs, freqs), dim=-1)
194
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
195
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
196
-
197
-
198
- # Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi
199
- class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
200
- """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
201
-
202
- def __init__(
203
- self,
204
- dim,
205
- max_position_embeddings=2048,
206
- base=10000,
207
- device=None,
208
- scaling_factor=1.0,
209
- ):
210
- self.scaling_factor = scaling_factor
211
- super().__init__(dim, max_position_embeddings, base, device)
212
-
213
- def _set_cos_sin_cache(self, seq_len, device, dtype):
214
- self.max_seq_len_cached = seq_len
215
-
216
- if seq_len > self.max_position_embeddings:
217
- base = self.base * (
218
- (self.scaling_factor * seq_len / self.max_position_embeddings)
219
- - (self.scaling_factor - 1)
220
- ) ** (self.dim / (self.dim - 2))
221
- inv_freq = 1.0 / (
222
- base
223
- ** (
224
- torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device)
225
- / self.dim
226
- )
227
- )
228
- self.register_buffer("inv_freq", inv_freq, persistent=False)
229
-
230
- t = torch.arange(
231
- self.max_seq_len_cached, device=device, dtype=torch.int64
232
- ).type_as(self.inv_freq)
233
-
234
- freqs = torch.outer(t, self.inv_freq)
235
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
236
- emb = torch.cat((freqs, freqs), dim=-1)
237
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
238
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
239
-
240
-
241
- # Copied from transformers.models.llama.modeling_llama.rotate_half
242
- def rotate_half(x):
243
- """Rotates half the hidden dims of the input."""
244
- x1 = x[..., : x.shape[-1] // 2]
245
- x2 = x[..., x.shape[-1] // 2 :]
246
- return torch.cat((-x2, x1), dim=-1)
247
-
248
-
249
- # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
250
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
251
- """Applies Rotary Position Embedding to the query and key tensors.
252
-
253
- Args:
254
- q (`torch.Tensor`): The query tensor.
255
- k (`torch.Tensor`): The key tensor.
256
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
257
- sin (`torch.Tensor`): The sine part of the rotary embedding.
258
- position_ids (`torch.Tensor`):
259
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
260
- used to pass offsetted position ids when working with a KV-cache.
261
- unsqueeze_dim (`int`, *optional*, defaults to 1):
262
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
263
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
264
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
265
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
266
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
267
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
268
- Returns:
269
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
270
- """
271
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
272
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
273
- q_embed = (q * cos) + (rotate_half(q) * sin)
274
- k_embed = (k * cos) + (rotate_half(k) * sin)
275
- return q_embed, k_embed
276
-
277
-
278
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
279
- class PhiMLP(nn.Module):
280
- def __init__(self, config):
281
- super().__init__()
282
- self.config = config
283
- self.activation_fn = ACT2FN[config.hidden_act]
284
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
285
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
286
-
287
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
288
- hidden_states = self.fc1(hidden_states)
289
- hidden_states = self.activation_fn(hidden_states)
290
- hidden_states = self.fc2(hidden_states)
291
- return hidden_states
292
-
293
-
294
- # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
295
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
296
- """
297
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
298
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
299
- """
300
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
301
- if n_rep == 1:
302
- return hidden_states
303
- hidden_states = hidden_states[:, :, None, :, :].expand(
304
- batch, num_key_value_heads, n_rep, slen, head_dim
305
- )
306
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
307
-
308
-
309
- class PhiAttention(nn.Module):
310
- """Multi-headed attention from 'Attention Is All You Need' paper"""
311
-
312
- def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
313
- super().__init__()
314
- self.config = config
315
- self.layer_idx = layer_idx
316
- if layer_idx is None:
317
- logger.warning_once(
318
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
319
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
320
- "when creating this class."
321
- )
322
-
323
- self.attention_dropout = config.attention_dropout
324
- self.hidden_size = config.hidden_size
325
- self.num_heads = config.num_attention_heads
326
- self.head_dim = self.hidden_size // self.num_heads
327
- self.num_key_value_heads = config.num_key_value_heads
328
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
329
- self.max_position_embeddings = config.max_position_embeddings
330
- self.rope_theta = config.rope_theta
331
- self.partial_rotary_factor = config.partial_rotary_factor
332
- self.is_causal = True
333
-
334
- if (self.head_dim * self.num_heads) != self.hidden_size:
335
- raise ValueError(
336
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
337
- f" and `num_heads`: {self.num_heads})."
338
- )
339
-
340
- self.Wqkv = nn.Linear(
341
- self.hidden_size, 3 * self.num_heads * self.head_dim, bias=True
342
- )
343
- self.out_proj = nn.Linear(
344
- self.num_heads * self.head_dim, self.hidden_size, bias=True
345
- )
346
-
347
- self._init_rope()
348
-
349
- def _init_rope(self):
350
- if self.config.rope_scaling is None:
351
- self.rotary_emb = PhiRotaryEmbedding(
352
- int(self.partial_rotary_factor * self.head_dim),
353
- max_position_embeddings=self.max_position_embeddings,
354
- base=self.rope_theta,
355
- )
356
- else:
357
- scaling_type = self.config.rope_scaling["type"]
358
- scaling_factor = self.config.rope_scaling["factor"]
359
- if scaling_type == "linear":
360
- self.rotary_emb = PhiLinearScalingRotaryEmbedding(
361
- int(self.partial_rotary_factor * self.head_dim),
362
- max_position_embeddings=self.max_position_embeddings,
363
- scaling_factor=scaling_factor,
364
- base=self.rope_theta,
365
- )
366
- elif scaling_type == "dynamic":
367
- self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
368
- int(self.partial_rotary_factor * self.head_dim),
369
- max_position_embeddings=self.max_position_embeddings,
370
- scaling_factor=scaling_factor,
371
- base=self.rope_theta,
372
- )
373
- else:
374
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
375
-
376
- def forward(
377
- self,
378
- hidden_states: torch.Tensor,
379
- attention_mask: Optional[torch.Tensor] = None,
380
- position_ids: Optional[torch.LongTensor] = None,
381
- past_key_value: Optional[Cache] = None,
382
- output_attentions: bool = False,
383
- use_cache: bool = False,
384
- cache_position: Optional[torch.LongTensor] = None,
385
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
386
- bsz, q_len, _ = hidden_states.size()
387
-
388
- query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
389
- 3, dim=-1
390
- )
391
-
392
- query_states = query_states.view(
393
- bsz, q_len, self.num_heads, self.head_dim
394
- ).transpose(1, 2)
395
- key_states = key_states.view(
396
- bsz, q_len, self.num_key_value_heads, self.head_dim
397
- ).transpose(1, 2)
398
- value_states = value_states.view(
399
- bsz, q_len, self.num_key_value_heads, self.head_dim
400
- ).transpose(1, 2)
401
-
402
- kv_seq_len = key_states.shape[-2]
403
- if past_key_value is not None:
404
- if self.layer_idx is None:
405
- raise ValueError(
406
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
407
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
408
- "with a layer index."
409
- )
410
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
411
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
412
-
413
- # Partial rotary embedding
414
- query_rot, query_pass = (
415
- query_states[..., : self.rotary_emb.dim],
416
- query_states[..., self.rotary_emb.dim :],
417
- )
418
- key_rot, key_pass = (
419
- key_states[..., : self.rotary_emb.dim],
420
- key_states[..., self.rotary_emb.dim :],
421
- )
422
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
423
- query_rot, key_rot = apply_rotary_pos_emb(
424
- query_rot, key_rot, cos, sin, position_ids
425
- )
426
-
427
- # [batch_size, seq_length, num_heads, head_dim]
428
- query_states = torch.cat((query_rot, query_pass), dim=-1)
429
- key_states = torch.cat((key_rot, key_pass), dim=-1)
430
-
431
- if past_key_value is not None:
432
- cache_kwargs = {
433
- "sin": sin,
434
- "cos": cos,
435
- "partial_rotation_size": self.rotary_emb.dim,
436
- "cache_position": cache_position,
437
- }
438
- key_states, value_states = past_key_value.update(
439
- key_states, value_states, self.layer_idx, cache_kwargs
440
- )
441
-
442
- key_states = repeat_kv(key_states, self.num_key_value_groups)
443
- value_states = repeat_kv(value_states, self.num_key_value_groups)
444
-
445
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
446
- attn_weights = torch.matmul(
447
- query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
448
- ) / math.sqrt(self.head_dim)
449
-
450
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
451
- raise ValueError(
452
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
453
- f" {attn_weights.size()}"
454
- )
455
-
456
- if attention_mask is not None:
457
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
458
- attn_weights += causal_mask
459
-
460
- # upcast attention to fp32
461
- attn_weights = nn.functional.softmax(
462
- attn_weights, dim=-1, dtype=torch.float32
463
- ).to(value_states.dtype)
464
- attn_weights = nn.functional.dropout(
465
- attn_weights, p=self.attention_dropout, training=self.training
466
- )
467
-
468
- attn_output = torch.matmul(attn_weights, value_states)
469
-
470
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
471
- raise ValueError(
472
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
473
- f" {attn_output.size()}"
474
- )
475
-
476
- attn_output = attn_output.transpose(1, 2).contiguous()
477
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
478
-
479
- attn_output = self.out_proj(attn_output)
480
-
481
- if not output_attentions:
482
- attn_weights = None
483
-
484
- return attn_output, attn_weights, past_key_value
485
-
486
-
487
- class PhiFlashAttention2(PhiAttention):
488
- """
489
- Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
490
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
491
- flash attention and deal with padding tokens in case the input contains any of them.
492
- """
493
-
494
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
495
- def __init__(self, *args, **kwargs):
496
- super().__init__(*args, **kwargs)
497
-
498
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
499
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
500
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
501
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
502
-
503
- def forward(
504
- self,
505
- hidden_states: torch.Tensor,
506
- attention_mask: Optional[torch.LongTensor] = None,
507
- position_ids: Optional[torch.LongTensor] = None,
508
- past_key_value: Optional[Cache] = None,
509
- output_attentions: bool = False,
510
- use_cache: bool = False,
511
- cache_position: Optional[torch.LongTensor] = None,
512
- **kwargs,
513
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
514
- # PhiFlashAttention2 attention does not support output_attentions
515
-
516
- output_attentions = False
517
-
518
- bsz, q_len, _ = hidden_states.size()
519
-
520
- query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
521
- 3, dim=-1
522
- )
523
-
524
- # Flash attention requires the input to have the shape
525
- # batch_size x seq_length x head_dim x hidden_dim
526
- # therefore we just need to keep the original shape
527
- query_states = query_states.view(
528
- bsz, q_len, self.num_heads, self.head_dim
529
- ).transpose(1, 2)
530
- key_states = key_states.view(
531
- bsz, q_len, self.num_key_value_heads, self.head_dim
532
- ).transpose(1, 2)
533
- value_states = value_states.view(
534
- bsz, q_len, self.num_key_value_heads, self.head_dim
535
- ).transpose(1, 2)
536
-
537
- kv_seq_len = key_states.shape[-2]
538
- if past_key_value is not None:
539
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
540
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
541
-
542
- # Partial rotary embedding
543
- query_rot, query_pass = (
544
- query_states[..., : self.rotary_emb.dim],
545
- query_states[..., self.rotary_emb.dim :],
546
- )
547
- key_rot, key_pass = (
548
- key_states[..., : self.rotary_emb.dim],
549
- key_states[..., self.rotary_emb.dim :],
550
- )
551
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
552
- query_rot, key_rot = apply_rotary_pos_emb(
553
- query_rot, key_rot, cos, sin, position_ids
554
- )
555
-
556
- # [batch_size, seq_length, num_heads, head_dim]
557
- query_states = torch.cat((query_rot, query_pass), dim=-1)
558
- key_states = torch.cat((key_rot, key_pass), dim=-1)
559
-
560
- if past_key_value is not None:
561
- cache_kwargs = {
562
- "sin": sin,
563
- "cos": cos,
564
- "partial_rotation_size": self.rotary_emb.dim,
565
- "cache_position": cache_position,
566
- }
567
- key_states, value_states = past_key_value.update(
568
- key_states, value_states, self.layer_idx, cache_kwargs
569
- )
570
-
571
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
572
- # to be able to avoid many of these transpose/reshape/view.
573
- query_states = query_states.transpose(1, 2)
574
- key_states = key_states.transpose(1, 2)
575
- value_states = value_states.transpose(1, 2)
576
-
577
- attn_dropout = self.attention_dropout if self.training else 0.0
578
-
579
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
580
- # therefore the input hidden states gets silently casted in float32. Hence, we need
581
- # cast them back in the correct dtype just to be sure everything works as expected.
582
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
583
- # in fp32.
584
-
585
- if query_states.dtype == torch.float32:
586
- if torch.is_autocast_enabled():
587
- target_dtype = torch.get_autocast_gpu_dtype()
588
- # Handle the case where the model is quantized
589
- elif hasattr(self.config, "_pre_quantization_dtype"):
590
- target_dtype = self.config._pre_quantization_dtype
591
- else:
592
- target_dtype = self.q_proj.weight.dtype
593
-
594
- logger.warning_once(
595
- f"The input hidden states seems to be silently casted in float32, this might be related to"
596
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
597
- f" {target_dtype}."
598
- )
599
-
600
- query_states = query_states.to(target_dtype)
601
- key_states = key_states.to(target_dtype)
602
- value_states = value_states.to(target_dtype)
603
-
604
- attn_output = _flash_attention_forward(
605
- query_states,
606
- key_states,
607
- value_states,
608
- attention_mask,
609
- q_len,
610
- position_ids=position_ids,
611
- dropout=attn_dropout,
612
- softmax_scale=None,
613
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
614
- is_causal=self.is_causal,
615
- )
616
-
617
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
618
- attn_output = self.out_proj(attn_output)
619
-
620
- if not output_attentions:
621
- attn_weights = None
622
-
623
- return attn_output, attn_weights, past_key_value
624
-
625
-
626
- class PhiSdpaAttention(PhiAttention):
627
- def __init__(self, *args, **kwargs):
628
- super().__init__(*args, **kwargs)
629
- self.require_contiguous_qkv = version.parse(
630
- get_torch_version()
631
- ) < version.parse("2.2.0")
632
-
633
- """
634
- SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
635
- `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
636
- SDPA API.
637
- """
638
-
639
- # Adapted from PhiAttention.forward
640
- def forward(
641
- self,
642
- hidden_states: torch.Tensor,
643
- attention_mask: Optional[torch.Tensor] = None,
644
- position_ids: Optional[torch.LongTensor] = None,
645
- past_key_value: Optional[Cache] = None,
646
- output_attentions: bool = False,
647
- use_cache: bool = False,
648
- cache_position: Optional[torch.LongTensor] = None,
649
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
650
- if output_attentions:
651
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
652
- logger.warning_once(
653
- "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not "
654
- "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying "
655
- "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can "
656
- 'be removed using the argument `attn_implementation="eager"` when loading the model.'
657
- )
658
- return super().forward(
659
- hidden_states=hidden_states,
660
- attention_mask=attention_mask,
661
- position_ids=position_ids,
662
- past_key_value=past_key_value,
663
- output_attentions=output_attentions,
664
- use_cache=use_cache,
665
- )
666
-
667
- bsz, q_len, _ = hidden_states.size()
668
-
669
- query_states, key_states, value_states = self.Wqkv(hidden_states).chunk(
670
- 3, dim=-1
671
- )
672
-
673
- query_states = query_states.view(
674
- bsz, q_len, self.num_heads, self.head_dim
675
- ).transpose(1, 2)
676
- key_states = key_states.view(
677
- bsz, q_len, self.num_key_value_heads, self.head_dim
678
- ).transpose(1, 2)
679
- value_states = value_states.view(
680
- bsz, q_len, self.num_key_value_heads, self.head_dim
681
- ).transpose(1, 2)
682
-
683
- kv_seq_len = key_states.shape[-2]
684
- if past_key_value is not None:
685
- if self.layer_idx is None:
686
- raise ValueError(
687
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
688
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
689
- "with a layer index."
690
- )
691
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
692
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
693
-
694
- # Partial rotary embedding
695
- query_rot, query_pass = (
696
- query_states[..., : self.rotary_emb.dim],
697
- query_states[..., self.rotary_emb.dim :],
698
- )
699
- key_rot, key_pass = (
700
- key_states[..., : self.rotary_emb.dim],
701
- key_states[..., self.rotary_emb.dim :],
702
- )
703
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
704
- query_rot, key_rot = apply_rotary_pos_emb(
705
- query_rot, key_rot, cos, sin, position_ids
706
- )
707
-
708
- # [batch_size, seq_length, num_heads, head_dim]
709
- query_states = torch.cat((query_rot, query_pass), dim=-1)
710
- key_states = torch.cat((key_rot, key_pass), dim=-1)
711
-
712
- if past_key_value is not None:
713
- cache_kwargs = {
714
- "sin": sin,
715
- "cos": cos,
716
- "partial_rotation_size": self.rotary_emb.dim,
717
- "cache_position": cache_position,
718
- }
719
- key_states, value_states = past_key_value.update(
720
- key_states, value_states, self.layer_idx, cache_kwargs
721
- )
722
-
723
- key_states = repeat_kv(key_states, self.num_key_value_groups)
724
- value_states = repeat_kv(value_states, self.num_key_value_groups)
725
-
726
- causal_mask = attention_mask
727
- if attention_mask is not None:
728
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
729
-
730
- # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
731
- # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
732
- # Reference: https://github.com/pytorch/pytorch/issues/112577
733
- if (
734
- self.require_contiguous_qkv
735
- and query_states.device.type == "cuda"
736
- and attention_mask is not None
737
- ):
738
- query_states = query_states.contiguous()
739
- key_states = key_states.contiguous()
740
- value_states = value_states.contiguous()
741
-
742
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
743
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
744
- is_causal = True if causal_mask is None and q_len > 1 else False
745
-
746
- attn_output = torch.nn.functional.scaled_dot_product_attention(
747
- query_states,
748
- key_states,
749
- value_states,
750
- attn_mask=causal_mask,
751
- dropout_p=self.attention_dropout if self.training else 0.0,
752
- is_causal=is_causal,
753
- )
754
-
755
- attn_output = attn_output.transpose(1, 2).contiguous()
756
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
757
-
758
- attn_output = self.out_proj(attn_output)
759
-
760
- return attn_output, None, past_key_value
761
-
762
-
763
- PHI_ATTENTION_CLASSES = {
764
- "eager": PhiAttention,
765
- "flash_attention_2": PhiFlashAttention2,
766
- "sdpa": PhiSdpaAttention,
767
- }
768
-
769
-
770
- class PhiDecoderLayer(nn.Module):
771
- def __init__(self, config: PhiConfig, layer_idx: int):
772
- super().__init__()
773
- self.mixer = PHI_ATTENTION_CLASSES[config._attn_implementation](
774
- config, layer_idx=layer_idx
775
- )
776
- self.mlp = PhiMLP(config)
777
- self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
778
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
779
-
780
- def forward(
781
- self,
782
- hidden_states: torch.Tensor,
783
- attention_mask: Optional[torch.Tensor] = None,
784
- position_ids: Optional[torch.LongTensor] = None,
785
- output_attentions: Optional[bool] = False,
786
- use_cache: Optional[bool] = False,
787
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
788
- cache_position: Optional[torch.LongTensor] = None,
789
- **kwargs,
790
- ) -> Tuple[
791
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
792
- ]:
793
- """
794
- Args:
795
- hidden_states (`torch.FloatTensor`):
796
- input to the layer of shape `(batch, seq_len, embed_dim)`
797
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
798
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
799
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
800
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
801
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
802
- output_attentions (`bool`, *optional*):
803
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
804
- returned tensors for more detail.
805
- use_cache (`bool`, *optional*):
806
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
807
- (see `past_key_values`).
808
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
809
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
810
- Indices depicting the position of the input sequence tokens in the sequence
811
- kwargs (`dict`, *optional*):
812
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
813
- into the model
814
- """
815
-
816
- residual = hidden_states
817
-
818
- hidden_states = self.ln(hidden_states)
819
-
820
- # Self Attention
821
- attn_outputs, self_attn_weights, present_key_value = self.mixer(
822
- hidden_states=hidden_states,
823
- attention_mask=attention_mask,
824
- position_ids=position_ids,
825
- past_key_value=past_key_value,
826
- output_attentions=output_attentions,
827
- use_cache=use_cache,
828
- cache_position=cache_position,
829
- )
830
- attn_outputs = self.resid_dropout(attn_outputs)
831
-
832
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
833
- hidden_states = attn_outputs + feed_forward_hidden_states + residual
834
- outputs = (hidden_states,)
835
-
836
- if output_attentions:
837
- outputs += (self_attn_weights,)
838
-
839
- if use_cache:
840
- outputs += (present_key_value,)
841
-
842
- return outputs
843
-
844
-
845
- PHI_START_DOCSTRING = r"""
846
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
847
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
848
- etc.)
849
-
850
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
851
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
852
- and behavior.
853
-
854
- Parameters:
855
- config ([`PhiConfig`]):
856
- Model configuration class with all the parameters of the model. Initializing with a config file does not
857
- load the weights associated with the model, only the configuration. Check out the
858
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
859
- """
860
-
861
-
862
- @add_start_docstrings(
863
- "The bare Phi Model outputting raw hidden-states without any specific head on top.",
864
- PHI_START_DOCSTRING,
865
- )
866
- class PhiPreTrainedModel(PreTrainedModel):
867
- config_class = PhiConfig
868
- base_model_prefix = "model"
869
- supports_gradient_checkpointing = True
870
- _no_split_modules = ["PhiDecoderLayer"]
871
- _skip_keys_device_placement = "past_key_values"
872
- _supports_flash_attn_2 = True
873
- _supports_sdpa = True
874
- _supports_cache_class = True
875
-
876
- def _init_weights(self, module):
877
- std = self.config.initializer_range
878
- if isinstance(module, nn.Linear):
879
- module.weight.data.normal_(mean=0.0, std=std)
880
- if module.bias is not None:
881
- module.bias.data.zero_()
882
- elif isinstance(module, nn.Embedding):
883
- module.weight.data.normal_(mean=0.0, std=std)
884
- if module.padding_idx is not None:
885
- module.weight.data[module.padding_idx].zero_()
886
-
887
-
888
- class Embedding(nn.Module):
889
- def __init__(self, config: PhiConfig):
890
- super().__init__()
891
- self.wte = nn.Embedding(
892
- config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
893
- )
894
-
895
- def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
896
- return self.wte(input_ids)
897
-
898
- PHI_INPUTS_DOCSTRING = r"""
899
- Args:
900
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
901
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
902
- it.
903
-
904
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
905
- [`PreTrainedTokenizer.__call__`] for details.
906
-
907
- [What are input IDs?](../glossary#input-ids)
908
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
910
-
911
- - 1 for tokens that are **not masked**,
912
- - 0 for tokens that are **masked**.
913
-
914
- [What are attention masks?](../glossary#attention-mask)
915
-
916
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
917
- [`PreTrainedTokenizer.__call__`] for details.
918
-
919
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
920
- `past_key_values`).
921
-
922
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
923
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
924
- information on the default strategy.
925
-
926
- - 1 indicates the head is **not masked**,
927
- - 0 indicates the head is **masked**.
928
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
929
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
930
- config.n_positions - 1]`.
931
-
932
- [What are position IDs?](../glossary#position-ids)
933
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
934
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
935
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
936
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
937
-
938
- Two formats are allowed:
939
- - a [`~cache_utils.Cache`] instance;
940
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
941
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
942
- cache format.
943
-
944
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
945
- legacy cache format will be returned.
946
-
947
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
948
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
949
- of shape `(batch_size, sequence_length)`.
950
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
951
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
952
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
953
- model's internal embedding lookup matrix.
954
- use_cache (`bool`, *optional*):
955
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
956
- `past_key_values`).
957
- output_attentions (`bool`, *optional*):
958
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
959
- tensors for more detail.
960
- output_hidden_states (`bool`, *optional*):
961
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
962
- more detail.
963
- return_dict (`bool`, *optional*):
964
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
965
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
966
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
967
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
968
- the complete sequence length.
969
- """
970
-
971
-
972
- @add_start_docstrings(
973
- "The bare Phi Model outputting raw hidden-states without any specific head on top.",
974
- PHI_START_DOCSTRING,
975
- )
976
- class PhiModel(PhiPreTrainedModel):
977
- """
978
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
979
-
980
- Args:
981
- config: PhiConfig
982
- """
983
-
984
- def __init__(self, config: PhiConfig):
985
- super().__init__(config)
986
- self.padding_idx = config.pad_token_id
987
- self.vocab_size = config.vocab_size
988
-
989
- self.embd = Embedding(config)
990
- self.embed_dropout = nn.Dropout(config.embd_pdrop)
991
- self.h = nn.ModuleList(
992
- [
993
- PhiDecoderLayer(config, layer_idx)
994
- for layer_idx in range(config.num_hidden_layers)
995
- ]
996
- )
997
-
998
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
999
- self._use_sdpa = config._attn_implementation == "sdpa"
1000
-
1001
- self.gradient_checkpointing = False
1002
- # Initialize weights and apply final processing
1003
- self.post_init()
1004
-
1005
- def get_input_embeddings(self):
1006
- return self.embd.wte
1007
-
1008
- def set_input_embeddings(self, value):
1009
- self.embd.wte = value
1010
-
1011
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1012
- def forward(
1013
- self,
1014
- input_ids: torch.LongTensor = None,
1015
- attention_mask: Optional[torch.Tensor] = None,
1016
- position_ids: Optional[torch.LongTensor] = None,
1017
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1018
- inputs_embeds: Optional[torch.FloatTensor] = None,
1019
- use_cache: Optional[bool] = None,
1020
- output_attentions: Optional[bool] = None,
1021
- output_hidden_states: Optional[bool] = None,
1022
- return_dict: Optional[bool] = None,
1023
- cache_position: Optional[torch.LongTensor] = None,
1024
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1025
- output_attentions = (
1026
- output_attentions
1027
- if output_attentions is not None
1028
- else self.config.output_attentions
1029
- )
1030
- output_hidden_states = (
1031
- output_hidden_states
1032
- if output_hidden_states is not None
1033
- else self.config.output_hidden_states
1034
- )
1035
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1036
-
1037
- return_dict = (
1038
- return_dict if return_dict is not None else self.config.use_return_dict
1039
- )
1040
-
1041
- if (input_ids is None) ^ (inputs_embeds is not None):
1042
- raise ValueError(
1043
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1044
- )
1045
-
1046
- if self.gradient_checkpointing and self.training:
1047
- if use_cache:
1048
- logger.warning_once(
1049
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1050
- )
1051
- use_cache = False
1052
-
1053
- use_legacy_cache = False
1054
- if use_cache and not isinstance(past_key_values, Cache) and not self.training:
1055
- use_legacy_cache = True
1056
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1057
- logger.warning_once(
1058
- "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
1059
- "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
1060
- )
1061
-
1062
- if inputs_embeds is None:
1063
- inputs_embeds = self.embd(input_ids)
1064
-
1065
- if cache_position is None:
1066
- past_seen_tokens = (
1067
- past_key_values.get_seq_length() if past_key_values is not None else 0
1068
- )
1069
- cache_position = torch.arange(
1070
- past_seen_tokens,
1071
- past_seen_tokens + inputs_embeds.shape[1],
1072
- device=inputs_embeds.device,
1073
- )
1074
- if position_ids is None:
1075
- position_ids = cache_position.unsqueeze(0)
1076
-
1077
- causal_mask = self._update_causal_mask(
1078
- attention_mask,
1079
- inputs_embeds,
1080
- cache_position,
1081
- past_key_values,
1082
- output_attentions,
1083
- )
1084
-
1085
- hidden_states = inputs_embeds
1086
-
1087
- # decoder layers
1088
- all_hidden_states = () if output_hidden_states else None
1089
- all_self_attns = () if output_attentions else None
1090
- next_decoder_cache = None
1091
-
1092
- for decoder_layer in self.h:
1093
- if output_hidden_states:
1094
- all_hidden_states += (hidden_states,)
1095
-
1096
- if self.gradient_checkpointing and self.training:
1097
- layer_outputs = self._gradient_checkpointing_func(
1098
- decoder_layer.__call__,
1099
- hidden_states,
1100
- causal_mask,
1101
- position_ids,
1102
- output_attentions,
1103
- use_cache,
1104
- past_key_values,
1105
- cache_position,
1106
- )
1107
- else:
1108
- layer_outputs = decoder_layer(
1109
- hidden_states,
1110
- attention_mask=causal_mask,
1111
- position_ids=position_ids,
1112
- past_key_value=past_key_values,
1113
- output_attentions=output_attentions,
1114
- use_cache=use_cache,
1115
- cache_position=cache_position,
1116
- )
1117
-
1118
- hidden_states = layer_outputs[0]
1119
-
1120
- if use_cache:
1121
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1122
-
1123
- if output_attentions:
1124
- all_self_attns += (layer_outputs[1],)
1125
-
1126
- # add hidden states from the last decoder layer
1127
- if output_hidden_states:
1128
- all_hidden_states += (hidden_states,)
1129
-
1130
- next_cache = None
1131
- if use_cache:
1132
- next_cache = (
1133
- next_decoder_cache.to_legacy_cache()
1134
- if use_legacy_cache
1135
- else next_decoder_cache
1136
- )
1137
- if not return_dict:
1138
- return tuple(
1139
- v
1140
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1141
- if v is not None
1142
- )
1143
- return BaseModelOutputWithPast(
1144
- last_hidden_state=hidden_states,
1145
- past_key_values=next_cache,
1146
- hidden_states=all_hidden_states,
1147
- attentions=all_self_attns,
1148
- )
1149
-
1150
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
1151
- def _update_causal_mask(
1152
- self,
1153
- attention_mask: torch.Tensor,
1154
- input_tensor: torch.Tensor,
1155
- cache_position: torch.Tensor,
1156
- past_key_values: Cache,
1157
- output_attentions: bool,
1158
- ):
1159
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1160
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1161
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1162
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1163
-
1164
- if self.config._attn_implementation == "flash_attention_2":
1165
- if attention_mask is not None and 0.0 in attention_mask:
1166
- return attention_mask
1167
- return None
1168
-
1169
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1170
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1171
- # to infer the attention mask.
1172
- past_seen_tokens = (
1173
- past_key_values.get_seq_length() if past_key_values is not None else 0
1174
- )
1175
- using_static_cache = isinstance(past_key_values, StaticCache)
1176
-
1177
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1178
- if (
1179
- self.config._attn_implementation == "sdpa"
1180
- and not using_static_cache
1181
- and not output_attentions
1182
- ):
1183
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1184
- attention_mask,
1185
- inputs_embeds=input_tensor,
1186
- past_key_values_length=past_seen_tokens,
1187
- is_training=self.training,
1188
- ):
1189
- return None
1190
-
1191
- dtype, device = input_tensor.dtype, input_tensor.device
1192
- min_dtype = torch.finfo(dtype).min
1193
- sequence_length = input_tensor.shape[1]
1194
- if using_static_cache:
1195
- target_length = past_key_values.get_max_length()
1196
- else:
1197
- target_length = (
1198
- attention_mask.shape[-1]
1199
- if isinstance(attention_mask, torch.Tensor)
1200
- else past_seen_tokens + sequence_length + 1
1201
- )
1202
-
1203
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1204
- causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1205
- attention_mask,
1206
- sequence_length=sequence_length,
1207
- target_length=target_length,
1208
- dtype=dtype,
1209
- device=device,
1210
- min_dtype=min_dtype,
1211
- cache_position=cache_position,
1212
- batch_size=input_tensor.shape[0],
1213
- )
1214
-
1215
- if (
1216
- self.config._attn_implementation == "sdpa"
1217
- and attention_mask is not None
1218
- and attention_mask.device.type == "cuda"
1219
- and not output_attentions
1220
- ):
1221
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1222
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1223
- # Details: https://github.com/pytorch/pytorch/issues/110213
1224
- causal_mask = AttentionMaskConverter._unmask_unattended(
1225
- causal_mask, min_dtype
1226
- )
1227
-
1228
- return causal_mask
1229
-
1230
-
1231
- class CausalLMHead(nn.Module):
1232
- """Causal Language Modeling head. Simplified version."""
1233
-
1234
- def __init__(self, config):
1235
- super().__init__()
1236
- self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1237
- self.linear = nn.Linear(config.hidden_size, config.vocab_size)
1238
-
1239
- def forward(self, hidden_states):
1240
- return self.linear(self.ln(hidden_states))
1241
-
1242
-
1243
- class PhiForCausalLM(PhiPreTrainedModel):
1244
-
1245
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
1246
- def __init__(self, config):
1247
- super().__init__(config)
1248
- self.transformer = PhiModel(config)
1249
- self.vocab_size = config.vocab_size
1250
- self.lm_head = CausalLMHead(config)
1251
-
1252
- # Initialize weights and apply final processing
1253
- self.post_init()
1254
-
1255
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1256
- def get_input_embeddings(self):
1257
- return self.transformer.embd.wte
1258
-
1259
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1260
- def set_input_embeddings(self, value):
1261
- self.transformer.embd.wte = value
1262
-
1263
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1264
- def get_output_embeddings(self):
1265
- return self.lm_head.linear
1266
-
1267
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1268
- def set_output_embeddings(self, new_embeddings):
1269
- self.lm_head.linear = new_embeddings
1270
-
1271
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1272
- def set_decoder(self, decoder):
1273
- self.model = decoder
1274
-
1275
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1276
- def get_decoder(self):
1277
- return self.model
1278
-
1279
- @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1280
- @replace_return_docstrings(
1281
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
1282
- )
1283
- def forward(
1284
- self,
1285
- input_ids: torch.LongTensor = None,
1286
- attention_mask: Optional[torch.Tensor] = None,
1287
- position_ids: Optional[torch.LongTensor] = None,
1288
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1289
- inputs_embeds: Optional[torch.FloatTensor] = None,
1290
- labels: Optional[torch.LongTensor] = None,
1291
- use_cache: Optional[bool] = None,
1292
- output_attentions: Optional[bool] = None,
1293
- output_hidden_states: Optional[bool] = None,
1294
- return_dict: Optional[bool] = None,
1295
- cache_position: Optional[torch.LongTensor] = None,
1296
- num_logits_to_keep: int = 0,
1297
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
- r"""
1299
- Args:
1300
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
-
1305
- num_logits_to_keep (`int`, *optional*):
1306
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
-
1310
- Returns:
1311
-
1312
- Example:
1313
-
1314
- ```python
1315
- >>> from transformers import AutoTokenizer, PhiForCausalLM
1316
-
1317
- >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1")
1318
- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
1319
-
1320
- >>> prompt = "This is an example script ."
1321
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
-
1323
- >>> # Generate
1324
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
- 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1327
- ```"""
1328
-
1329
- output_attentions = (
1330
- output_attentions
1331
- if output_attentions is not None
1332
- else self.config.output_attentions
1333
- )
1334
- output_hidden_states = (
1335
- output_hidden_states
1336
- if output_hidden_states is not None
1337
- else self.config.output_hidden_states
1338
- )
1339
- return_dict = (
1340
- return_dict if return_dict is not None else self.config.use_return_dict
1341
- )
1342
-
1343
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1344
- outputs = self.transformer(
1345
- input_ids=input_ids,
1346
- attention_mask=attention_mask,
1347
- position_ids=position_ids,
1348
- past_key_values=past_key_values,
1349
- inputs_embeds=inputs_embeds,
1350
- use_cache=use_cache,
1351
- output_attentions=output_attentions,
1352
- output_hidden_states=output_hidden_states,
1353
- return_dict=return_dict,
1354
- cache_position=cache_position,
1355
- )
1356
-
1357
- hidden_states = outputs[0]
1358
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
1359
-
1360
- loss = None
1361
- if labels is not None:
1362
- # Upcast to float if we need to compute the loss to avoid potential precision issues
1363
- logits = logits.float()
1364
- # Shift so that tokens < n predict n
1365
- shift_logits = logits[..., :-1, :].contiguous()
1366
- shift_labels = labels[..., 1:].contiguous()
1367
- # Flatten the tokens
1368
- loss_fct = CrossEntropyLoss()
1369
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1370
- shift_labels = shift_labels.view(-1)
1371
- # Enable model parallelism
1372
- shift_labels = shift_labels.to(shift_logits.device)
1373
- loss = loss_fct(shift_logits, shift_labels)
1374
-
1375
- if not return_dict:
1376
- output = (logits,) + outputs[1:]
1377
- return (loss,) + output if loss is not None else output
1378
-
1379
- return CausalLMOutputWithPast(
1380
- loss=loss,
1381
- logits=logits,
1382
- past_key_values=outputs.past_key_values,
1383
- hidden_states=outputs.hidden_states,
1384
- attentions=outputs.attentions,
1385
- )
1386
-
1387
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1388
- def prepare_inputs_for_generation(
1389
- self,
1390
- input_ids,
1391
- past_key_values=None,
1392
- attention_mask=None,
1393
- inputs_embeds=None,
1394
- cache_position=None,
1395
- position_ids=None,
1396
- use_cache=True,
1397
- num_logits_to_keep=0,
1398
- **kwargs,
1399
- ):
1400
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1401
- # Exception 1: when passing input_embeds, input_ids may be missing entries
1402
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1403
- if past_key_values is not None:
1404
- if inputs_embeds is not None: # Exception 1
1405
- input_ids = input_ids[:, -cache_position.shape[0] :]
1406
- elif (
1407
- input_ids.shape[1] != cache_position.shape[0]
1408
- ): # Default case (the "else", a no op, is Exception 2)
1409
- input_ids = input_ids[:, cache_position]
1410
-
1411
- if attention_mask is not None and position_ids is None:
1412
- # create position_ids on the fly for batch generation
1413
- position_ids = attention_mask.long().cumsum(-1) - 1
1414
- position_ids.masked_fill_(attention_mask == 0, 1)
1415
- if past_key_values:
1416
- position_ids = position_ids[:, -input_ids.shape[1] :]
1417
-
1418
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1419
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1420
-
1421
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1422
- if inputs_embeds is not None and cache_position[0] == 0:
1423
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1424
- else:
1425
- # The clone here is for the same reason as for `position_ids`.
1426
- model_inputs = {
1427
- "input_ids": input_ids.clone(memory_format=torch.contiguous_format),
1428
- "inputs_embeds": None,
1429
- }
1430
-
1431
- if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1432
- if model_inputs["inputs_embeds"] is not None:
1433
- batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
1434
- device = model_inputs["inputs_embeds"].device
1435
- else:
1436
- batch_size, sequence_length = model_inputs["input_ids"].shape
1437
- device = model_inputs["input_ids"].device
1438
-
1439
- dtype = self.lm_head.weight.dtype
1440
- min_dtype = torch.finfo(dtype).min
1441
-
1442
- attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1443
- attention_mask,
1444
- sequence_length=sequence_length,
1445
- target_length=past_key_values.get_max_length(),
1446
- dtype=dtype,
1447
- device=device,
1448
- min_dtype=min_dtype,
1449
- cache_position=cache_position,
1450
- batch_size=batch_size,
1451
- )
1452
-
1453
- model_inputs.update(
1454
- {
1455
- "position_ids": position_ids,
1456
- "cache_position": cache_position,
1457
- "past_key_values": past_key_values,
1458
- "use_cache": use_cache,
1459
- "attention_mask": attention_mask,
1460
- "num_logits_to_keep": num_logits_to_keep,
1461
- }
1462
- )
1463
- return model_inputs