Rtian commited on
Commit
7ef64fe
·
verified ·
1 Parent(s): 1b7f58c

Update modeling_dream.py

Browse files
Files changed (1) hide show
  1. modeling_dream.py +99 -3
modeling_dream.py CHANGED
@@ -23,6 +23,7 @@ import math
23
  from typing import List, Optional, Tuple, Union
24
  import os
25
  import torch
 
26
  import torch.utils.checkpoint
27
  from torch import nn
28
 
@@ -47,6 +48,9 @@ from .generation_utils import DreamGenerationMixin, DreamGenerationConfig
47
 
48
  if is_flash_attn_2_available():
49
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
 
 
 
50
 
51
 
52
  logger = logging.get_logger(__name__)
@@ -360,7 +364,9 @@ class DreamSdpaAttention(DreamAttention):
360
  use_cache: bool = False,
361
  cache_position: Optional[torch.LongTensor] = None,
362
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
 
363
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
364
  if output_attentions:
365
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
366
  logger.warning_once(
@@ -378,14 +384,45 @@ class DreamSdpaAttention(DreamAttention):
378
 
379
  bsz, q_len, _ = hidden_states.size()
380
 
 
 
 
 
 
 
 
381
  query_states = self.q_proj(hidden_states)
382
  key_states = self.k_proj(hidden_states)
383
  value_states = self.v_proj(hidden_states)
384
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
386
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
387
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
388
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  if position_embeddings is None:
390
  logger.warning_once(
391
  "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
@@ -398,6 +435,15 @@ class DreamSdpaAttention(DreamAttention):
398
  cos, sin = position_embeddings
399
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
400
 
 
 
 
 
 
 
 
 
 
401
  if past_key_value is not None:
402
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
403
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
@@ -405,6 +451,18 @@ class DreamSdpaAttention(DreamAttention):
405
  key_states = repeat_kv(key_states, self.num_key_value_groups)
406
  value_states = repeat_kv(value_states, self.num_key_value_groups)
407
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  # causal_mask = attention_mask
409
  # if attention_mask is not None: # no matter the length, we just slice it
410
  # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
@@ -420,7 +478,14 @@ class DreamSdpaAttention(DreamAttention):
420
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
421
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
422
  # is_causal = True if causal_mask is None and q_len > 1 else False
 
 
 
 
 
 
423
 
 
424
  attn_output = torch.nn.functional.scaled_dot_product_attention(
425
  query_states,
426
  key_states,
@@ -430,9 +495,21 @@ class DreamSdpaAttention(DreamAttention):
430
  is_causal=False, # hard coded
431
  )
432
 
 
 
 
 
 
 
433
  attn_output = attn_output.transpose(1, 2).contiguous()
434
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
435
 
 
 
 
 
 
 
436
  attn_output = self.o_proj(attn_output)
437
 
438
  return attn_output, None, past_key_value
@@ -466,6 +543,7 @@ class DreamDecoderLayer(nn.Module):
466
  use_cache: Optional[bool] = False,
467
  cache_position: Optional[torch.LongTensor] = None,
468
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
 
469
  **kwargs,
470
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
471
  """
@@ -489,9 +567,7 @@ class DreamDecoderLayer(nn.Module):
489
  Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
490
  into the model
491
  """
492
-
493
  residual = hidden_states
494
-
495
  hidden_states = self.input_layernorm(hidden_states)
496
 
497
  # Self Attention
@@ -504,6 +580,7 @@ class DreamDecoderLayer(nn.Module):
504
  use_cache=use_cache,
505
  cache_position=cache_position,
506
  position_embeddings=position_embeddings,
 
507
  )
508
  hidden_states = residual + hidden_states
509
 
@@ -642,7 +719,9 @@ class DreamBaseModel(DreamPreTrainedModel):
642
  output_hidden_states: Optional[bool] = None,
643
  return_dict: Optional[bool] = None,
644
  cache_position: Optional[torch.LongTensor] = None,
 
645
  ) -> Union[Tuple, BaseModelOutput]:
 
646
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
647
  output_hidden_states = (
648
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -660,7 +739,13 @@ class DreamBaseModel(DreamPreTrainedModel):
660
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
661
  )
662
  use_cache = False
663
-
 
 
 
 
 
 
664
  if inputs_embeds is None:
665
  inputs_embeds = self.embed_tokens(input_ids)
666
 
@@ -678,6 +763,9 @@ class DreamBaseModel(DreamPreTrainedModel):
678
 
679
  hidden_states = inputs_embeds
680
 
 
 
 
681
  # create position embeddings to be shared across the decoder layers
682
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
683
 
@@ -711,6 +799,7 @@ class DreamBaseModel(DreamPreTrainedModel):
711
  use_cache=use_cache,
712
  cache_position=cache_position,
713
  position_embeddings=position_embeddings,
 
714
  )
715
 
716
  hidden_states = layer_outputs[0]
@@ -782,8 +871,14 @@ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
782
  return_dict: Optional[bool] = None,
783
  cache_position: Optional[torch.LongTensor] = None,
784
  num_logits_to_keep: int = 0,
 
785
  **loss_kwargs,
786
  ) -> Union[Tuple, MaskedLMOutput]:
 
 
 
 
 
787
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
788
  output_hidden_states = (
789
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -802,6 +897,7 @@ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
802
  output_hidden_states=output_hidden_states,
803
  return_dict=return_dict,
804
  cache_position=cache_position,
 
805
  )
806
 
807
  hidden_states = outputs[0]
 
23
  from typing import List, Optional, Tuple, Union
24
  import os
25
  import torch
26
+ import hashlib
27
  import torch.utils.checkpoint
28
  from torch import nn
29
 
 
48
 
49
  if is_flash_attn_2_available():
50
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
51
+
52
+ def check_hash(X):
53
+ t = X.detach().cpu().contiguous().view(torch.uint16); print(hashlib.md5(t.numpy().tobytes()).hexdigest())
54
 
55
 
56
  logger = logging.get_logger(__name__)
 
364
  use_cache: bool = False,
365
  cache_position: Optional[torch.LongTensor] = None,
366
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
367
+ use_flex_attn: Optional[bool] = False,
368
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
369
+
370
  if output_attentions:
371
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
372
  logger.warning_once(
 
384
 
385
  bsz, q_len, _ = hidden_states.size()
386
 
387
+ # Debug: Print all hidden_states[0] values
388
+ # with open("mabmcm_mmm.txt", "a") as f:
389
+ # f.write(f"\n=== Layer {self.layer_idx} ===\n")
390
+ # f.write(f"hidden_states[0] - all positions:\n")
391
+ # for idx in range(len(hidden_states[0])):
392
+ # f.write(f" idx {idx}: {hidden_states[0][idx]}\n")
393
+
394
  query_states = self.q_proj(hidden_states)
395
  key_states = self.k_proj(hidden_states)
396
  value_states = self.v_proj(hidden_states)
397
 
398
+ # Debug: Print all QKV[0] values after projection (before view/transpose)
399
+ # with open("mabmcm_mmm.txt", "a") as f:
400
+ # f.write(f"\nquery_states[0] (after proj) - all positions:\n")
401
+ # for idx in range(len(query_states[0])):
402
+ # f.write(f" idx {idx}: {query_states[0][idx]}\n")
403
+ # f.write(f"\nkey_states[0] (after proj) - all positions:\n")
404
+ # for idx in range(len(key_states[0])):
405
+ # f.write(f" idx {idx}: {key_states[0][idx]}\n")
406
+ # f.write(f"\nvalue_states[0] (after proj) - all positions:\n")
407
+ # for idx in range(len(value_states[0])):
408
+ # f.write(f" idx {idx}: {value_states[0][idx]}\n")
409
+
410
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
411
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
412
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
413
 
414
+ # Debug: Print all QKV[0][0] values after view/transpose
415
+ # with open("mabmcm_mmm.txt", "a") as f:
416
+ # f.write(f"\nquery_states[0][0] (after view/transpose) - all positions:\n")
417
+ # for idx in range(len(query_states[0][0])):
418
+ # f.write(f" idx {idx}: {query_states[0][0][idx]}\n")
419
+ # f.write(f"\nkey_states[0][0] (after view/transpose) - all positions:\n")
420
+ # for idx in range(len(key_states[0][0])):
421
+ # f.write(f" idx {idx}: {key_states[0][0][idx]}\n")
422
+ # f.write(f"\nvalue_states[0][0] (after view/transpose) - all positions:\n")
423
+ # for idx in range(len(value_states[0][0])):
424
+ # f.write(f" idx {idx}: {value_states[0][0][idx]}\n")
425
+
426
  if position_embeddings is None:
427
  logger.warning_once(
428
  "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
 
435
  cos, sin = position_embeddings
436
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
437
 
438
+ # Debug: Print all QKV[0][0] values after positional embedding
439
+ # with open("mabmcm_mmm.txt", "a") as f:
440
+ # f.write(f"\nquery_states[0][0] (after positional embedding) - all positions:\n")
441
+ # for idx in range(len(query_states[0][0])):
442
+ # f.write(f" idx {idx}: {query_states[0][0][idx]}\n")
443
+ # f.write(f"\nkey_states[0][0] (after positional embedding) - all positions:\n")
444
+ # for idx in range(len(key_states[0][0])):
445
+ # f.write(f" idx {idx}: {key_states[0][0][idx]}\n")
446
+
447
  if past_key_value is not None:
448
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
449
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
451
  key_states = repeat_kv(key_states, self.num_key_value_groups)
452
  value_states = repeat_kv(value_states, self.num_key_value_groups)
453
 
454
+ # Debug: Print all QKV[0][0] values after grouping
455
+ # with open("mabmcm_mmm.txt", "a") as f:
456
+ # f.write(f"\nquery_states[0][0] (after grouping) - all positions:\n")
457
+ # for idx in range(len(query_states[0][0])):
458
+ # f.write(f" idx {idx}: {query_states[0][0][idx]}\n")
459
+ # f.write(f"\nkey_states[0][0] (after grouping) - all positions:\n")
460
+ # for idx in range(len(key_states[0][0])):
461
+ # f.write(f" idx {idx}: {key_states[0][0][idx]}\n")
462
+ # f.write(f"\nvalue_states[0][0] (after grouping) - all positions:\n")
463
+ # for idx in range(len(value_states[0][0])):
464
+ # f.write(f" idx {idx}: {value_states[0][0][idx]}\n")
465
+
466
  # causal_mask = attention_mask
467
  # if attention_mask is not None: # no matter the length, we just slice it
468
  # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 
478
  # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
479
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
480
  # is_causal = True if causal_mask is None and q_len > 1 else False
481
+ if use_flex_attn:
482
+ # L = attention_mask.shape[0]
483
+ # attention_mask_inverted = 1 - attention_mask
484
+ # attention_mask = torch.cat([attention_mask, attention_mask_inverted], dim=1)
485
+ # attention_mask = torch.cat([attention_mask, torch.zeros(L, 2*L, dtype=attention_mask.dtype, device=attention_mask.device)], dim=0)
486
+ attention_mask = attention_mask.bool()
487
 
488
+
489
  attn_output = torch.nn.functional.scaled_dot_product_attention(
490
  query_states,
491
  key_states,
 
495
  is_causal=False, # hard coded
496
  )
497
 
498
+ # Debug: Print all attn_output[0][0] values after attention
499
+ # with open("mabmcm_mmm.txt", "a") as f:
500
+ # f.write(f"\nattn_output[0][0] (after attention) - all positions:\n")
501
+ # for idx in range(len(attn_output[0][0])):
502
+ # f.write(f" idx {idx}: {attn_output[0][0][idx]}\n")
503
+
504
  attn_output = attn_output.transpose(1, 2).contiguous()
505
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
506
 
507
+ # Debug: Print all attn_output[0] values after view
508
+ # with open("mabmcm_mmm.txt", "a") as f:
509
+ # f.write(f"\nattn_output[0] (after view) - all positions:\n")
510
+ # for idx in range(len(attn_output[0])):
511
+ # f.write(f" idx {idx}: {attn_output[0][idx]}\n")
512
+
513
  attn_output = self.o_proj(attn_output)
514
 
515
  return attn_output, None, past_key_value
 
543
  use_cache: Optional[bool] = False,
544
  cache_position: Optional[torch.LongTensor] = None,
545
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
546
+ use_flex_attn: Optional[bool] = False,
547
  **kwargs,
548
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
549
  """
 
567
  Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
568
  into the model
569
  """
 
570
  residual = hidden_states
 
571
  hidden_states = self.input_layernorm(hidden_states)
572
 
573
  # Self Attention
 
580
  use_cache=use_cache,
581
  cache_position=cache_position,
582
  position_embeddings=position_embeddings,
583
+ use_flex_attn=use_flex_attn,
584
  )
585
  hidden_states = residual + hidden_states
586
 
 
719
  output_hidden_states: Optional[bool] = None,
720
  return_dict: Optional[bool] = None,
721
  cache_position: Optional[torch.LongTensor] = None,
722
+ use_flex_attn: Optional[bool]=None,
723
  ) -> Union[Tuple, BaseModelOutput]:
724
+
725
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
726
  output_hidden_states = (
727
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
739
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
740
  )
741
  use_cache = False
742
+
743
+ # Remark: append an [MASK]*L suffix to the input_ids
744
+ # if use_flex_attn:
745
+ # mask_id = 151666
746
+ # L = input_ids.shape[1]
747
+ # input_ids = torch.cat([input_ids, torch.full((input_ids.shape[0], L), mask_id, dtype=input_ids.dtype, device=input_ids.device)], dim=1)
748
+
749
  if inputs_embeds is None:
750
  inputs_embeds = self.embed_tokens(input_ids)
751
 
 
763
 
764
  hidden_states = inputs_embeds
765
 
766
+ if use_flex_attn:
767
+ position_ids = torch.cat([position_ids[:, :16], torch.tensor([[11, 14, 10, 13, 15]], device=position_ids.device)], dim=1)
768
+
769
  # create position embeddings to be shared across the decoder layers
770
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
771
 
 
799
  use_cache=use_cache,
800
  cache_position=cache_position,
801
  position_embeddings=position_embeddings,
802
+ use_flex_attn=use_flex_attn,
803
  )
804
 
805
  hidden_states = layer_outputs[0]
 
871
  return_dict: Optional[bool] = None,
872
  cache_position: Optional[torch.LongTensor] = None,
873
  num_logits_to_keep: int = 0,
874
+ use_flex_attn: bool = False,
875
  **loss_kwargs,
876
  ) -> Union[Tuple, MaskedLMOutput]:
877
+
878
+ if not use_flex_attn:
879
+ attention_mask = "full"
880
+
881
+ # Remark: in our method, attention_mask should be an L*L matrix
882
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
883
  output_hidden_states = (
884
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
897
  output_hidden_states=output_hidden_states,
898
  return_dict=return_dict,
899
  cache_position=cache_position,
900
+ use_flex_attn=use_flex_attn,
901
  )
902
 
903
  hidden_states = outputs[0]