Update modeling_dream.py
Browse files- modeling_dream.py +99 -3
 
    	
        modeling_dream.py
    CHANGED
    
    | 
         @@ -23,6 +23,7 @@ import math 
     | 
|
| 23 | 
         
             
            from typing import List, Optional, Tuple, Union
         
     | 
| 24 | 
         
             
            import os
         
     | 
| 25 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 26 | 
         
             
            import torch.utils.checkpoint
         
     | 
| 27 | 
         
             
            from torch import nn
         
     | 
| 28 | 
         | 
| 
         @@ -47,6 +48,9 @@ from .generation_utils import DreamGenerationMixin, DreamGenerationConfig 
     | 
|
| 47 | 
         | 
| 48 | 
         
             
            if is_flash_attn_2_available():
         
     | 
| 49 | 
         
             
                from transformers.modeling_flash_attention_utils import _flash_attention_forward
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 50 | 
         | 
| 51 | 
         | 
| 52 | 
         
             
            logger = logging.get_logger(__name__)
         
     | 
| 
         @@ -360,7 +364,9 @@ class DreamSdpaAttention(DreamAttention): 
     | 
|
| 360 | 
         
             
                    use_cache: bool = False,
         
     | 
| 361 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 362 | 
         
             
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         
     | 
| 
         | 
|
| 363 | 
         
             
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         
     | 
| 
         | 
|
| 364 | 
         
             
                    if output_attentions:
         
     | 
| 365 | 
         
             
                        # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
         
     | 
| 366 | 
         
             
                        logger.warning_once(
         
     | 
| 
         @@ -378,14 +384,45 @@ class DreamSdpaAttention(DreamAttention): 
     | 
|
| 378 | 
         | 
| 379 | 
         
             
                    bsz, q_len, _ = hidden_states.size()
         
     | 
| 380 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 381 | 
         
             
                    query_states = self.q_proj(hidden_states)
         
     | 
| 382 | 
         
             
                    key_states = self.k_proj(hidden_states)
         
     | 
| 383 | 
         
             
                    value_states = self.v_proj(hidden_states)
         
     | 
| 384 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 385 | 
         
             
                    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
         
     | 
| 386 | 
         
             
                    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 387 | 
         
             
                    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 388 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 389 | 
         
             
                    if position_embeddings is None:
         
     | 
| 390 | 
         
             
                        logger.warning_once(
         
     | 
| 391 | 
         
             
                            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
         
     | 
| 
         @@ -398,6 +435,15 @@ class DreamSdpaAttention(DreamAttention): 
     | 
|
| 398 | 
         
             
                        cos, sin = position_embeddings
         
     | 
| 399 | 
         
             
                    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         
     | 
| 400 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 401 | 
         
             
                    if past_key_value is not None:
         
     | 
| 402 | 
         
             
                        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
         
     | 
| 403 | 
         
             
                        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
         
     | 
| 
         @@ -405,6 +451,18 @@ class DreamSdpaAttention(DreamAttention): 
     | 
|
| 405 | 
         
             
                    key_states = repeat_kv(key_states, self.num_key_value_groups)
         
     | 
| 406 | 
         
             
                    value_states = repeat_kv(value_states, self.num_key_value_groups)
         
     | 
| 407 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 408 | 
         
             
                    # causal_mask = attention_mask
         
     | 
| 409 | 
         
             
                    # if attention_mask is not None:  # no matter the length, we just slice it
         
     | 
| 410 | 
         
             
                    #     causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
         
     | 
| 
         @@ -420,7 +478,14 @@ class DreamSdpaAttention(DreamAttention): 
     | 
|
| 420 | 
         
             
                    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
         
     | 
| 421 | 
         
             
                    # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
         
     | 
| 422 | 
         
             
                    # is_causal = True if causal_mask is None and q_len > 1 else False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 423 | 
         | 
| 
         | 
|
| 424 | 
         
             
                    attn_output = torch.nn.functional.scaled_dot_product_attention(
         
     | 
| 425 | 
         
             
                        query_states,
         
     | 
| 426 | 
         
             
                        key_states,
         
     | 
| 
         @@ -430,9 +495,21 @@ class DreamSdpaAttention(DreamAttention): 
     | 
|
| 430 | 
         
             
                        is_causal=False, # hard coded
         
     | 
| 431 | 
         
             
                    )
         
     | 
| 432 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 433 | 
         
             
                    attn_output = attn_output.transpose(1, 2).contiguous()
         
     | 
| 434 | 
         
             
                    attn_output = attn_output.view(bsz, q_len, self.hidden_size)
         
     | 
| 435 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 436 | 
         
             
                    attn_output = self.o_proj(attn_output)
         
     | 
| 437 | 
         | 
| 438 | 
         
             
                    return attn_output, None, past_key_value
         
     | 
| 
         @@ -466,6 +543,7 @@ class DreamDecoderLayer(nn.Module): 
     | 
|
| 466 | 
         
             
                    use_cache: Optional[bool] = False,
         
     | 
| 467 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 468 | 
         
             
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         
     | 
| 
         | 
|
| 469 | 
         
             
                    **kwargs,
         
     | 
| 470 | 
         
             
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         
     | 
| 471 | 
         
             
                    """
         
     | 
| 
         @@ -489,9 +567,7 @@ class DreamDecoderLayer(nn.Module): 
     | 
|
| 489 | 
         
             
                            Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
         
     | 
| 490 | 
         
             
                            into the model
         
     | 
| 491 | 
         
             
                    """
         
     | 
| 492 | 
         
            -
             
     | 
| 493 | 
         
             
                    residual = hidden_states
         
     | 
| 494 | 
         
            -
             
     | 
| 495 | 
         
             
                    hidden_states = self.input_layernorm(hidden_states)
         
     | 
| 496 | 
         | 
| 497 | 
         
             
                    # Self Attention
         
     | 
| 
         @@ -504,6 +580,7 @@ class DreamDecoderLayer(nn.Module): 
     | 
|
| 504 | 
         
             
                        use_cache=use_cache,
         
     | 
| 505 | 
         
             
                        cache_position=cache_position,
         
     | 
| 506 | 
         
             
                        position_embeddings=position_embeddings,
         
     | 
| 
         | 
|
| 507 | 
         
             
                    )
         
     | 
| 508 | 
         
             
                    hidden_states = residual + hidden_states
         
     | 
| 509 | 
         | 
| 
         @@ -642,7 +719,9 @@ class DreamBaseModel(DreamPreTrainedModel): 
     | 
|
| 642 | 
         
             
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 643 | 
         
             
                    return_dict: Optional[bool] = None,
         
     | 
| 644 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 
         | 
|
| 645 | 
         
             
                ) -> Union[Tuple, BaseModelOutput]:
         
     | 
| 
         | 
|
| 646 | 
         
             
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 647 | 
         
             
                    output_hidden_states = (
         
     | 
| 648 | 
         
             
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 
         @@ -660,7 +739,13 @@ class DreamBaseModel(DreamPreTrainedModel): 
     | 
|
| 660 | 
         
             
                                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
         
     | 
| 661 | 
         
             
                            )
         
     | 
| 662 | 
         
             
                            use_cache = False
         
     | 
| 663 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 664 | 
         
             
                    if inputs_embeds is None:
         
     | 
| 665 | 
         
             
                        inputs_embeds = self.embed_tokens(input_ids)
         
     | 
| 666 | 
         | 
| 
         @@ -678,6 +763,9 @@ class DreamBaseModel(DreamPreTrainedModel): 
     | 
|
| 678 | 
         | 
| 679 | 
         
             
                    hidden_states = inputs_embeds
         
     | 
| 680 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 681 | 
         
             
                    # create position embeddings to be shared across the decoder layers
         
     | 
| 682 | 
         
             
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         
     | 
| 683 | 
         | 
| 
         @@ -711,6 +799,7 @@ class DreamBaseModel(DreamPreTrainedModel): 
     | 
|
| 711 | 
         
             
                                use_cache=use_cache,
         
     | 
| 712 | 
         
             
                                cache_position=cache_position,
         
     | 
| 713 | 
         
             
                                position_embeddings=position_embeddings,
         
     | 
| 
         | 
|
| 714 | 
         
             
                            )
         
     | 
| 715 | 
         | 
| 716 | 
         
             
                        hidden_states = layer_outputs[0]
         
     | 
| 
         @@ -782,8 +871,14 @@ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel): 
     | 
|
| 782 | 
         
             
                    return_dict: Optional[bool] = None,
         
     | 
| 783 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 784 | 
         
             
                    num_logits_to_keep: int = 0,
         
     | 
| 
         | 
|
| 785 | 
         
             
                    **loss_kwargs,
         
     | 
| 786 | 
         
             
                ) -> Union[Tuple, MaskedLMOutput]:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 787 | 
         
             
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 788 | 
         
             
                    output_hidden_states = (
         
     | 
| 789 | 
         
             
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 
         @@ -802,6 +897,7 @@ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel): 
     | 
|
| 802 | 
         
             
                        output_hidden_states=output_hidden_states,
         
     | 
| 803 | 
         
             
                        return_dict=return_dict,
         
     | 
| 804 | 
         
             
                        cache_position=cache_position,
         
     | 
| 
         | 
|
| 805 | 
         
             
                    )
         
     | 
| 806 | 
         | 
| 807 | 
         
             
                    hidden_states = outputs[0]
         
     | 
| 
         | 
|
| 23 | 
         
             
            from typing import List, Optional, Tuple, Union
         
     | 
| 24 | 
         
             
            import os
         
     | 
| 25 | 
         
             
            import torch
         
     | 
| 26 | 
         
            +
            import hashlib
         
     | 
| 27 | 
         
             
            import torch.utils.checkpoint
         
     | 
| 28 | 
         
             
            from torch import nn
         
     | 
| 29 | 
         | 
| 
         | 
|
| 48 | 
         | 
| 49 | 
         
             
            if is_flash_attn_2_available():
         
     | 
| 50 | 
         
             
                from transformers.modeling_flash_attention_utils import _flash_attention_forward
         
     | 
| 51 | 
         
            +
                
         
     | 
| 52 | 
         
            +
            def check_hash(X):
         
     | 
| 53 | 
         
            +
                t = X.detach().cpu().contiguous().view(torch.uint16); print(hashlib.md5(t.numpy().tobytes()).hexdigest())
         
     | 
| 54 | 
         | 
| 55 | 
         | 
| 56 | 
         
             
            logger = logging.get_logger(__name__)
         
     | 
| 
         | 
|
| 364 | 
         
             
                    use_cache: bool = False,
         
     | 
| 365 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 366 | 
         
             
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         
     | 
| 367 | 
         
            +
                    use_flex_attn: Optional[bool] = False,
         
     | 
| 368 | 
         
             
                ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
             
                    if output_attentions:
         
     | 
| 371 | 
         
             
                        # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
         
     | 
| 372 | 
         
             
                        logger.warning_once(
         
     | 
| 
         | 
|
| 384 | 
         | 
| 385 | 
         
             
                    bsz, q_len, _ = hidden_states.size()
         
     | 
| 386 | 
         | 
| 387 | 
         
            +
                    # Debug: Print all hidden_states[0] values
         
     | 
| 388 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 389 | 
         
            +
                    #     f.write(f"\n=== Layer {self.layer_idx} ===\n")
         
     | 
| 390 | 
         
            +
                    #     f.write(f"hidden_states[0] - all positions:\n")
         
     | 
| 391 | 
         
            +
                    #     for idx in range(len(hidden_states[0])):
         
     | 
| 392 | 
         
            +
                    #         f.write(f"  idx {idx}: {hidden_states[0][idx]}\n")
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
             
                    query_states = self.q_proj(hidden_states)
         
     | 
| 395 | 
         
             
                    key_states = self.k_proj(hidden_states)
         
     | 
| 396 | 
         
             
                    value_states = self.v_proj(hidden_states)
         
     | 
| 397 | 
         | 
| 398 | 
         
            +
                    # Debug: Print all QKV[0] values after projection (before view/transpose)
         
     | 
| 399 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 400 | 
         
            +
                    #     f.write(f"\nquery_states[0] (after proj) - all positions:\n")
         
     | 
| 401 | 
         
            +
                    #     for idx in range(len(query_states[0])):
         
     | 
| 402 | 
         
            +
                    #         f.write(f"  idx {idx}: {query_states[0][idx]}\n")
         
     | 
| 403 | 
         
            +
                    #     f.write(f"\nkey_states[0] (after proj) - all positions:\n")
         
     | 
| 404 | 
         
            +
                    #     for idx in range(len(key_states[0])):
         
     | 
| 405 | 
         
            +
                    #         f.write(f"  idx {idx}: {key_states[0][idx]}\n")
         
     | 
| 406 | 
         
            +
                    #     f.write(f"\nvalue_states[0] (after proj) - all positions:\n")
         
     | 
| 407 | 
         
            +
                    #     for idx in range(len(value_states[0])):
         
     | 
| 408 | 
         
            +
                    #         f.write(f"  idx {idx}: {value_states[0][idx]}\n")
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
             
                    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
         
     | 
| 411 | 
         
             
                    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 412 | 
         
             
                    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         
     | 
| 413 | 
         | 
| 414 | 
         
            +
                    # Debug: Print all QKV[0][0] values after view/transpose
         
     | 
| 415 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 416 | 
         
            +
                    #     f.write(f"\nquery_states[0][0] (after view/transpose) - all positions:\n")
         
     | 
| 417 | 
         
            +
                    #     for idx in range(len(query_states[0][0])):
         
     | 
| 418 | 
         
            +
                    #         f.write(f"  idx {idx}: {query_states[0][0][idx]}\n")
         
     | 
| 419 | 
         
            +
                    #     f.write(f"\nkey_states[0][0] (after view/transpose) - all positions:\n")
         
     | 
| 420 | 
         
            +
                    #     for idx in range(len(key_states[0][0])):
         
     | 
| 421 | 
         
            +
                    #         f.write(f"  idx {idx}: {key_states[0][0][idx]}\n")
         
     | 
| 422 | 
         
            +
                    #     f.write(f"\nvalue_states[0][0] (after view/transpose) - all positions:\n")
         
     | 
| 423 | 
         
            +
                    #     for idx in range(len(value_states[0][0])):
         
     | 
| 424 | 
         
            +
                    #         f.write(f"  idx {idx}: {value_states[0][0][idx]}\n")
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
             
                    if position_embeddings is None:
         
     | 
| 427 | 
         
             
                        logger.warning_once(
         
     | 
| 428 | 
         
             
                            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
         
     | 
| 
         | 
|
| 435 | 
         
             
                        cos, sin = position_embeddings
         
     | 
| 436 | 
         
             
                    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         
     | 
| 437 | 
         | 
| 438 | 
         
            +
                    # Debug: Print all QKV[0][0] values after positional embedding
         
     | 
| 439 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 440 | 
         
            +
                    #     f.write(f"\nquery_states[0][0] (after positional embedding) - all positions:\n")
         
     | 
| 441 | 
         
            +
                    #     for idx in range(len(query_states[0][0])):
         
     | 
| 442 | 
         
            +
                    #         f.write(f"  idx {idx}: {query_states[0][0][idx]}\n")
         
     | 
| 443 | 
         
            +
                    #     f.write(f"\nkey_states[0][0] (after positional embedding) - all positions:\n")
         
     | 
| 444 | 
         
            +
                    #     for idx in range(len(key_states[0][0])):
         
     | 
| 445 | 
         
            +
                    #         f.write(f"  idx {idx}: {key_states[0][0][idx]}\n")
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
             
                    if past_key_value is not None:
         
     | 
| 448 | 
         
             
                        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
         
     | 
| 449 | 
         
             
                        key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
         
     | 
| 
         | 
|
| 451 | 
         
             
                    key_states = repeat_kv(key_states, self.num_key_value_groups)
         
     | 
| 452 | 
         
             
                    value_states = repeat_kv(value_states, self.num_key_value_groups)
         
     | 
| 453 | 
         | 
| 454 | 
         
            +
                    # Debug: Print all QKV[0][0] values after grouping
         
     | 
| 455 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 456 | 
         
            +
                    #     f.write(f"\nquery_states[0][0] (after grouping) - all positions:\n")
         
     | 
| 457 | 
         
            +
                    #     for idx in range(len(query_states[0][0])):
         
     | 
| 458 | 
         
            +
                    #         f.write(f"  idx {idx}: {query_states[0][0][idx]}\n")
         
     | 
| 459 | 
         
            +
                    #     f.write(f"\nkey_states[0][0] (after grouping) - all positions:\n")
         
     | 
| 460 | 
         
            +
                    #     for idx in range(len(key_states[0][0])):
         
     | 
| 461 | 
         
            +
                    #         f.write(f"  idx {idx}: {key_states[0][0][idx]}\n")
         
     | 
| 462 | 
         
            +
                    #     f.write(f"\nvalue_states[0][0] (after grouping) - all positions:\n")
         
     | 
| 463 | 
         
            +
                    #     for idx in range(len(value_states[0][0])):
         
     | 
| 464 | 
         
            +
                    #         f.write(f"  idx {idx}: {value_states[0][0][idx]}\n")
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
             
                    # causal_mask = attention_mask
         
     | 
| 467 | 
         
             
                    # if attention_mask is not None:  # no matter the length, we just slice it
         
     | 
| 468 | 
         
             
                    #     causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
         
     | 
| 
         | 
|
| 478 | 
         
             
                    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
         
     | 
| 479 | 
         
             
                    # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
         
     | 
| 480 | 
         
             
                    # is_causal = True if causal_mask is None and q_len > 1 else False
         
     | 
| 481 | 
         
            +
                    if use_flex_attn:
         
     | 
| 482 | 
         
            +
                        # L = attention_mask.shape[0]
         
     | 
| 483 | 
         
            +
                        # attention_mask_inverted = 1 - attention_mask
         
     | 
| 484 | 
         
            +
                        # attention_mask = torch.cat([attention_mask, attention_mask_inverted], dim=1)
         
     | 
| 485 | 
         
            +
                        # attention_mask = torch.cat([attention_mask, torch.zeros(L, 2*L, dtype=attention_mask.dtype, device=attention_mask.device)], dim=0)
         
     | 
| 486 | 
         
            +
                        attention_mask = attention_mask.bool()
         
     | 
| 487 | 
         | 
| 488 | 
         
            +
                    
         
     | 
| 489 | 
         
             
                    attn_output = torch.nn.functional.scaled_dot_product_attention(
         
     | 
| 490 | 
         
             
                        query_states,
         
     | 
| 491 | 
         
             
                        key_states,
         
     | 
| 
         | 
|
| 495 | 
         
             
                        is_causal=False, # hard coded
         
     | 
| 496 | 
         
             
                    )
         
     | 
| 497 | 
         | 
| 498 | 
         
            +
                    # Debug: Print all attn_output[0][0] values after attention
         
     | 
| 499 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 500 | 
         
            +
                    #     f.write(f"\nattn_output[0][0] (after attention) - all positions:\n")
         
     | 
| 501 | 
         
            +
                    #     for idx in range(len(attn_output[0][0])):
         
     | 
| 502 | 
         
            +
                    #         f.write(f"  idx {idx}: {attn_output[0][0][idx]}\n")
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
             
                    attn_output = attn_output.transpose(1, 2).contiguous()
         
     | 
| 505 | 
         
             
                    attn_output = attn_output.view(bsz, q_len, self.hidden_size)
         
     | 
| 506 | 
         | 
| 507 | 
         
            +
                    # Debug: Print all attn_output[0] values after view
         
     | 
| 508 | 
         
            +
                    # with open("mabmcm_mmm.txt", "a") as f:
         
     | 
| 509 | 
         
            +
                    #     f.write(f"\nattn_output[0] (after view) - all positions:\n")
         
     | 
| 510 | 
         
            +
                    #     for idx in range(len(attn_output[0])):
         
     | 
| 511 | 
         
            +
                    #         f.write(f"  idx {idx}: {attn_output[0][idx]}\n")
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
             
                    attn_output = self.o_proj(attn_output)
         
     | 
| 514 | 
         | 
| 515 | 
         
             
                    return attn_output, None, past_key_value
         
     | 
| 
         | 
|
| 543 | 
         
             
                    use_cache: Optional[bool] = False,
         
     | 
| 544 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 545 | 
         
             
                    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
         
     | 
| 546 | 
         
            +
                    use_flex_attn: Optional[bool] = False,
         
     | 
| 547 | 
         
             
                    **kwargs,
         
     | 
| 548 | 
         
             
                ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
         
     | 
| 549 | 
         
             
                    """
         
     | 
| 
         | 
|
| 567 | 
         
             
                            Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
         
     | 
| 568 | 
         
             
                            into the model
         
     | 
| 569 | 
         
             
                    """
         
     | 
| 
         | 
|
| 570 | 
         
             
                    residual = hidden_states
         
     | 
| 
         | 
|
| 571 | 
         
             
                    hidden_states = self.input_layernorm(hidden_states)
         
     | 
| 572 | 
         | 
| 573 | 
         
             
                    # Self Attention
         
     | 
| 
         | 
|
| 580 | 
         
             
                        use_cache=use_cache,
         
     | 
| 581 | 
         
             
                        cache_position=cache_position,
         
     | 
| 582 | 
         
             
                        position_embeddings=position_embeddings,
         
     | 
| 583 | 
         
            +
                        use_flex_attn=use_flex_attn,
         
     | 
| 584 | 
         
             
                    )
         
     | 
| 585 | 
         
             
                    hidden_states = residual + hidden_states
         
     | 
| 586 | 
         | 
| 
         | 
|
| 719 | 
         
             
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 720 | 
         
             
                    return_dict: Optional[bool] = None,
         
     | 
| 721 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 722 | 
         
            +
                    use_flex_attn: Optional[bool]=None,
         
     | 
| 723 | 
         
             
                ) -> Union[Tuple, BaseModelOutput]:
         
     | 
| 724 | 
         
            +
                    
         
     | 
| 725 | 
         
             
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 726 | 
         
             
                    output_hidden_states = (
         
     | 
| 727 | 
         
             
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 
         | 
|
| 739 | 
         
             
                                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
         
     | 
| 740 | 
         
             
                            )
         
     | 
| 741 | 
         
             
                            use_cache = False
         
     | 
| 742 | 
         
            +
                    
         
     | 
| 743 | 
         
            +
                    # Remark: append an [MASK]*L suffix to the input_ids
         
     | 
| 744 | 
         
            +
                    # if use_flex_attn:
         
     | 
| 745 | 
         
            +
                        # mask_id = 151666
         
     | 
| 746 | 
         
            +
                        # L = input_ids.shape[1]
         
     | 
| 747 | 
         
            +
                        # input_ids = torch.cat([input_ids, torch.full((input_ids.shape[0], L), mask_id, dtype=input_ids.dtype, device=input_ids.device)], dim=1)
         
     | 
| 748 | 
         
            +
                    
         
     | 
| 749 | 
         
             
                    if inputs_embeds is None:
         
     | 
| 750 | 
         
             
                        inputs_embeds = self.embed_tokens(input_ids)
         
     | 
| 751 | 
         | 
| 
         | 
|
| 763 | 
         | 
| 764 | 
         
             
                    hidden_states = inputs_embeds
         
     | 
| 765 | 
         | 
| 766 | 
         
            +
                    if use_flex_attn:
         
     | 
| 767 | 
         
            +
                        position_ids = torch.cat([position_ids[:, :16], torch.tensor([[11, 14, 10, 13, 15]], device=position_ids.device)], dim=1)
         
     | 
| 768 | 
         
            +
                    
         
     | 
| 769 | 
         
             
                    # create position embeddings to be shared across the decoder layers
         
     | 
| 770 | 
         
             
                    position_embeddings = self.rotary_emb(hidden_states, position_ids)
         
     | 
| 771 | 
         | 
| 
         | 
|
| 799 | 
         
             
                                use_cache=use_cache,
         
     | 
| 800 | 
         
             
                                cache_position=cache_position,
         
     | 
| 801 | 
         
             
                                position_embeddings=position_embeddings,
         
     | 
| 802 | 
         
            +
                                use_flex_attn=use_flex_attn,
         
     | 
| 803 | 
         
             
                            )
         
     | 
| 804 | 
         | 
| 805 | 
         
             
                        hidden_states = layer_outputs[0]
         
     | 
| 
         | 
|
| 871 | 
         
             
                    return_dict: Optional[bool] = None,
         
     | 
| 872 | 
         
             
                    cache_position: Optional[torch.LongTensor] = None,
         
     | 
| 873 | 
         
             
                    num_logits_to_keep: int = 0,
         
     | 
| 874 | 
         
            +
                    use_flex_attn: bool = False,
         
     | 
| 875 | 
         
             
                    **loss_kwargs,
         
     | 
| 876 | 
         
             
                ) -> Union[Tuple, MaskedLMOutput]:
         
     | 
| 877 | 
         
            +
             
     | 
| 878 | 
         
            +
                    if not use_flex_attn:
         
     | 
| 879 | 
         
            +
                        attention_mask = "full"
         
     | 
| 880 | 
         
            +
                    
         
     | 
| 881 | 
         
            +
                    # Remark: in our method, attention_mask should be an L*L matrix
         
     | 
| 882 | 
         
             
                    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
         
     | 
| 883 | 
         
             
                    output_hidden_states = (
         
     | 
| 884 | 
         
             
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         
     | 
| 
         | 
|
| 897 | 
         
             
                        output_hidden_states=output_hidden_states,
         
     | 
| 898 | 
         
             
                        return_dict=return_dict,
         
     | 
| 899 | 
         
             
                        cache_position=cache_position,
         
     | 
| 900 | 
         
            +
                        use_flex_attn=use_flex_attn,
         
     | 
| 901 | 
         
             
                    )
         
     | 
| 902 | 
         | 
| 903 | 
         
             
                    hidden_states = outputs[0]
         
     |