Yan Wei commited on
Commit
550eb56
·
0 Parent(s):

Initial commit: DeepSeek Multi-Latent Attention implementation

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
CONTRIBUTING.md ADDED
File without changes
README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSeek Multi-Latent Attention
2
+
3
+ A PyTorch implementation of the Multi-Latent Attention (MLA) mechanism introduced in the DeepSeek-V2 paper. MLA significantly reduces KV cache for efficient inference while maintaining model performance through its innovative architecture.
4
+
5
+ ## Key Features
6
+
7
+ - **Low-Rank Key-Value Joint Compression**: Reduces memory footprint during inference
8
+ - **Decoupled Rotary Position Embedding**: Enables efficient position-aware attention
9
+ - **Optimized Cache Management**: Handles both compressed KV states and rotary embeddings
10
+ - **Cross-Attention Support**: Works for both self-attention and cross-attention scenarios
11
+
12
+ ## Installation
13
+ Clone this repository:
14
+ ```bash
15
+ git clone https://huggingface.co/bird-of-paradise/deepseek-mla
16
+ ```
17
+ Or download directly from the HuggingFace repository page.
18
+
19
+ ## Quick Start
20
+
21
+ ```python
22
+ import torch
23
+ from src.mla import MultiLatentAttention
24
+
25
+ # Initialize MLA
26
+ mla = MultiLatentAttention(
27
+ d_model=512, # Model dimension
28
+ num_head=8, # Number of attention heads
29
+ d_embed=512, # Embedding dimension
30
+ d_c=64, # KV compression dimension
31
+ d_c1=64, # Query compression dimension
32
+ d_rotate=32, # Rotary embedding dimension
33
+ )
34
+
35
+ # Input sequence
36
+ x = torch.randn(2, 10, 512) # [batch_size, seq_len, d_model]
37
+
38
+ # Forward pass
39
+ output = mla(x)
40
+ ```
41
+
42
+ ## Testing
43
+
44
+ To run the test suite, execute the following command from the project root directory:
45
+
46
+ ```bash
47
+ python -m src.tests.test_mla
48
+ ```
49
+
50
+ ## Architecture Details
51
+
52
+ ![MLA Architecture](assets/mla_architecture.png)
53
+
54
+ MLA combines two key innovations:
55
+ 1. Low-rank compression pathway for efficient KV caching
56
+ 2. Decoupled position-aware pathway using RoPE
57
+
58
+ For detailed architectural insights, see [insights/architecture.md](insights/architecture.md).
59
+
60
+ ## Caching Behavior
61
+
62
+ During inference, MLA maintains two caches:
63
+ ```python
64
+ cache_kv: [batch, max_len, d_c] # Compressed KV states
65
+ cache_rk: [batch, max_len, d_r] # Shared rotary key
66
+ ```
67
+
68
+ For detailed insights on attention masking and caching, see [insights/attention_mask.md](insights/attention_mask.md).
69
+
70
+ ## Usage Examples
71
+
72
+ ### Basic Attention
73
+
74
+ ```python
75
+ # Standard self-attention
76
+ output = mla(sequence)
77
+
78
+ # Cross-attention
79
+ output = mla(query, key_value_states=context)
80
+ ```
81
+
82
+ ### Cached Generation
83
+
84
+ ```python
85
+ # Initial forward pass
86
+ output = mla(prompt, use_cache=True, start_pos=0)
87
+
88
+ # Generate tokens using cache
89
+ for i in range(max_new_tokens):
90
+ output = mla(next_token, use_cache=True, start_pos=prompt_len + i)
91
+ ```
92
+
93
+ ## Implementation Details
94
+
95
+ The implementation closely follows the formulation in the DeepSeek-V2 paper:
96
+
97
+ ![MLA Formulas](assets/mla_formulas.png)
98
+
99
+ Key aspects:
100
+ - Separate compression pathways for queries and key-values
101
+ - Position encoding through decoupled RoPE pathway
102
+ - Efficient cache management for both pathways
103
+
104
+ ## Contributing
105
+
106
+ Contributions are welcome! Feel free to:
107
+ - Report bugs and issues
108
+ - Submit pull requests for improvements
109
+ - Add additional test cases
110
+ - Provide documentation clarifications
111
+
112
+ Please ensure all tests pass before submitting pull requests.
113
+
114
+ ## Citation
115
+ ```bibtex
116
+ @misc{deepseek2024,
117
+ title={DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model},
118
+ author={DeepSeek-AI and et al.},
119
+ year={2024},
120
+ journal={arXiv preprint arXiv:2405.04434}
121
+ }
122
+ ```
123
+
124
+ ## License
125
+
126
+ [MIT License](LICENSE)
assets/mla_architecture.png ADDED
assets/mla_formulas.png ADDED
insights/architecture.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Advanced Insights: Multi-Latent Attention Architecture
2
+
3
+ ## Key Architectural Innovations
4
+
5
+ ### Compression-Position Decoupling
6
+ ```python
7
+ # Two parallel pathways with different roles:
8
+ [b, s, d] -> [b, s, d_c] -> [b, s, d] # Compression pathway
9
+ [b, s, d] -> [b, s, d_r] -> RoPE() # Position pathway
10
+ ```
11
+ Critical insight: Matrix multiplication non-commutativity necessitates pathway separation for efficient inference.
12
+
13
+ ### Asymmetric Dimensionality
14
+ ```
15
+ Q pathway: per-head rotary dimensions [b, s, n_h, d_r]
16
+ K pathway: shared rotary dimensions [b, s, 1, d_r]
17
+ ```
18
+ Design choice allows computational reuse while maintaining positional awareness.
19
+
20
+ ### Cache Optimization Strategy
21
+ Two distinct caches with different roles:
22
+ ```python
23
+ cache_kv: [b, max_len, d_c] # Compressed KV states
24
+ cache_rk: [b, max_len, d_r] # Shared rotary key
25
+ ```
26
+ Optimization insight: `d_c + d_r << d_model`, enabling significant memory reduction.
27
+
28
+ ## Implementation Subtleties
29
+
30
+ ### Matrix Absorption During Inference
31
+ ```
32
+ Standard: W^Q @ (W^UK @ c^KV) # Three matrix multiplications
33
+ Optimized: (W^Q @ W^UK) @ c^KV # Two matrix multiplications
34
+ ```
35
+ Key requirement: Position-agnostic main pathway enables matrix pre-multiplication.
36
+
37
+ ### Attention Pattern Evolution
38
+ ```
39
+ t=1: Pattern[1:1] # Initial token
40
+ t=2: Pattern[1:2] # One previous token
41
+ t=n: Pattern[1:n] # Full context window
42
+ ```
43
+ Cache growth introduces subtle position-dependent patterns requiring careful mask handling.
44
+
45
+ ### Dimension Flow Control
46
+ Critical transitions to monitor:
47
+ ```
48
+ [b, s, d] -> [b, s, d_c] # Down projection
49
+ [b, s, d_c] -> [b, s+cache, d_c] # Cache concatenation
50
+ [b, s+cache, d_c] -> [b, s+cache, d] # Up projection
51
+ ```
52
+ Each transition must preserve both positional and content information.
53
+
54
+ ## Edge Cases and Considerations
55
+
56
+ ### Cross-Attention Scenarios
57
+ ```python
58
+ q_len != kv_len # Length mismatch
59
+ d_c < d_model # Compression bottleneck
60
+ ```
61
+ Compression and position information must be maintained across different sequence lengths.
62
+
63
+ ### Position-Aware Cache Updates
64
+ ```python
65
+ # Position-dependent attention mask creation
66
+ mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
67
+ mask[:, :, i, (start_pos + i + 1):] = -inf # Cannot attend
68
+ ```
69
+ Mask must evolve with cache to maintain causal attention patterns.
70
+
71
+ ### Numerical Stability
72
+ 1. Scaling factor accounts for both pathways: `1/sqrt(d_head + d_rotate)`
73
+ 2. Compression dimensions balance between efficiency and representation capacity
74
+ 3. RoPE dimensions impact position encoding granularity
75
+
76
+ ## Performance Implications
77
+
78
+ ### Memory Complexity
79
+ ```
80
+ Standard: O(b * s * d_model)
81
+ MLA: O(b * s * (d_c + d_r))
82
+ ```
83
+ Where `d_c + d_r << d_model`
84
+
85
+ ### Computational Trade-offs
86
+ 1. Additional projections for position pathway
87
+ 2. Reduced cache size enables longer sequences
88
+ 3. Matrix absorption reduces inference compute
89
+
90
+ ## Integration Considerations
91
+
92
+ ### Initialization Strategy
93
+ ```python
94
+ # Critical hyperparameters
95
+ d_c: Compression dimension
96
+ d_rotate: Position encoding dimension
97
+ ```
98
+ Trade-off between compression efficiency and position encoding capacity.
99
+
100
+ ### Cache Management
101
+ ```python
102
+ # Two update patterns
103
+ cache_kv[:, pos:pos+s] = current_kv # Content cache
104
+ cache_rk[:, pos:pos+s] = current_rk # Position cache
105
+ ```
106
+ Synchronization between caches crucial for correctness.
insights/attention_mask.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Advanced Insights: Attention Masks with KV-Caching
2
+
3
+ ## Key Pitfalls in Complex Attention Implementations
4
+
5
+ ### Dimension Evolution with Caching
6
+ ```python
7
+ # Crucial dimension transitions in cached attention:
8
+ [b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head]
9
+ ```
10
+ The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length.
11
+
12
+ ### Mask Causality with Growing Cache
13
+ Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases:
14
+ - Token at position `i` must attend to `[0:start_pos+i]`
15
+ - Naive mask extension leads to incorrect causality preservation
16
+ - Performance impact of position-wise mask generation
17
+
18
+ ### Optimization Considerations
19
+ 1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position
20
+ 2. Batch dimension handling: Mask broadcasting impacts memory usage
21
+ 3. Fused attention patterns may break with custom mask handling
22
+
23
+ ## Debugging Strategy for Non-Obvious Cases
24
+ Monitor these dimension transitions for subtle bugs:
25
+ ```python
26
+ C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c]
27
+ K_state.shape # Post-projection growth affects attention patterns
28
+ att_output.shape # Must maintain query dimensions despite K/V growth
29
+ ```
30
+
31
+ ## Practical Example: DeepSeek's MLA Edge Case
32
+ In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to:
33
+ 1. Joint compression affecting position-dependent patterns
34
+ 2. Non-standard dimension flow through compression/decompression
35
+ 3. Mask causality preservation across cached compressed states
src/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepSeek Multi-Latent Attention Implementation
3
+ Copyright (c) 2025
4
+
5
+ Implementation of the Multi-Latent Attention mechanism from the DeepSeek-V2 paper.
6
+ """
7
+
8
+ from .mla import MultiLatentAttention, precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb
9
+
10
+ __version__ = "0.1.0"
11
+ __all__ = ["MultiLatentAttention", "precompute_freqs_cis", "reshape_for_broadcast","apply_rotary_emb"]
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (580 Bytes). View file
 
src/__pycache__/mla.cpython-311.pyc ADDED
Binary file (14 kB). View file
 
src/mla.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Tuple
5
+
6
+ import math
7
+
8
+
9
+
10
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
11
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
12
+ t = torch.arange(end, device=freqs.device) # type: ignore
13
+ freqs = torch.outer(t, freqs).float() # type: ignore
14
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
15
+ return freqs_cis
16
+
17
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
18
+ ndim = x.ndim
19
+ assert 0 <= 1 < ndim
20
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
21
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
22
+ return freqs_cis.view(*shape)
23
+
24
+
25
+ def apply_rotary_emb(
26
+ xq: torch.Tensor,
27
+ xk: torch.Tensor,
28
+ freqs_cis: torch.Tensor,
29
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ # Validate input dimensions
31
+ assert xq.shape[-1] == xk.shape[-1], "Query and Key must have same embedding dimension"
32
+ assert xq.shape[-1] % 2 == 0, "Embedding dimension must be even"
33
+
34
+ # Get sequence lengths
35
+ q_len = xq.shape[1]
36
+ k_len = xk.shape[1]
37
+
38
+ # Use appropriate part of freqs_cis for each sequence
39
+ q_freqs = freqs_cis[:q_len]
40
+ k_freqs = freqs_cis[:k_len]
41
+
42
+ # Apply rotary embeddings separately
43
+ # split last dimention to [xq.shape[:-1]/2, 2]
44
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
45
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
46
+
47
+
48
+ # Reshape freqs for each
49
+ q_freqs = reshape_for_broadcast(q_freqs, xq_)
50
+ k_freqs = reshape_for_broadcast(k_freqs, xk_)
51
+
52
+ # Works for both [bsz, seqlen, n_heads*head_dim] and [bsz, seqlen, n_heads, head_dim]
53
+ xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim-1)
54
+ xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim-1)
55
+
56
+ return xq_out.type_as(xq), xk_out.type_as(xk)
57
+
58
+
59
+
60
+
61
+ class MultiLatentAttention(nn.Module):
62
+ """
63
+ Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
64
+ Key innovation from standard MHA:
65
+ 1. Low-Rank Key-Value Joint Compression
66
+ 2. Decoupled Rotary Position Embedding
67
+
68
+ Args:
69
+ d_model: Total dimension of the model.
70
+ num_head: Number of attention heads.
71
+ d_embed: Embedding dimension
72
+ d_c: K/V compression dimension
73
+ d_c1: Q compression dimension
74
+ d_rotate: Dimension for Rotary Position Embedding
75
+ dropout: Dropout rate for attention scores.
76
+ bias: Whether to include bias in linear projections.
77
+
78
+ d_head: Inferred from d_model//num_head
79
+
80
+ Inputs:
81
+ sequence: input sequence for self-attention and the query for cross-attention
82
+ key_value_state: input for the key, values for cross-attention
83
+ """
84
+ def __init__(
85
+ self,
86
+ d_model, # Infer d_head from d_model
87
+ num_head,
88
+ d_embed,
89
+ d_c,
90
+ d_c1,
91
+ d_rotate,
92
+ dropout=0.1,
93
+ bias=True,
94
+ max_batch_size=32, # For KV cache sizing
95
+ max_seq_len=2048 # For KV cache sizing
96
+ ):
97
+ super().__init__()
98
+
99
+ assert d_model % num_head == 0, "d_model must be divisible by num_head"
100
+ assert d_c < d_embed, "Compression dim should be smaller than embedding dim"
101
+ assert d_c1 < d_embed, "Query compression dim should be smaller than embedding dim"
102
+
103
+ self.d_model = d_model
104
+ self.num_head = num_head
105
+ # Verify dimensions match up
106
+ assert d_model % num_head == 0, f"d_model ({d_model}) must be divisible by num_head ({num_head})"
107
+ self.d_head=d_model//num_head
108
+ self.d_embed = d_embed
109
+ self.d_c = d_c
110
+ self.d_c1 = d_c1
111
+ self.d_rotate = d_rotate
112
+ self.dropout_rate = dropout # Store dropout rate separately
113
+
114
+ # Linear down-projection(compression) transformations
115
+ self.DKV_proj = nn.Linear(d_embed, d_c, bias=bias)
116
+ self.DQ_proj = nn.Linear(d_embed, d_c1, bias=bias)
117
+
118
+ # linear up-projection transformations
119
+ self.UQ_proj = nn.Linear(d_c1, d_model, bias=bias)
120
+ self.UK_proj = nn.Linear(d_c, d_model, bias=bias)
121
+ self.UV_proj = nn.Linear(d_c, d_model, bias=bias)
122
+
123
+ # Linear RoPE-projection
124
+ self.RQ_proj = nn.Linear(d_c1, num_head*d_rotate, bias=bias)
125
+ self.RK_proj = nn.Linear(d_embed, d_rotate, bias=bias)
126
+
127
+ # linear output transformations
128
+ self.output_proj = nn.Linear( d_model, d_model, bias=bias)
129
+
130
+ # Dropout layer
131
+ self.dropout = nn.Dropout(p=dropout)
132
+
133
+ # Initiialize scaler
134
+ self.scaler = float(1.0 / math.sqrt(self.d_head + d_rotate)) # Store as float in initialization
135
+
136
+ # Initialize C_KV and R_K cache for inference
137
+ self.cache_kv = torch.zeros(
138
+ (max_batch_size, max_seq_len, d_c)
139
+ )
140
+ self.cache_rk = torch.zeros(
141
+ (max_batch_size, max_seq_len, d_rotate)
142
+ )
143
+
144
+ # Initialize freqs_cis for RoPE
145
+ self.freqs_cis = precompute_freqs_cis(
146
+ d_rotate, max_seq_len * 2
147
+ )
148
+
149
+
150
+ def forward(
151
+ self,
152
+ sequence,
153
+ key_value_states = None,
154
+ att_mask=None,
155
+ use_cache=False,
156
+ start_pos: int = 0
157
+ ):
158
+
159
+ """
160
+ Forward pass supporting both standard attention and cached inference
161
+ Input shape: [batch_size, seq_len, d_model=num_head * d_head]
162
+ Args:
163
+ sequence: Input sequence [batch_size, seq_len, d_model]
164
+ key_value_states: Optional states for cross-attention
165
+ att_mask: Optional attention mask
166
+ use_cache: Whether to use KV caching (for inference)
167
+ start_pos: Position in sequence when using KV cache
168
+ """
169
+ batch_size, seq_len, model_dim = sequence.size()
170
+ # prepare for RoPE
171
+ self.freqs_cis = self.freqs_cis.to(sequence.device)
172
+ freqs_cis = self.freqs_cis[start_pos : ]
173
+
174
+ # Check only critical input dimensions
175
+ assert model_dim == self.d_model, f"Input dimension {model_dim} doesn't match model dimension {self.d_model}"
176
+ if key_value_states is not None:
177
+ assert key_value_states.size(-1) == self.d_model, \
178
+ f"Cross attention key/value dimension {key_value_states.size(-1)} doesn't match model dimension {self.d_model}"
179
+
180
+ # if key_value_states are provided this layer is used as a cross-attention layer
181
+ # for the decoder
182
+ is_cross_attention = key_value_states is not None
183
+
184
+ # Determine kv_seq_len early
185
+ kv_seq_len = key_value_states.size(1) if is_cross_attention else seq_len
186
+
187
+ # Linear projections and reshape for multi-head, in the order of Q, K/V
188
+ # Down and up projection for query
189
+ C_Q = self.DQ_proj(sequence) #[batch_size, seq_len, d_c1]
190
+ Q_state = self.UQ_proj(C_Q) #[batch_size, seq_len, d_model]
191
+ # Linear projection for query RoPE pathway
192
+ Q_rotate = self.RQ_proj(C_Q) #[batch_size, seq_len, num_head*d_rotate]
193
+
194
+
195
+ if use_cache:
196
+ #Equation (41) in DeepSeek-v2 paper: cache c^{KV}_t
197
+ self.cache_kv = self.cache_kv.to(sequence.device)
198
+
199
+ # Get current compressed KV states
200
+ current_kv = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]
201
+ # Update cache using kv_seq_len instead of seq_len
202
+ self.cache_kv[:batch_size, start_pos:start_pos + kv_seq_len] = current_kv
203
+ # Use cached compressed KV up to current position
204
+ C_KV = self.cache_kv[:batch_size, :start_pos + kv_seq_len]
205
+
206
+ #Equation (43) in DeepSeek-v2 paper: cache the RoPE pathwway for shared key k^R_t
207
+ assert self.cache_rk.size(-1) == self.d_rotate, "RoPE cache dimension mismatch"
208
+ self.cache_rk = self.cache_rk.to(sequence.device)
209
+ # Get current RoPE key
210
+ current_K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_rotate]
211
+ # Update cache using kv_seq_len instead of seq_len
212
+ self.cache_rk[:batch_size, start_pos:start_pos + kv_seq_len] = current_K_rotate
213
+ # Use cached RoPE key up to current position
214
+ K_rotate = self.cache_rk[:batch_size, :start_pos + kv_seq_len] #[batch_size, cached_len, d_rotate]
215
+
216
+
217
+ """handling attention mask"""
218
+ if att_mask is not None:
219
+ # Get the original mask shape
220
+ mask_size = att_mask.size(-1)
221
+ cached_len = start_pos + kv_seq_len # cached key_len, including previous key
222
+ assert C_KV.size(1) == cached_len, \
223
+ f"Cached key/value length {C_KV.size(1)} doesn't match theoretical length {cached_len}"
224
+
225
+ # Create new mask matching attention matrix shape
226
+ extended_mask = torch.zeros(
227
+ (batch_size, 1, seq_len, cached_len), # [batch, head, query_len, key_len]
228
+ device=att_mask.device,
229
+ dtype=att_mask.dtype
230
+ )
231
+
232
+ # Fill in the mask appropriately - we need to be careful about the causality here
233
+ # For each query position, it should only attend to cached positions up to that point
234
+ for i in range(seq_len):
235
+ extended_mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
236
+ extended_mask[:, :, i, (start_pos + i + 1):] = float('-inf') # Cannot attend
237
+
238
+ att_mask = extended_mask
239
+ else:
240
+ # Compression projection for C_KV
241
+ C_KV = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]\
242
+ # RoPE pathway for *shared* key
243
+ K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence)
244
+
245
+
246
+ # Up projection for key and value
247
+ K_state = self.UK_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
248
+ V_state = self.UV_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
249
+
250
+
251
+ Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
252
+
253
+ # After getting K_state from projection, get its actual sequence length
254
+ actual_kv_len = K_state.size(1) # kv_seq_len or start_pos + kv_seq_len
255
+ # in cross-attention, key/value sequence length might be different from query sequence length
256
+ # Use actual_kv_len instead of kv_seq_len for reshaping
257
+ K_state = K_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
258
+ V_state = V_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
259
+
260
+
261
+ #Apply RoPE to query and shared key
262
+ Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
263
+ K_rotate = K_rotate.unsqueeze(2).expand(-1, -1, self.num_head, -1) # [batch, cached_len, num_head, d_rotate]
264
+ Q_rotate, K_rotate = apply_rotary_emb(Q_rotate, K_rotate, freqs_cis=freqs_cis)
265
+
266
+
267
+ # Concatenate along head dimension
268
+ Q_state = torch.cat([Q_state, Q_rotate], dim=-1) # [batch_size, seq_len, num_head, d_head + d_rotate]
269
+ K_state = torch.cat([K_state, K_rotate], dim=-1) # [batch_size, actual_kv_len, num_head, d_head + d_rotate]
270
+
271
+
272
+ # Scale Q by 1/sqrt(d_k)
273
+ Q_state = Q_state * self.scaler
274
+ Q_state = Q_state.transpose(1, 2) # [batch_size, num_head, seq_len, head_dim]
275
+ K_state = K_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
276
+ V_state = V_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
277
+
278
+
279
+ # Compute attention matrix: QK^T
280
+ self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1,-2))
281
+
282
+ # apply attention mask to attention matrix
283
+ if att_mask is not None and not isinstance(att_mask, torch.Tensor):
284
+ raise TypeError("att_mask must be a torch.Tensor")
285
+
286
+ if att_mask is not None:
287
+ self.att_matrix = self.att_matrix + att_mask
288
+
289
+ # apply softmax to the last dimension to get the attention score: softmax(QK^T)
290
+ att_score = F.softmax(self.att_matrix, dim = -1)
291
+
292
+ # apply drop out to attention score
293
+ att_score = self.dropout(att_score)
294
+
295
+ # get final output: softmax(QK^T)V
296
+ att_output = torch.matmul(att_score, V_state)
297
+ assert att_output.size(0) == batch_size, "Batch size mismatch"
298
+ assert att_output.size(2) == seq_len, "Output sequence length should match query sequence length"
299
+
300
+
301
+ # concatinate all attention heads
302
+ att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_head*self.d_head)
303
+
304
+
305
+ # final linear transformation to the concatenated output
306
+ att_output = self.output_proj(att_output)
307
+
308
+ assert att_output.size() == (batch_size, seq_len, self.d_model), \
309
+ f"Final output shape {att_output.size()} incorrect"
310
+
311
+ return att_output
src/tests/__init__.py ADDED
File without changes
src/tests/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (182 Bytes). View file
 
src/tests/__pycache__/test_mla.cpython-311.pyc ADDED
Binary file (6.69 kB). View file
 
src/tests/test_mla.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from ..mla import MultiLatentAttention # Using relative import
4
+
5
+ class TestMultiLatentAttention(unittest.TestCase):
6
+ def setUp(self):
7
+ # Common dimensions for testing
8
+ self.d_model = 512
9
+ self.num_head = 8
10
+ self.d_embed = 512
11
+ self.d_c = 64 # Compression dim for K/V
12
+ self.d_c1 = 64 # Compression dim for Q
13
+ self.d_rotate = 32 # For future RoPE implementation
14
+ self.batch_size = 2
15
+ self.seq_len = 10
16
+
17
+ # Initialize MLA
18
+ self.mla = MultiLatentAttention(
19
+ d_model=self.d_model,
20
+ num_head=self.num_head,
21
+ d_embed=self.d_embed,
22
+ d_c=self.d_c,
23
+ d_c1=self.d_c1,
24
+ d_rotate=self.d_rotate
25
+ )
26
+
27
+ def test_basic_forward(self):
28
+ """Test basic forward pass without caching"""
29
+ x = torch.randn(self.batch_size, self.seq_len, self.d_model)
30
+ output = self.mla(x)
31
+
32
+ # Check output shape
33
+ self.assertEqual(
34
+ output.shape,
35
+ (self.batch_size, self.seq_len, self.d_model),
36
+ "Output shape mismatch"
37
+ )
38
+
39
+ def test_cross_attention(self):
40
+ """Test cross-attention functionality"""
41
+ query = torch.randn(self.batch_size, self.seq_len, self.d_model)
42
+ kv = torch.randn(self.batch_size, self.seq_len * 2, self.d_model) # Different seq_len
43
+
44
+ output = self.mla(query, key_value_states=kv)
45
+ self.assertEqual(
46
+ output.shape,
47
+ (self.batch_size, self.seq_len, self.d_model),
48
+ "Cross-attention output shape mismatch"
49
+ )
50
+
51
+ def test_cache_initialization(self):
52
+ """Test if cache is properly initialized"""
53
+ x = torch.randn(self.batch_size, self.seq_len, self.d_model)
54
+ _ = self.mla(x, use_cache=True, start_pos=0)
55
+
56
+ self.assertIsNotNone(self.mla.cache_kv)
57
+ self.assertEqual(
58
+ self.mla.cache_kv.shape[-1],
59
+ self.d_c,
60
+ "Cache compression dimension mismatch"
61
+ )
62
+
63
+ def test_sequential_caching(self):
64
+ """Test sequential forward passes with caching"""
65
+ # Initial sequence
66
+ prompt_len = 5
67
+ prompt = torch.randn(self.batch_size, prompt_len, self.d_model)
68
+
69
+ # First forward pass with prompt
70
+ output1 = self.mla(prompt, use_cache=True, start_pos=0)
71
+ cached_kv_1 = self.mla.cache_kv[:, :prompt_len].clone()
72
+
73
+ # Second forward pass with one new token
74
+ new_token = torch.randn(self.batch_size, 1, self.d_model)
75
+ output2 = self.mla(new_token, use_cache=True, start_pos=prompt_len)
76
+
77
+ # Verify cache consistency
78
+ # First part of cache should remain unchanged
79
+ self.assertTrue(
80
+ torch.allclose(
81
+ self.mla.cache_kv[:, :prompt_len],
82
+ cached_kv_1,
83
+ rtol=1e-5
84
+ ),
85
+ "Cache was modified for previously processed tokens"
86
+ )
87
+
88
+ # Verify new token was added to cache
89
+ self.assertFalse(
90
+ torch.allclose(
91
+ self.mla.cache_kv[:, prompt_len:prompt_len+1],
92
+ torch.zeros_like(self.mla.cache_kv[:, prompt_len:prompt_len+1]),
93
+ rtol=1e-5
94
+ ),
95
+ "New token was not added to cache"
96
+ )
97
+
98
+ def test_attention_mask_with_cache(self):
99
+ """Test attention masking with cached KV"""
100
+ seq_len = 5
101
+ x = torch.randn(self.batch_size, seq_len, self.d_model)
102
+
103
+ # Create causal mask
104
+ mask = torch.triu(
105
+ torch.ones(seq_len, seq_len) * float('-inf'),
106
+ diagonal=1
107
+ ).unsqueeze(0)
108
+
109
+ # First forward pass with mask
110
+ output1 = self.mla(x, use_cache=True, start_pos=0, att_mask=mask)
111
+
112
+ # Second pass with one token
113
+ new_token = torch.randn(self.batch_size, 1, self.d_model)
114
+ extended_mask = torch.triu(
115
+ torch.ones(seq_len + 1, seq_len + 1) * float('-inf'),
116
+ diagonal=1
117
+ ).unsqueeze(0)
118
+
119
+ output2 = self.mla(
120
+ new_token,
121
+ use_cache=True,
122
+ start_pos=seq_len,
123
+ att_mask=extended_mask
124
+ )
125
+
126
+ self.assertEqual(
127
+ output2.shape,
128
+ (self.batch_size, 1, self.d_model),
129
+ "Output shape incorrect for cached attention with mask"
130
+ )
131
+
132
+ def run_tests():
133
+ suite = unittest.TestLoader().loadTestsFromTestCase(TestMultiLatentAttention)
134
+ runner = unittest.TextTestRunner(verbosity=2)
135
+ runner.run(suite)
136
+
137
+ # Run the tests
138
+ run_tests()