Yan Wei
commited on
Commit
·
550eb56
0
Parent(s):
Initial commit: DeepSeek Multi-Latent Attention implementation
Browse files- .DS_Store +0 -0
- CONTRIBUTING.md +0 -0
- README.md +126 -0
- assets/mla_architecture.png +0 -0
- assets/mla_formulas.png +0 -0
- insights/architecture.md +106 -0
- insights/attention_mask.md +35 -0
- src/__init__.py +11 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/__pycache__/mla.cpython-311.pyc +0 -0
- src/mla.py +311 -0
- src/tests/__init__.py +0 -0
- src/tests/__pycache__/__init__.cpython-311.pyc +0 -0
- src/tests/__pycache__/test_mla.cpython-311.pyc +0 -0
- src/tests/test_mla.py +138 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
CONTRIBUTING.md
ADDED
File without changes
|
README.md
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DeepSeek Multi-Latent Attention
|
2 |
+
|
3 |
+
A PyTorch implementation of the Multi-Latent Attention (MLA) mechanism introduced in the DeepSeek-V2 paper. MLA significantly reduces KV cache for efficient inference while maintaining model performance through its innovative architecture.
|
4 |
+
|
5 |
+
## Key Features
|
6 |
+
|
7 |
+
- **Low-Rank Key-Value Joint Compression**: Reduces memory footprint during inference
|
8 |
+
- **Decoupled Rotary Position Embedding**: Enables efficient position-aware attention
|
9 |
+
- **Optimized Cache Management**: Handles both compressed KV states and rotary embeddings
|
10 |
+
- **Cross-Attention Support**: Works for both self-attention and cross-attention scenarios
|
11 |
+
|
12 |
+
## Installation
|
13 |
+
Clone this repository:
|
14 |
+
```bash
|
15 |
+
git clone https://huggingface.co/bird-of-paradise/deepseek-mla
|
16 |
+
```
|
17 |
+
Or download directly from the HuggingFace repository page.
|
18 |
+
|
19 |
+
## Quick Start
|
20 |
+
|
21 |
+
```python
|
22 |
+
import torch
|
23 |
+
from src.mla import MultiLatentAttention
|
24 |
+
|
25 |
+
# Initialize MLA
|
26 |
+
mla = MultiLatentAttention(
|
27 |
+
d_model=512, # Model dimension
|
28 |
+
num_head=8, # Number of attention heads
|
29 |
+
d_embed=512, # Embedding dimension
|
30 |
+
d_c=64, # KV compression dimension
|
31 |
+
d_c1=64, # Query compression dimension
|
32 |
+
d_rotate=32, # Rotary embedding dimension
|
33 |
+
)
|
34 |
+
|
35 |
+
# Input sequence
|
36 |
+
x = torch.randn(2, 10, 512) # [batch_size, seq_len, d_model]
|
37 |
+
|
38 |
+
# Forward pass
|
39 |
+
output = mla(x)
|
40 |
+
```
|
41 |
+
|
42 |
+
## Testing
|
43 |
+
|
44 |
+
To run the test suite, execute the following command from the project root directory:
|
45 |
+
|
46 |
+
```bash
|
47 |
+
python -m src.tests.test_mla
|
48 |
+
```
|
49 |
+
|
50 |
+
## Architecture Details
|
51 |
+
|
52 |
+

|
53 |
+
|
54 |
+
MLA combines two key innovations:
|
55 |
+
1. Low-rank compression pathway for efficient KV caching
|
56 |
+
2. Decoupled position-aware pathway using RoPE
|
57 |
+
|
58 |
+
For detailed architectural insights, see [insights/architecture.md](insights/architecture.md).
|
59 |
+
|
60 |
+
## Caching Behavior
|
61 |
+
|
62 |
+
During inference, MLA maintains two caches:
|
63 |
+
```python
|
64 |
+
cache_kv: [batch, max_len, d_c] # Compressed KV states
|
65 |
+
cache_rk: [batch, max_len, d_r] # Shared rotary key
|
66 |
+
```
|
67 |
+
|
68 |
+
For detailed insights on attention masking and caching, see [insights/attention_mask.md](insights/attention_mask.md).
|
69 |
+
|
70 |
+
## Usage Examples
|
71 |
+
|
72 |
+
### Basic Attention
|
73 |
+
|
74 |
+
```python
|
75 |
+
# Standard self-attention
|
76 |
+
output = mla(sequence)
|
77 |
+
|
78 |
+
# Cross-attention
|
79 |
+
output = mla(query, key_value_states=context)
|
80 |
+
```
|
81 |
+
|
82 |
+
### Cached Generation
|
83 |
+
|
84 |
+
```python
|
85 |
+
# Initial forward pass
|
86 |
+
output = mla(prompt, use_cache=True, start_pos=0)
|
87 |
+
|
88 |
+
# Generate tokens using cache
|
89 |
+
for i in range(max_new_tokens):
|
90 |
+
output = mla(next_token, use_cache=True, start_pos=prompt_len + i)
|
91 |
+
```
|
92 |
+
|
93 |
+
## Implementation Details
|
94 |
+
|
95 |
+
The implementation closely follows the formulation in the DeepSeek-V2 paper:
|
96 |
+
|
97 |
+

|
98 |
+
|
99 |
+
Key aspects:
|
100 |
+
- Separate compression pathways for queries and key-values
|
101 |
+
- Position encoding through decoupled RoPE pathway
|
102 |
+
- Efficient cache management for both pathways
|
103 |
+
|
104 |
+
## Contributing
|
105 |
+
|
106 |
+
Contributions are welcome! Feel free to:
|
107 |
+
- Report bugs and issues
|
108 |
+
- Submit pull requests for improvements
|
109 |
+
- Add additional test cases
|
110 |
+
- Provide documentation clarifications
|
111 |
+
|
112 |
+
Please ensure all tests pass before submitting pull requests.
|
113 |
+
|
114 |
+
## Citation
|
115 |
+
```bibtex
|
116 |
+
@misc{deepseek2024,
|
117 |
+
title={DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model},
|
118 |
+
author={DeepSeek-AI and et al.},
|
119 |
+
year={2024},
|
120 |
+
journal={arXiv preprint arXiv:2405.04434}
|
121 |
+
}
|
122 |
+
```
|
123 |
+
|
124 |
+
## License
|
125 |
+
|
126 |
+
[MIT License](LICENSE)
|
assets/mla_architecture.png
ADDED
![]() |
assets/mla_formulas.png
ADDED
![]() |
insights/architecture.md
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Advanced Insights: Multi-Latent Attention Architecture
|
2 |
+
|
3 |
+
## Key Architectural Innovations
|
4 |
+
|
5 |
+
### Compression-Position Decoupling
|
6 |
+
```python
|
7 |
+
# Two parallel pathways with different roles:
|
8 |
+
[b, s, d] -> [b, s, d_c] -> [b, s, d] # Compression pathway
|
9 |
+
[b, s, d] -> [b, s, d_r] -> RoPE() # Position pathway
|
10 |
+
```
|
11 |
+
Critical insight: Matrix multiplication non-commutativity necessitates pathway separation for efficient inference.
|
12 |
+
|
13 |
+
### Asymmetric Dimensionality
|
14 |
+
```
|
15 |
+
Q pathway: per-head rotary dimensions [b, s, n_h, d_r]
|
16 |
+
K pathway: shared rotary dimensions [b, s, 1, d_r]
|
17 |
+
```
|
18 |
+
Design choice allows computational reuse while maintaining positional awareness.
|
19 |
+
|
20 |
+
### Cache Optimization Strategy
|
21 |
+
Two distinct caches with different roles:
|
22 |
+
```python
|
23 |
+
cache_kv: [b, max_len, d_c] # Compressed KV states
|
24 |
+
cache_rk: [b, max_len, d_r] # Shared rotary key
|
25 |
+
```
|
26 |
+
Optimization insight: `d_c + d_r << d_model`, enabling significant memory reduction.
|
27 |
+
|
28 |
+
## Implementation Subtleties
|
29 |
+
|
30 |
+
### Matrix Absorption During Inference
|
31 |
+
```
|
32 |
+
Standard: W^Q @ (W^UK @ c^KV) # Three matrix multiplications
|
33 |
+
Optimized: (W^Q @ W^UK) @ c^KV # Two matrix multiplications
|
34 |
+
```
|
35 |
+
Key requirement: Position-agnostic main pathway enables matrix pre-multiplication.
|
36 |
+
|
37 |
+
### Attention Pattern Evolution
|
38 |
+
```
|
39 |
+
t=1: Pattern[1:1] # Initial token
|
40 |
+
t=2: Pattern[1:2] # One previous token
|
41 |
+
t=n: Pattern[1:n] # Full context window
|
42 |
+
```
|
43 |
+
Cache growth introduces subtle position-dependent patterns requiring careful mask handling.
|
44 |
+
|
45 |
+
### Dimension Flow Control
|
46 |
+
Critical transitions to monitor:
|
47 |
+
```
|
48 |
+
[b, s, d] -> [b, s, d_c] # Down projection
|
49 |
+
[b, s, d_c] -> [b, s+cache, d_c] # Cache concatenation
|
50 |
+
[b, s+cache, d_c] -> [b, s+cache, d] # Up projection
|
51 |
+
```
|
52 |
+
Each transition must preserve both positional and content information.
|
53 |
+
|
54 |
+
## Edge Cases and Considerations
|
55 |
+
|
56 |
+
### Cross-Attention Scenarios
|
57 |
+
```python
|
58 |
+
q_len != kv_len # Length mismatch
|
59 |
+
d_c < d_model # Compression bottleneck
|
60 |
+
```
|
61 |
+
Compression and position information must be maintained across different sequence lengths.
|
62 |
+
|
63 |
+
### Position-Aware Cache Updates
|
64 |
+
```python
|
65 |
+
# Position-dependent attention mask creation
|
66 |
+
mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
|
67 |
+
mask[:, :, i, (start_pos + i + 1):] = -inf # Cannot attend
|
68 |
+
```
|
69 |
+
Mask must evolve with cache to maintain causal attention patterns.
|
70 |
+
|
71 |
+
### Numerical Stability
|
72 |
+
1. Scaling factor accounts for both pathways: `1/sqrt(d_head + d_rotate)`
|
73 |
+
2. Compression dimensions balance between efficiency and representation capacity
|
74 |
+
3. RoPE dimensions impact position encoding granularity
|
75 |
+
|
76 |
+
## Performance Implications
|
77 |
+
|
78 |
+
### Memory Complexity
|
79 |
+
```
|
80 |
+
Standard: O(b * s * d_model)
|
81 |
+
MLA: O(b * s * (d_c + d_r))
|
82 |
+
```
|
83 |
+
Where `d_c + d_r << d_model`
|
84 |
+
|
85 |
+
### Computational Trade-offs
|
86 |
+
1. Additional projections for position pathway
|
87 |
+
2. Reduced cache size enables longer sequences
|
88 |
+
3. Matrix absorption reduces inference compute
|
89 |
+
|
90 |
+
## Integration Considerations
|
91 |
+
|
92 |
+
### Initialization Strategy
|
93 |
+
```python
|
94 |
+
# Critical hyperparameters
|
95 |
+
d_c: Compression dimension
|
96 |
+
d_rotate: Position encoding dimension
|
97 |
+
```
|
98 |
+
Trade-off between compression efficiency and position encoding capacity.
|
99 |
+
|
100 |
+
### Cache Management
|
101 |
+
```python
|
102 |
+
# Two update patterns
|
103 |
+
cache_kv[:, pos:pos+s] = current_kv # Content cache
|
104 |
+
cache_rk[:, pos:pos+s] = current_rk # Position cache
|
105 |
+
```
|
106 |
+
Synchronization between caches crucial for correctness.
|
insights/attention_mask.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Advanced Insights: Attention Masks with KV-Caching
|
2 |
+
|
3 |
+
## Key Pitfalls in Complex Attention Implementations
|
4 |
+
|
5 |
+
### Dimension Evolution with Caching
|
6 |
+
```python
|
7 |
+
# Crucial dimension transitions in cached attention:
|
8 |
+
[b, s, d_model] -> [b, s+cache, d_c] -> [b, s+cache, d_model] -> [b, num_h, s, d_head]
|
9 |
+
```
|
10 |
+
The non-obvious trap: even with growing K/V cache, attention output dimensions must match query length, not cached length.
|
11 |
+
|
12 |
+
### Mask Causality with Growing Cache
|
13 |
+
Standard causal masks break with KV-caching - they don't account for position-dependent attention patterns across cached sequences. Critical edge cases:
|
14 |
+
- Token at position `i` must attend to `[0:start_pos+i]`
|
15 |
+
- Naive mask extension leads to incorrect causality preservation
|
16 |
+
- Performance impact of position-wise mask generation
|
17 |
+
|
18 |
+
### Optimization Considerations
|
19 |
+
1. Memory vs Compute tradeoff: Precomputing extended masks vs generating per position
|
20 |
+
2. Batch dimension handling: Mask broadcasting impacts memory usage
|
21 |
+
3. Fused attention patterns may break with custom mask handling
|
22 |
+
|
23 |
+
## Debugging Strategy for Non-Obvious Cases
|
24 |
+
Monitor these dimension transitions for subtle bugs:
|
25 |
+
```python
|
26 |
+
C_KV.shape # Should grow: [b, s₁, d_c] -> [b, s₁+s₂, d_c]
|
27 |
+
K_state.shape # Post-projection growth affects attention patterns
|
28 |
+
att_output.shape # Must maintain query dimensions despite K/V growth
|
29 |
+
```
|
30 |
+
|
31 |
+
## Practical Example: DeepSeek's MLA Edge Case
|
32 |
+
In Multi-Latent Attention, the compressed KV cache introduces subtle interactions with attention masks due to:
|
33 |
+
1. Joint compression affecting position-dependent patterns
|
34 |
+
2. Non-standard dimension flow through compression/decompression
|
35 |
+
3. Mask causality preservation across cached compressed states
|
src/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
DeepSeek Multi-Latent Attention Implementation
|
3 |
+
Copyright (c) 2025
|
4 |
+
|
5 |
+
Implementation of the Multi-Latent Attention mechanism from the DeepSeek-V2 paper.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from .mla import MultiLatentAttention, precompute_freqs_cis, reshape_for_broadcast, apply_rotary_emb
|
9 |
+
|
10 |
+
__version__ = "0.1.0"
|
11 |
+
__all__ = ["MultiLatentAttention", "precompute_freqs_cis", "reshape_for_broadcast","apply_rotary_emb"]
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (580 Bytes). View file
|
|
src/__pycache__/mla.cpython-311.pyc
ADDED
Binary file (14 kB). View file
|
|
src/mla.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
11 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
12 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
13 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
14 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
15 |
+
return freqs_cis
|
16 |
+
|
17 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
18 |
+
ndim = x.ndim
|
19 |
+
assert 0 <= 1 < ndim
|
20 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
21 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
22 |
+
return freqs_cis.view(*shape)
|
23 |
+
|
24 |
+
|
25 |
+
def apply_rotary_emb(
|
26 |
+
xq: torch.Tensor,
|
27 |
+
xk: torch.Tensor,
|
28 |
+
freqs_cis: torch.Tensor,
|
29 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
30 |
+
# Validate input dimensions
|
31 |
+
assert xq.shape[-1] == xk.shape[-1], "Query and Key must have same embedding dimension"
|
32 |
+
assert xq.shape[-1] % 2 == 0, "Embedding dimension must be even"
|
33 |
+
|
34 |
+
# Get sequence lengths
|
35 |
+
q_len = xq.shape[1]
|
36 |
+
k_len = xk.shape[1]
|
37 |
+
|
38 |
+
# Use appropriate part of freqs_cis for each sequence
|
39 |
+
q_freqs = freqs_cis[:q_len]
|
40 |
+
k_freqs = freqs_cis[:k_len]
|
41 |
+
|
42 |
+
# Apply rotary embeddings separately
|
43 |
+
# split last dimention to [xq.shape[:-1]/2, 2]
|
44 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
45 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
46 |
+
|
47 |
+
|
48 |
+
# Reshape freqs for each
|
49 |
+
q_freqs = reshape_for_broadcast(q_freqs, xq_)
|
50 |
+
k_freqs = reshape_for_broadcast(k_freqs, xk_)
|
51 |
+
|
52 |
+
# Works for both [bsz, seqlen, n_heads*head_dim] and [bsz, seqlen, n_heads, head_dim]
|
53 |
+
xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim-1)
|
54 |
+
xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim-1)
|
55 |
+
|
56 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class MultiLatentAttention(nn.Module):
|
62 |
+
"""
|
63 |
+
Multi-Head Latent Attention(MLA) Module As in DeepSeek_V2 pape
|
64 |
+
Key innovation from standard MHA:
|
65 |
+
1. Low-Rank Key-Value Joint Compression
|
66 |
+
2. Decoupled Rotary Position Embedding
|
67 |
+
|
68 |
+
Args:
|
69 |
+
d_model: Total dimension of the model.
|
70 |
+
num_head: Number of attention heads.
|
71 |
+
d_embed: Embedding dimension
|
72 |
+
d_c: K/V compression dimension
|
73 |
+
d_c1: Q compression dimension
|
74 |
+
d_rotate: Dimension for Rotary Position Embedding
|
75 |
+
dropout: Dropout rate for attention scores.
|
76 |
+
bias: Whether to include bias in linear projections.
|
77 |
+
|
78 |
+
d_head: Inferred from d_model//num_head
|
79 |
+
|
80 |
+
Inputs:
|
81 |
+
sequence: input sequence for self-attention and the query for cross-attention
|
82 |
+
key_value_state: input for the key, values for cross-attention
|
83 |
+
"""
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
d_model, # Infer d_head from d_model
|
87 |
+
num_head,
|
88 |
+
d_embed,
|
89 |
+
d_c,
|
90 |
+
d_c1,
|
91 |
+
d_rotate,
|
92 |
+
dropout=0.1,
|
93 |
+
bias=True,
|
94 |
+
max_batch_size=32, # For KV cache sizing
|
95 |
+
max_seq_len=2048 # For KV cache sizing
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
assert d_model % num_head == 0, "d_model must be divisible by num_head"
|
100 |
+
assert d_c < d_embed, "Compression dim should be smaller than embedding dim"
|
101 |
+
assert d_c1 < d_embed, "Query compression dim should be smaller than embedding dim"
|
102 |
+
|
103 |
+
self.d_model = d_model
|
104 |
+
self.num_head = num_head
|
105 |
+
# Verify dimensions match up
|
106 |
+
assert d_model % num_head == 0, f"d_model ({d_model}) must be divisible by num_head ({num_head})"
|
107 |
+
self.d_head=d_model//num_head
|
108 |
+
self.d_embed = d_embed
|
109 |
+
self.d_c = d_c
|
110 |
+
self.d_c1 = d_c1
|
111 |
+
self.d_rotate = d_rotate
|
112 |
+
self.dropout_rate = dropout # Store dropout rate separately
|
113 |
+
|
114 |
+
# Linear down-projection(compression) transformations
|
115 |
+
self.DKV_proj = nn.Linear(d_embed, d_c, bias=bias)
|
116 |
+
self.DQ_proj = nn.Linear(d_embed, d_c1, bias=bias)
|
117 |
+
|
118 |
+
# linear up-projection transformations
|
119 |
+
self.UQ_proj = nn.Linear(d_c1, d_model, bias=bias)
|
120 |
+
self.UK_proj = nn.Linear(d_c, d_model, bias=bias)
|
121 |
+
self.UV_proj = nn.Linear(d_c, d_model, bias=bias)
|
122 |
+
|
123 |
+
# Linear RoPE-projection
|
124 |
+
self.RQ_proj = nn.Linear(d_c1, num_head*d_rotate, bias=bias)
|
125 |
+
self.RK_proj = nn.Linear(d_embed, d_rotate, bias=bias)
|
126 |
+
|
127 |
+
# linear output transformations
|
128 |
+
self.output_proj = nn.Linear( d_model, d_model, bias=bias)
|
129 |
+
|
130 |
+
# Dropout layer
|
131 |
+
self.dropout = nn.Dropout(p=dropout)
|
132 |
+
|
133 |
+
# Initiialize scaler
|
134 |
+
self.scaler = float(1.0 / math.sqrt(self.d_head + d_rotate)) # Store as float in initialization
|
135 |
+
|
136 |
+
# Initialize C_KV and R_K cache for inference
|
137 |
+
self.cache_kv = torch.zeros(
|
138 |
+
(max_batch_size, max_seq_len, d_c)
|
139 |
+
)
|
140 |
+
self.cache_rk = torch.zeros(
|
141 |
+
(max_batch_size, max_seq_len, d_rotate)
|
142 |
+
)
|
143 |
+
|
144 |
+
# Initialize freqs_cis for RoPE
|
145 |
+
self.freqs_cis = precompute_freqs_cis(
|
146 |
+
d_rotate, max_seq_len * 2
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
sequence,
|
153 |
+
key_value_states = None,
|
154 |
+
att_mask=None,
|
155 |
+
use_cache=False,
|
156 |
+
start_pos: int = 0
|
157 |
+
):
|
158 |
+
|
159 |
+
"""
|
160 |
+
Forward pass supporting both standard attention and cached inference
|
161 |
+
Input shape: [batch_size, seq_len, d_model=num_head * d_head]
|
162 |
+
Args:
|
163 |
+
sequence: Input sequence [batch_size, seq_len, d_model]
|
164 |
+
key_value_states: Optional states for cross-attention
|
165 |
+
att_mask: Optional attention mask
|
166 |
+
use_cache: Whether to use KV caching (for inference)
|
167 |
+
start_pos: Position in sequence when using KV cache
|
168 |
+
"""
|
169 |
+
batch_size, seq_len, model_dim = sequence.size()
|
170 |
+
# prepare for RoPE
|
171 |
+
self.freqs_cis = self.freqs_cis.to(sequence.device)
|
172 |
+
freqs_cis = self.freqs_cis[start_pos : ]
|
173 |
+
|
174 |
+
# Check only critical input dimensions
|
175 |
+
assert model_dim == self.d_model, f"Input dimension {model_dim} doesn't match model dimension {self.d_model}"
|
176 |
+
if key_value_states is not None:
|
177 |
+
assert key_value_states.size(-1) == self.d_model, \
|
178 |
+
f"Cross attention key/value dimension {key_value_states.size(-1)} doesn't match model dimension {self.d_model}"
|
179 |
+
|
180 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
181 |
+
# for the decoder
|
182 |
+
is_cross_attention = key_value_states is not None
|
183 |
+
|
184 |
+
# Determine kv_seq_len early
|
185 |
+
kv_seq_len = key_value_states.size(1) if is_cross_attention else seq_len
|
186 |
+
|
187 |
+
# Linear projections and reshape for multi-head, in the order of Q, K/V
|
188 |
+
# Down and up projection for query
|
189 |
+
C_Q = self.DQ_proj(sequence) #[batch_size, seq_len, d_c1]
|
190 |
+
Q_state = self.UQ_proj(C_Q) #[batch_size, seq_len, d_model]
|
191 |
+
# Linear projection for query RoPE pathway
|
192 |
+
Q_rotate = self.RQ_proj(C_Q) #[batch_size, seq_len, num_head*d_rotate]
|
193 |
+
|
194 |
+
|
195 |
+
if use_cache:
|
196 |
+
#Equation (41) in DeepSeek-v2 paper: cache c^{KV}_t
|
197 |
+
self.cache_kv = self.cache_kv.to(sequence.device)
|
198 |
+
|
199 |
+
# Get current compressed KV states
|
200 |
+
current_kv = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]
|
201 |
+
# Update cache using kv_seq_len instead of seq_len
|
202 |
+
self.cache_kv[:batch_size, start_pos:start_pos + kv_seq_len] = current_kv
|
203 |
+
# Use cached compressed KV up to current position
|
204 |
+
C_KV = self.cache_kv[:batch_size, :start_pos + kv_seq_len]
|
205 |
+
|
206 |
+
#Equation (43) in DeepSeek-v2 paper: cache the RoPE pathwway for shared key k^R_t
|
207 |
+
assert self.cache_rk.size(-1) == self.d_rotate, "RoPE cache dimension mismatch"
|
208 |
+
self.cache_rk = self.cache_rk.to(sequence.device)
|
209 |
+
# Get current RoPE key
|
210 |
+
current_K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_rotate]
|
211 |
+
# Update cache using kv_seq_len instead of seq_len
|
212 |
+
self.cache_rk[:batch_size, start_pos:start_pos + kv_seq_len] = current_K_rotate
|
213 |
+
# Use cached RoPE key up to current position
|
214 |
+
K_rotate = self.cache_rk[:batch_size, :start_pos + kv_seq_len] #[batch_size, cached_len, d_rotate]
|
215 |
+
|
216 |
+
|
217 |
+
"""handling attention mask"""
|
218 |
+
if att_mask is not None:
|
219 |
+
# Get the original mask shape
|
220 |
+
mask_size = att_mask.size(-1)
|
221 |
+
cached_len = start_pos + kv_seq_len # cached key_len, including previous key
|
222 |
+
assert C_KV.size(1) == cached_len, \
|
223 |
+
f"Cached key/value length {C_KV.size(1)} doesn't match theoretical length {cached_len}"
|
224 |
+
|
225 |
+
# Create new mask matching attention matrix shape
|
226 |
+
extended_mask = torch.zeros(
|
227 |
+
(batch_size, 1, seq_len, cached_len), # [batch, head, query_len, key_len]
|
228 |
+
device=att_mask.device,
|
229 |
+
dtype=att_mask.dtype
|
230 |
+
)
|
231 |
+
|
232 |
+
# Fill in the mask appropriately - we need to be careful about the causality here
|
233 |
+
# For each query position, it should only attend to cached positions up to that point
|
234 |
+
for i in range(seq_len):
|
235 |
+
extended_mask[:, :, i, :(start_pos + i + 1)] = 0 # Can attend
|
236 |
+
extended_mask[:, :, i, (start_pos + i + 1):] = float('-inf') # Cannot attend
|
237 |
+
|
238 |
+
att_mask = extended_mask
|
239 |
+
else:
|
240 |
+
# Compression projection for C_KV
|
241 |
+
C_KV = self.DKV_proj(key_value_states if is_cross_attention else sequence) #[batch_size, kv_seq_len, d_c]\
|
242 |
+
# RoPE pathway for *shared* key
|
243 |
+
K_rotate = self.RK_proj(key_value_states if is_cross_attention else sequence)
|
244 |
+
|
245 |
+
|
246 |
+
# Up projection for key and value
|
247 |
+
K_state = self.UK_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
|
248 |
+
V_state = self.UV_proj(C_KV) #[batch_size, kv_seq_len/cached_len, d_model]
|
249 |
+
|
250 |
+
|
251 |
+
Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head)
|
252 |
+
|
253 |
+
# After getting K_state from projection, get its actual sequence length
|
254 |
+
actual_kv_len = K_state.size(1) # kv_seq_len or start_pos + kv_seq_len
|
255 |
+
# in cross-attention, key/value sequence length might be different from query sequence length
|
256 |
+
# Use actual_kv_len instead of kv_seq_len for reshaping
|
257 |
+
K_state = K_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
|
258 |
+
V_state = V_state.view(batch_size, actual_kv_len, self.num_head, self.d_head)
|
259 |
+
|
260 |
+
|
261 |
+
#Apply RoPE to query and shared key
|
262 |
+
Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate)
|
263 |
+
K_rotate = K_rotate.unsqueeze(2).expand(-1, -1, self.num_head, -1) # [batch, cached_len, num_head, d_rotate]
|
264 |
+
Q_rotate, K_rotate = apply_rotary_emb(Q_rotate, K_rotate, freqs_cis=freqs_cis)
|
265 |
+
|
266 |
+
|
267 |
+
# Concatenate along head dimension
|
268 |
+
Q_state = torch.cat([Q_state, Q_rotate], dim=-1) # [batch_size, seq_len, num_head, d_head + d_rotate]
|
269 |
+
K_state = torch.cat([K_state, K_rotate], dim=-1) # [batch_size, actual_kv_len, num_head, d_head + d_rotate]
|
270 |
+
|
271 |
+
|
272 |
+
# Scale Q by 1/sqrt(d_k)
|
273 |
+
Q_state = Q_state * self.scaler
|
274 |
+
Q_state = Q_state.transpose(1, 2) # [batch_size, num_head, seq_len, head_dim]
|
275 |
+
K_state = K_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
|
276 |
+
V_state = V_state.transpose(1, 2) # [batch_size, num_head, actual_kv_len, head_dim]
|
277 |
+
|
278 |
+
|
279 |
+
# Compute attention matrix: QK^T
|
280 |
+
self.att_matrix = torch.matmul(Q_state, K_state.transpose(-1,-2))
|
281 |
+
|
282 |
+
# apply attention mask to attention matrix
|
283 |
+
if att_mask is not None and not isinstance(att_mask, torch.Tensor):
|
284 |
+
raise TypeError("att_mask must be a torch.Tensor")
|
285 |
+
|
286 |
+
if att_mask is not None:
|
287 |
+
self.att_matrix = self.att_matrix + att_mask
|
288 |
+
|
289 |
+
# apply softmax to the last dimension to get the attention score: softmax(QK^T)
|
290 |
+
att_score = F.softmax(self.att_matrix, dim = -1)
|
291 |
+
|
292 |
+
# apply drop out to attention score
|
293 |
+
att_score = self.dropout(att_score)
|
294 |
+
|
295 |
+
# get final output: softmax(QK^T)V
|
296 |
+
att_output = torch.matmul(att_score, V_state)
|
297 |
+
assert att_output.size(0) == batch_size, "Batch size mismatch"
|
298 |
+
assert att_output.size(2) == seq_len, "Output sequence length should match query sequence length"
|
299 |
+
|
300 |
+
|
301 |
+
# concatinate all attention heads
|
302 |
+
att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_head*self.d_head)
|
303 |
+
|
304 |
+
|
305 |
+
# final linear transformation to the concatenated output
|
306 |
+
att_output = self.output_proj(att_output)
|
307 |
+
|
308 |
+
assert att_output.size() == (batch_size, seq_len, self.d_model), \
|
309 |
+
f"Final output shape {att_output.size()} incorrect"
|
310 |
+
|
311 |
+
return att_output
|
src/tests/__init__.py
ADDED
File without changes
|
src/tests/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (182 Bytes). View file
|
|
src/tests/__pycache__/test_mla.cpython-311.pyc
ADDED
Binary file (6.69 kB). View file
|
|
src/tests/test_mla.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import torch
|
3 |
+
from ..mla import MultiLatentAttention # Using relative import
|
4 |
+
|
5 |
+
class TestMultiLatentAttention(unittest.TestCase):
|
6 |
+
def setUp(self):
|
7 |
+
# Common dimensions for testing
|
8 |
+
self.d_model = 512
|
9 |
+
self.num_head = 8
|
10 |
+
self.d_embed = 512
|
11 |
+
self.d_c = 64 # Compression dim for K/V
|
12 |
+
self.d_c1 = 64 # Compression dim for Q
|
13 |
+
self.d_rotate = 32 # For future RoPE implementation
|
14 |
+
self.batch_size = 2
|
15 |
+
self.seq_len = 10
|
16 |
+
|
17 |
+
# Initialize MLA
|
18 |
+
self.mla = MultiLatentAttention(
|
19 |
+
d_model=self.d_model,
|
20 |
+
num_head=self.num_head,
|
21 |
+
d_embed=self.d_embed,
|
22 |
+
d_c=self.d_c,
|
23 |
+
d_c1=self.d_c1,
|
24 |
+
d_rotate=self.d_rotate
|
25 |
+
)
|
26 |
+
|
27 |
+
def test_basic_forward(self):
|
28 |
+
"""Test basic forward pass without caching"""
|
29 |
+
x = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
30 |
+
output = self.mla(x)
|
31 |
+
|
32 |
+
# Check output shape
|
33 |
+
self.assertEqual(
|
34 |
+
output.shape,
|
35 |
+
(self.batch_size, self.seq_len, self.d_model),
|
36 |
+
"Output shape mismatch"
|
37 |
+
)
|
38 |
+
|
39 |
+
def test_cross_attention(self):
|
40 |
+
"""Test cross-attention functionality"""
|
41 |
+
query = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
42 |
+
kv = torch.randn(self.batch_size, self.seq_len * 2, self.d_model) # Different seq_len
|
43 |
+
|
44 |
+
output = self.mla(query, key_value_states=kv)
|
45 |
+
self.assertEqual(
|
46 |
+
output.shape,
|
47 |
+
(self.batch_size, self.seq_len, self.d_model),
|
48 |
+
"Cross-attention output shape mismatch"
|
49 |
+
)
|
50 |
+
|
51 |
+
def test_cache_initialization(self):
|
52 |
+
"""Test if cache is properly initialized"""
|
53 |
+
x = torch.randn(self.batch_size, self.seq_len, self.d_model)
|
54 |
+
_ = self.mla(x, use_cache=True, start_pos=0)
|
55 |
+
|
56 |
+
self.assertIsNotNone(self.mla.cache_kv)
|
57 |
+
self.assertEqual(
|
58 |
+
self.mla.cache_kv.shape[-1],
|
59 |
+
self.d_c,
|
60 |
+
"Cache compression dimension mismatch"
|
61 |
+
)
|
62 |
+
|
63 |
+
def test_sequential_caching(self):
|
64 |
+
"""Test sequential forward passes with caching"""
|
65 |
+
# Initial sequence
|
66 |
+
prompt_len = 5
|
67 |
+
prompt = torch.randn(self.batch_size, prompt_len, self.d_model)
|
68 |
+
|
69 |
+
# First forward pass with prompt
|
70 |
+
output1 = self.mla(prompt, use_cache=True, start_pos=0)
|
71 |
+
cached_kv_1 = self.mla.cache_kv[:, :prompt_len].clone()
|
72 |
+
|
73 |
+
# Second forward pass with one new token
|
74 |
+
new_token = torch.randn(self.batch_size, 1, self.d_model)
|
75 |
+
output2 = self.mla(new_token, use_cache=True, start_pos=prompt_len)
|
76 |
+
|
77 |
+
# Verify cache consistency
|
78 |
+
# First part of cache should remain unchanged
|
79 |
+
self.assertTrue(
|
80 |
+
torch.allclose(
|
81 |
+
self.mla.cache_kv[:, :prompt_len],
|
82 |
+
cached_kv_1,
|
83 |
+
rtol=1e-5
|
84 |
+
),
|
85 |
+
"Cache was modified for previously processed tokens"
|
86 |
+
)
|
87 |
+
|
88 |
+
# Verify new token was added to cache
|
89 |
+
self.assertFalse(
|
90 |
+
torch.allclose(
|
91 |
+
self.mla.cache_kv[:, prompt_len:prompt_len+1],
|
92 |
+
torch.zeros_like(self.mla.cache_kv[:, prompt_len:prompt_len+1]),
|
93 |
+
rtol=1e-5
|
94 |
+
),
|
95 |
+
"New token was not added to cache"
|
96 |
+
)
|
97 |
+
|
98 |
+
def test_attention_mask_with_cache(self):
|
99 |
+
"""Test attention masking with cached KV"""
|
100 |
+
seq_len = 5
|
101 |
+
x = torch.randn(self.batch_size, seq_len, self.d_model)
|
102 |
+
|
103 |
+
# Create causal mask
|
104 |
+
mask = torch.triu(
|
105 |
+
torch.ones(seq_len, seq_len) * float('-inf'),
|
106 |
+
diagonal=1
|
107 |
+
).unsqueeze(0)
|
108 |
+
|
109 |
+
# First forward pass with mask
|
110 |
+
output1 = self.mla(x, use_cache=True, start_pos=0, att_mask=mask)
|
111 |
+
|
112 |
+
# Second pass with one token
|
113 |
+
new_token = torch.randn(self.batch_size, 1, self.d_model)
|
114 |
+
extended_mask = torch.triu(
|
115 |
+
torch.ones(seq_len + 1, seq_len + 1) * float('-inf'),
|
116 |
+
diagonal=1
|
117 |
+
).unsqueeze(0)
|
118 |
+
|
119 |
+
output2 = self.mla(
|
120 |
+
new_token,
|
121 |
+
use_cache=True,
|
122 |
+
start_pos=seq_len,
|
123 |
+
att_mask=extended_mask
|
124 |
+
)
|
125 |
+
|
126 |
+
self.assertEqual(
|
127 |
+
output2.shape,
|
128 |
+
(self.batch_size, 1, self.d_model),
|
129 |
+
"Output shape incorrect for cached attention with mask"
|
130 |
+
)
|
131 |
+
|
132 |
+
def run_tests():
|
133 |
+
suite = unittest.TestLoader().loadTestsFromTestCase(TestMultiLatentAttention)
|
134 |
+
runner = unittest.TextTestRunner(verbosity=2)
|
135 |
+
runner.run(suite)
|
136 |
+
|
137 |
+
# Run the tests
|
138 |
+
run_tests()
|