npvinHnivqn commited on
Commit
f209d0d
·
1 Parent(s): 01d9133

update bug

Browse files
Files changed (1) hide show
  1. modeling_stablelm_epoch.py +2 -2
modeling_stablelm_epoch.py CHANGED
@@ -535,7 +535,7 @@ class DecoderLayer(nn.Module):
535
  bsz, q_len, _ = hidden_states.size()
536
  _, kv_len, _ = cross_states.size()
537
 
538
- cross_attn_mask = torch.ones((bsz, 1, kv_len, q_len), device=hidden_states.device)
539
  hidden_states, cross_attn_weights, _ = self.cross_attn(
540
  hidden_states=hidden_states,
541
  cross_states=cross_states,
@@ -545,7 +545,7 @@ class DecoderLayer(nn.Module):
545
  output_attentions=output_attentions,
546
  use_cache=use_cache,
547
  )
548
- hidden_states = residual + hidden_states
549
 
550
  # Fully Connected
551
  residual = hidden_states
 
535
  bsz, q_len, _ = hidden_states.size()
536
  _, kv_len, _ = cross_states.size()
537
 
538
+ cross_attn_mask = torch.zeros((bsz, 1, kv_len, q_len), device=hidden_states.device)
539
  hidden_states, cross_attn_weights, _ = self.cross_attn(
540
  hidden_states=hidden_states,
541
  cross_states=cross_states,
 
545
  output_attentions=output_attentions,
546
  use_cache=use_cache,
547
  )
548
+ hidden_states = residual# + hidden_states
549
 
550
  # Fully Connected
551
  residual = hidden_states