EricB HF Staff commited on
Commit
ed30f9d
·
1 Parent(s): a0903d3

Add metal paged attention

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.so filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.so filter=lfs diff=lfs merge=lfs -text
37
+ *.metallib filter=lfs diff=lfs merge=lfs -text
build.toml CHANGED
@@ -1,5 +1,6 @@
1
  [general]
2
  name = "paged_attention"
 
3
 
4
  [torch]
5
  src = [
@@ -8,6 +9,7 @@ src = [
8
  ]
9
 
10
  [kernel.cuda_utils]
 
11
  src = [
12
  "cuda-utils/cuda_utils_kernels.cu",
13
  ]
@@ -15,6 +17,7 @@ depends = []
15
 
16
 
17
  [kernel.paged_attention]
 
18
  src = [
19
  "paged-attention/attention/attention_dtypes.h",
20
  "paged-attention/attention/attention_generic.cuh",
@@ -37,3 +40,20 @@ src = [
37
  include = [ "." ]
38
  depends = [ "torch" ]
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  [general]
2
  name = "paged_attention"
3
+ universal = false
4
 
5
  [torch]
6
  src = [
 
9
  ]
10
 
11
  [kernel.cuda_utils]
12
+ backend = "cuda"
13
  src = [
14
  "cuda-utils/cuda_utils_kernels.cu",
15
  ]
 
17
 
18
 
19
  [kernel.paged_attention]
20
+ backend = "cuda"
21
  src = [
22
  "paged-attention/attention/attention_dtypes.h",
23
  "paged-attention/attention/attention_generic.cuh",
 
40
  include = [ "." ]
41
  depends = [ "torch" ]
42
 
43
+
44
+ [kernel.paged_attention_metal]
45
+ backend = "metal"
46
+ src = [
47
+ "paged-attention-metal/attention/paged_attention.metal",
48
+ "paged-attention-metal/cache/copy_blocks.metal",
49
+ "paged-attention-metal/cache/reshape_and_cache.metal",
50
+ "paged-attention-metal/convert_fp8.metal",
51
+ "paged-attention-metal/float8.metal",
52
+ "paged-attention-metal/utils.metal",
53
+ "paged-attention-metal/paged_attention.mm",
54
+ "paged-attention-metal/cache.mm",
55
+ "paged-attention-metal/convert_fp8.mm",
56
+ "paged-attention-metal/device.mm",
57
+ ]
58
+ include = [ "." ]
59
+ depends = [ "torch" ]
build/torch27-metal-aarch64-darwin/paged_attention/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._custom_ops import (
2
+ convert_fp8,
3
+ copy_blocks,
4
+ paged_attention_v1,
5
+ paged_attention_v2,
6
+ reshape_and_cache,
7
+ reshape_and_cache_flash,
8
+ swap_blocks,
9
+ )
10
+ from ._ops import ops
11
+
12
+ __all__ = [
13
+ "convert_fp8",
14
+ "copy_blocks",
15
+ "ops",
16
+ "paged_attention_v1",
17
+ "paged_attention_v2",
18
+ "reshape_and_cache",
19
+ "reshape_and_cache_flash",
20
+ "swap_blocks",
21
+ ]
build/torch27-metal-aarch64-darwin/paged_attention/_custom_ops.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+
8
+ # page attention ops
9
+ def paged_attention_v1(
10
+ out: torch.Tensor,
11
+ query: torch.Tensor,
12
+ key_cache: torch.Tensor,
13
+ value_cache: torch.Tensor,
14
+ num_kv_heads: int,
15
+ scale: float,
16
+ block_tables: torch.Tensor,
17
+ seq_lens: torch.Tensor,
18
+ block_size: int,
19
+ max_seq_len: int,
20
+ alibi_slopes: Optional[torch.Tensor],
21
+ kv_cache_dtype: str,
22
+ k_scale: float,
23
+ v_scale: float,
24
+ tp_rank: int = 0,
25
+ blocksparse_local_blocks: int = 0,
26
+ blocksparse_vert_stride: int = 0,
27
+ blocksparse_block_size: int = 64,
28
+ blocksparse_head_sliding_step: int = 0,
29
+ ) -> None:
30
+ ops.paged_attention_v1(
31
+ out,
32
+ query,
33
+ key_cache,
34
+ value_cache,
35
+ num_kv_heads,
36
+ scale,
37
+ block_tables,
38
+ seq_lens,
39
+ block_size,
40
+ max_seq_len,
41
+ alibi_slopes,
42
+ kv_cache_dtype,
43
+ k_scale,
44
+ v_scale,
45
+ tp_rank,
46
+ blocksparse_local_blocks,
47
+ blocksparse_vert_stride,
48
+ blocksparse_block_size,
49
+ blocksparse_head_sliding_step,
50
+ )
51
+
52
+
53
+ def paged_attention_v2(
54
+ out: torch.Tensor,
55
+ exp_sum: torch.Tensor,
56
+ max_logits: torch.Tensor,
57
+ tmp_out: torch.Tensor,
58
+ query: torch.Tensor,
59
+ key_cache: torch.Tensor,
60
+ value_cache: torch.Tensor,
61
+ num_kv_heads: int,
62
+ scale: float,
63
+ block_tables: torch.Tensor,
64
+ seq_lens: torch.Tensor,
65
+ block_size: int,
66
+ max_seq_len: int,
67
+ alibi_slopes: Optional[torch.Tensor],
68
+ kv_cache_dtype: str,
69
+ k_scale: float,
70
+ v_scale: float,
71
+ tp_rank: int = 0,
72
+ blocksparse_local_blocks: int = 0,
73
+ blocksparse_vert_stride: int = 0,
74
+ blocksparse_block_size: int = 64,
75
+ blocksparse_head_sliding_step: int = 0,
76
+ ) -> None:
77
+ ops.paged_attention_v2(
78
+ out,
79
+ exp_sum,
80
+ max_logits,
81
+ tmp_out,
82
+ query,
83
+ key_cache,
84
+ value_cache,
85
+ num_kv_heads,
86
+ scale,
87
+ block_tables,
88
+ seq_lens,
89
+ block_size,
90
+ max_seq_len,
91
+ alibi_slopes,
92
+ kv_cache_dtype,
93
+ k_scale,
94
+ v_scale,
95
+ tp_rank,
96
+ blocksparse_local_blocks,
97
+ blocksparse_vert_stride,
98
+ blocksparse_block_size,
99
+ blocksparse_head_sliding_step,
100
+ )
101
+
102
+
103
+ def reshape_and_cache(
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ key_cache: torch.Tensor,
107
+ value_cache: torch.Tensor,
108
+ slot_mapping: torch.Tensor,
109
+ kv_cache_dtype: str,
110
+ k_scale: float,
111
+ v_scale: float,
112
+ ) -> None:
113
+ ops.reshape_and_cache(
114
+ key,
115
+ value,
116
+ key_cache,
117
+ value_cache,
118
+ slot_mapping,
119
+ kv_cache_dtype,
120
+ k_scale,
121
+ v_scale,
122
+ )
123
+
124
+
125
+ def reshape_and_cache_flash(
126
+ key: torch.Tensor,
127
+ value: torch.Tensor,
128
+ key_cache: torch.Tensor,
129
+ value_cache: torch.Tensor,
130
+ slot_mapping: torch.Tensor,
131
+ kv_cache_dtype: str,
132
+ k_scale: torch.Tensor,
133
+ v_scale: torch.Tensor,
134
+ ) -> None:
135
+ ops.reshape_and_cache_flash(
136
+ key,
137
+ value,
138
+ key_cache,
139
+ value_cache,
140
+ slot_mapping,
141
+ kv_cache_dtype,
142
+ k_scale,
143
+ v_scale,
144
+ )
145
+
146
+
147
+ def copy_blocks(
148
+ key_caches: List[torch.Tensor],
149
+ value_caches: List[torch.Tensor],
150
+ block_mapping: torch.Tensor,
151
+ ) -> None:
152
+ ops.copy_blocks(key_caches, value_caches, block_mapping)
153
+
154
+
155
+ def swap_blocks(
156
+ src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
157
+ ) -> None:
158
+ ops.swap_blocks(src, dst, block_mapping)
159
+
160
+
161
+ def convert_fp8(
162
+ output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
163
+ ) -> None:
164
+ ops.convert_fp8(output, input, scale, kv_dtype)
165
+
166
+
167
+ __all__ = [
168
+ "convert_fp8",
169
+ "paged_attention_v1",
170
+ "paged_attention_v2",
171
+ "reshape_and_cache",
172
+ "copy_blocks",
173
+ ]
build/torch27-metal-aarch64-darwin/paged_attention/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _paged_attention_9678b89
3
+ ops = torch.ops._paged_attention_9678b89
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_paged_attention_9678b89::{op_name}"
build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a94cee9e553d2bdf8d47d0d9461c871b3e57a33cf6cb259807377f0d1b03c7d
3
+ size 214800
build/torch27-metal-aarch64-darwin/paged_attention/_paged_attention_9678b89.metallib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c46eaf21c96da70c5227b2566308a8ef73ae09abf303278f40070dd4326ba0be
3
+ size 4999876
build/torch27-metal-aarch64-darwin/paged_attention/platforms.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from functools import lru_cache, wraps
5
+ from typing import Callable, ParamSpec, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
+
13
+
14
+ class Platform(ABC):
15
+ @classmethod
16
+ def seed_everything(cls, seed: int) -> None:
17
+ """
18
+ Set the seed of each random module.
19
+ `torch.manual_seed` will set seed on all devices.
20
+
21
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+ @abstractmethod
28
+ def get_device_name(self, device_id: int = 0) -> str: ...
29
+
30
+ @abstractmethod
31
+ def is_cuda(self) -> bool: ...
32
+
33
+ @abstractmethod
34
+ def is_rocm(self) -> bool: ...
35
+
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
+
40
+ class CudaPlatform(Platform):
41
+ @classmethod
42
+ @lru_cache(maxsize=8)
43
+ def get_device_name(cls, device_id: int = 0) -> str:
44
+ return torch.cuda.get_device_name(0)
45
+
46
+ def is_cuda(self) -> bool:
47
+ return True
48
+
49
+ def is_rocm(self) -> bool:
50
+ return False
51
+
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
+
56
+ class RocmPlatform(Platform):
57
+ @classmethod
58
+ @lru_cache(maxsize=8)
59
+ def get_device_name(cls, device_id: int = 0) -> str:
60
+ return torch.cuda.get_device_name(device_id)
61
+
62
+ def is_cuda(self) -> bool:
63
+ return False
64
+
65
+ def is_rocm(self) -> bool:
66
+ return True
67
+
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
+
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
flake.lock CHANGED
@@ -1,6 +1,21 @@
1
  {
2
  "nodes": {
3
  "flake-compat": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  "locked": {
5
  "lastModified": 1733328505,
6
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
@@ -33,61 +48,82 @@
33
  "type": "github"
34
  }
35
  },
36
- "kernel-builder": {
37
  "inputs": {
38
- "flake-compat": "flake-compat",
39
- "flake-utils": "flake-utils",
40
- "nixpkgs": "nixpkgs",
41
- "rocm-nix": "rocm-nix"
42
  },
43
  "locked": {
44
- "lastModified": 1744976941,
45
- "narHash": "sha256-+csrhVaT6Mj2j1FM7P2BDITvf1Xwj2AKdMm0IKZK340=",
46
- "owner": "huggingface",
47
- "repo": "kernel-builder",
48
- "rev": "0a278c2e9aaf6003a4ec6fe35c7158624762de5a",
49
  "type": "github"
50
  },
51
  "original": {
52
- "owner": "huggingface",
53
- "repo": "kernel-builder",
54
  "type": "github"
55
  }
56
  },
57
- "nixpkgs": {
 
 
 
 
 
58
  "locked": {
59
- "lastModified": 1743559129,
60
- "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=",
61
- "owner": "nixos",
62
- "repo": "nixpkgs",
63
- "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e",
64
  "type": "github"
65
  },
66
  "original": {
67
- "owner": "nixos",
68
- "ref": "nixos-unstable-small",
69
- "repo": "nixpkgs",
70
  "type": "github"
71
  }
72
  },
73
- "rocm-nix": {
74
  "inputs": {
 
 
 
75
  "nixpkgs": [
76
  "kernel-builder",
 
77
  "nixpkgs"
78
  ]
79
  },
80
  "locked": {
81
- "lastModified": 1743085847,
82
- "narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=",
83
  "owner": "huggingface",
84
- "repo": "rocm-nix",
85
- "rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39",
86
  "type": "github"
87
  },
88
  "original": {
89
  "owner": "huggingface",
90
- "repo": "rocm-nix",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  "type": "github"
92
  }
93
  },
@@ -110,6 +146,21 @@
110
  "repo": "default",
111
  "type": "github"
112
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
  },
115
  "root": "root",
 
1
  {
2
  "nodes": {
3
  "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
  "locked": {
20
  "lastModified": 1733328505,
21
  "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
 
48
  "type": "github"
49
  }
50
  },
51
+ "flake-utils_2": {
52
  "inputs": {
53
+ "systems": "systems_2"
 
 
 
54
  },
55
  "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
  "type": "github"
62
  },
63
  "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
  "type": "github"
67
  }
68
  },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
  "locked": {
76
+ "lastModified": 1750234878,
77
+ "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
 
86
  "type": "github"
87
  }
88
  },
89
+ "kernel-builder": {
90
  "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
  "nixpkgs": [
95
  "kernel-builder",
96
+ "hf-nix",
97
  "nixpkgs"
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1750917308,
102
+ "narHash": "sha256-/kRwI2GgYwhgFwFGZ/tOgQr1qdihidU89ngDviqxTtU=",
103
  "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "5fb8be4d148b5e4d0e2130998d02bafca71520c7",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
  "type": "github"
128
  }
129
  },
 
146
  "repo": "default",
147
  "type": "github"
148
  }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
  }
165
  },
166
  "root": "root",
paged-attention-metal/attention/paged_attention.metal ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Updated from MLX commit has f70764a
2
+
3
+ #include "../utils.metal"
4
+ #include "../float8.metal"
5
+ #include <metal_simdgroup>
6
+ #include <metal_stdlib>
7
+
8
+ using namespace metal;
9
+
10
+ // ========================================== Generic vector types
11
+
12
+ // A vector type to store Q, K, V elements.
13
+ template <typename T, int VEC_SIZE> struct Vec {};
14
+
15
+ // A vector type to store FP32 accumulators.
16
+ template <typename T> struct FloatVec {};
17
+
18
+ // Template vector operations.
19
+ template <typename Acc, typename A, typename B> inline Acc mul(A a, B b);
20
+
21
+ template <typename T> inline float sum(T v);
22
+
23
+ template <typename T> inline float dot(T a, T b) {
24
+ return sum(mul<T, T, T>(a, b));
25
+ }
26
+
27
+ template <typename A, typename T> inline float dot(T a, T b) {
28
+ return sum(mul<A, T, T>(a, b));
29
+ }
30
+
31
+ // FP32 vector data types.
32
+ struct Float8_ {
33
+ float4 x;
34
+ float4 y;
35
+ };
36
+
37
+ template <> struct Vec<float, 1> {
38
+ using Type = float;
39
+ };
40
+ template <> struct Vec<float, 2> {
41
+ using Type = float2;
42
+ };
43
+ template <> struct Vec<float, 4> {
44
+ using Type = float4;
45
+ };
46
+ template <> struct Vec<float, 8> {
47
+ using Type = Float8_;
48
+ };
49
+
50
+ template <> struct FloatVec<float> {
51
+ using Type = float;
52
+ };
53
+ template <> struct FloatVec<float2> {
54
+ using Type = float2;
55
+ };
56
+ template <> struct FloatVec<float4> {
57
+ using Type = float4;
58
+ };
59
+ template <> struct FloatVec<Float8_> {
60
+ using Type = Float8_;
61
+ };
62
+
63
+ template <> inline float mul(float a, float b) { return a * b; }
64
+
65
+ template <> inline float2 mul(float2 a, float2 b) { return a * b; }
66
+
67
+ template <> inline float4 mul(float4 a, float4 b) { return a * b; }
68
+
69
+ template <> inline Float8_ mul(Float8_ a, Float8_ b) {
70
+ Float8_ c;
71
+ c.x = a.x * b.x;
72
+ c.y = a.y * b.y;
73
+ return c;
74
+ }
75
+
76
+ template <> inline float sum(float a) { return a; }
77
+
78
+ template <> inline float sum(float2 a) { return a.x + a.y; }
79
+
80
+ template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; }
81
+
82
+ template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); }
83
+
84
+ inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) {
85
+ Float8_ res;
86
+ res.x = fma(a.x, b.x, c.x);
87
+ res.y = fma(a.y, b.y, c.y);
88
+ return res;
89
+ }
90
+
91
+ inline void from_float(thread float &dst, float src) { dst = src; }
92
+ inline void from_float(thread float2 &dst, float2 src) { dst = src; }
93
+ inline void from_float(thread float4 &dst, float4 src) { dst = src; }
94
+ inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; }
95
+
96
+ // BF16 vector data types.
97
+ // #if defined(__HAVE_BFLOAT__)
98
+
99
+ // struct Bfloat8_ {
100
+ // bfloat4 x;
101
+ // bfloat4 y;
102
+ // };
103
+
104
+ // template<>
105
+ // struct Vec<bfloat, 1> {
106
+ // using Type = bfloat;
107
+ // };
108
+ // template<>
109
+ // struct Vec<bfloat, 2> {
110
+ // using Type = bfloat2;
111
+ // };
112
+ // template<>
113
+ // struct Vec<bfloat, 4> {
114
+ // using Type = bfloat4;
115
+ // };
116
+ // template<>
117
+ // struct Vec<bfloat, 8> {
118
+ // using Type = Bfloat8_;
119
+ // };
120
+
121
+ // template<>
122
+ // struct FloatVec<bfloat> {
123
+ // using Type = float;
124
+ // };
125
+ // template<>
126
+ // struct FloatVec<bfloat2> {
127
+ // using Type = float2;
128
+ // };
129
+ // template<>
130
+ // struct FloatVec<bfloat4> {
131
+ // using Type = float4;
132
+ // };
133
+ // template<>
134
+ // struct FloatVec<Bfloat8_> {
135
+ // using Type = Float8_;
136
+ // };
137
+
138
+ // template<>
139
+ // inline float mul(bfloat a, bfloat b) {
140
+ // return (float)a * (float)b;
141
+ // }
142
+ // template<>
143
+ // inline bfloat mul(bfloat a, bfloat b) {
144
+ // return a*b;
145
+ // }
146
+
147
+ // template<>
148
+ // inline float2 mul(bfloat2 a, bfloat2 b) {
149
+ // return (float2)a * (float2)b;
150
+ // }
151
+ // template<>
152
+ // inline bfloat2 mul(bfloat2 a, bfloat2 b) {
153
+ // return a * b;
154
+ // }
155
+
156
+ // template<>
157
+ // inline float4 mul(bfloat4 a, bfloat4 b) {
158
+ // return (float4)a * (float4)b;
159
+ // }
160
+ // template<>
161
+ // inline bfloat4 mul(bfloat4 a, bfloat4 b) {
162
+ // return a * b;
163
+ // }
164
+
165
+ // template<>
166
+ // inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
167
+ // Float8_ c;
168
+ // c.x = mul<float4, bfloat4, bfloat4>(a.x, b.x);
169
+ // c.y = mul<float4, bfloat4, bfloat4>(a.y, b.y);
170
+ // return c;
171
+ // }
172
+ // template<>
173
+ // inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
174
+ // Bfloat8_ c;
175
+ // c.x = mul<bfloat4, bfloat4, bfloat4>(a.x, b.x);
176
+ // c.y = mul<bfloat4, bfloat4, bfloat4>(a.y, b.y);
177
+ // return c;
178
+ // }
179
+
180
+ // template<>
181
+ // inline float sum(bfloat a) {
182
+ // return (float)a;
183
+ // }
184
+
185
+ // template<>
186
+ // inline float sum(bfloat2 a) {
187
+ // return (float)a.x + (float)a.y;
188
+ // }
189
+
190
+ // template<>
191
+ // inline float sum(bfloat4 a) {
192
+ // return sum(a.x) + sum(a.y);
193
+ // }
194
+
195
+ // template<>
196
+ // inline float sum(Bfloat8_ a) {
197
+ // return sum(a.x) + sum(a.y);
198
+ // }
199
+
200
+ // inline float fma(bfloat a, bfloat b, float c) {
201
+ // return (float)a * (float)b + c;
202
+ // }
203
+
204
+ // inline float2 fma(bfloat2 a, bfloat2 b, float2 c) {
205
+ // return (float2)a * (float2)b + c;
206
+ // }
207
+
208
+ // inline float4 fma(bfloat4 a, bfloat4 b, float4 c) {
209
+ // return (float4)a * (float4)b + c;
210
+ // }
211
+
212
+ // inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
213
+ // Float8_ res;
214
+ // res.x = fma((float4)a.x, (float4)b.x, (float4)c.x);
215
+ // res.y = fma((float4)a.y, (float4)b.y, (float4)c.y);
216
+ // return res;
217
+ // }
218
+ // inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
219
+ // Bfloat8_ res;
220
+ // res.x = (bfloat4)fma((float4)a.x, (float4)b.x, (float4)c.x);
221
+ // res.y = (bfloat4)fma((float4)a.y, (float4)b.x, (float4)c.y);
222
+ // return c;
223
+ // }
224
+
225
+ // inline void from_float(thread bfloat& dst, float src) {
226
+ // dst = static_cast<bfloat>(src);
227
+ // }
228
+ // inline void from_float(thread bfloat2& dst, float2 src) {
229
+ // dst.x = static_cast<bfloat>(src.x);
230
+ // dst.y = static_cast<bfloat>(src.y);
231
+ // }
232
+ // inline void from_float(thread bfloat4& dst, float4 src) {
233
+ // dst.x = static_cast<bfloat>(src.x);
234
+ // dst.y = static_cast<bfloat>(src.y);
235
+ // dst.z = static_cast<bfloat>(src.z);
236
+ // dst.w = static_cast<bfloat>(src.w);
237
+ // }
238
+ // inline void from_float(thread Bfloat8_& dst, Float8_ src) {
239
+ // bfloat4 x;
240
+ // bfloat4 y;
241
+ // from_float(x, src.x);
242
+ // from_float(y, src.y);
243
+ // dst.x = x;
244
+ // dst.y = y;
245
+ // }
246
+
247
+ // #else
248
+
249
+ struct Bfloat2_ {
250
+ bfloat16_t x;
251
+ bfloat16_t y;
252
+ };
253
+
254
+ struct Bfloat4_ {
255
+ Bfloat2_ x;
256
+ Bfloat2_ y;
257
+ };
258
+
259
+ struct Bfloat8_ {
260
+ Bfloat4_ x;
261
+ Bfloat4_ y;
262
+ };
263
+
264
+ template <> struct Vec<bfloat16_t, 1> {
265
+ using Type = bfloat16_t;
266
+ };
267
+ template <> struct Vec<bfloat16_t, 2> {
268
+ using Type = Bfloat2_;
269
+ };
270
+ template <> struct Vec<bfloat16_t, 4> {
271
+ using Type = Bfloat4_;
272
+ };
273
+ template <> struct Vec<bfloat16_t, 8> {
274
+ using Type = Bfloat8_;
275
+ };
276
+
277
+ template <> struct FloatVec<bfloat16_t> {
278
+ using Type = float;
279
+ };
280
+ template <> struct FloatVec<Bfloat2_> {
281
+ using Type = float2;
282
+ };
283
+ template <> struct FloatVec<Bfloat4_> {
284
+ using Type = float4;
285
+ };
286
+ template <> struct FloatVec<Bfloat8_> {
287
+ using Type = Float8_;
288
+ };
289
+
290
+ template <> inline float mul(bfloat16_t a, bfloat16_t b) {
291
+ return (float)a * (float)b;
292
+ }
293
+ template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; }
294
+
295
+ template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) {
296
+ float2 a_f((float)a.x, (float)a.y);
297
+ float2 b_f((float)b.x, (float)b.y);
298
+ return a_f * b_f;
299
+ }
300
+ template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) {
301
+ Bfloat2_ c;
302
+ c.x = a.x * b.x;
303
+ c.y = a.y * b.y;
304
+ return c;
305
+ }
306
+
307
+ template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) {
308
+ float2 x = mul<float2, Bfloat2_, Bfloat2_>(a.x, b.x);
309
+ float2 y = mul<float2, Bfloat2_, Bfloat2_>(a.y, b.y);
310
+ float4 c;
311
+ c.x = x.x;
312
+ c.y = x.y;
313
+ c.z = y.x;
314
+ c.w = y.y;
315
+ return c;
316
+ }
317
+ template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) {
318
+ Bfloat4_ c;
319
+ c.x = mul<Bfloat2_, Bfloat2_, Bfloat2_>(a.x, b.x);
320
+ c.y = mul<Bfloat2_, Bfloat2_, Bfloat2_>(a.y, b.y);
321
+ return c;
322
+ }
323
+
324
+ template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) {
325
+ Float8_ c;
326
+ c.x = mul<float4, Bfloat4_, Bfloat4_>(a.x, b.x);
327
+ c.y = mul<float4, Bfloat4_, Bfloat4_>(a.y, b.y);
328
+ return c;
329
+ }
330
+ template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) {
331
+ Bfloat8_ c;
332
+ c.x = mul<Bfloat4_, Bfloat4_, Bfloat4_>(a.x, b.x);
333
+ c.y = mul<Bfloat4_, Bfloat4_, Bfloat4_>(a.y, b.y);
334
+ return c;
335
+ }
336
+
337
+ template <> inline float sum(bfloat16_t a) { return (float)a; }
338
+
339
+ template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; }
340
+
341
+ template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); }
342
+
343
+ template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); }
344
+
345
+ inline float fma(bfloat16_t a, bfloat16_t b, float c) {
346
+ return (float)a * (float)b + c;
347
+ }
348
+ inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) {
349
+ return a * b + c;
350
+ }
351
+
352
+ inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) {
353
+ float2 a_f((float)a.x, (float)a.y);
354
+ float2 b_f((float)b.x, (float)b.y);
355
+ return a_f * b_f + c;
356
+ }
357
+ inline Bfloat2_ fma(Bfloat2_ a, Bfloat2_ b, Bfloat2_ c) {
358
+ Bfloat2_ res;
359
+ res.x = a.x * b.x + c.x;
360
+ res.y = a.y * b.y + c.y;
361
+ return res;
362
+ }
363
+
364
+ inline float4 fma(Bfloat4_ a, Bfloat4_ b, float4 c) {
365
+ float4 res;
366
+ res.x = fma(a.x.x, b.x.x, c.x);
367
+ res.y = fma(a.x.y, b.x.y, c.y);
368
+ res.z = fma(a.y.x, b.y.x, c.z);
369
+ res.w = fma(a.y.y, b.y.y, c.w);
370
+ return res;
371
+ }
372
+ inline Bfloat4_ fma(Bfloat4_ a, Bfloat4_ b, Bfloat4_ c) {
373
+ Bfloat4_ res;
374
+ res.x = fma(a.x, b.x, c.x);
375
+ res.y = fma(a.y, b.y, c.y);
376
+ return res;
377
+ }
378
+
379
+ inline Float8_ fma(Bfloat8_ a, Bfloat8_ b, Float8_ c) {
380
+ float4 x = fma(a.x, b.x, c.x);
381
+ float4 y = fma(a.y, b.y, c.y);
382
+ Float8_ res;
383
+ res.x = x;
384
+ res.y = y;
385
+ return res;
386
+ }
387
+ inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) {
388
+ Bfloat8_ res;
389
+ res.x = fma(a.x, b.x, c.x);
390
+ res.y = fma(a.y, b.y, c.y);
391
+ return res;
392
+ }
393
+
394
+ inline void from_float(thread bfloat16_t &dst, float src) {
395
+ dst = static_cast<bfloat16_t>(src);
396
+ }
397
+ inline void from_float(thread Bfloat2_ &dst, float2 src) {
398
+ dst.x = static_cast<bfloat16_t>(src.x);
399
+ dst.y = static_cast<bfloat16_t>(src.y);
400
+ }
401
+ inline void from_float(thread Bfloat4_ &dst, float4 src) {
402
+ dst.x.x = static_cast<bfloat16_t>(src.x);
403
+ dst.x.y = static_cast<bfloat16_t>(src.y);
404
+ dst.y.x = static_cast<bfloat16_t>(src.z);
405
+ dst.y.y = static_cast<bfloat16_t>(src.w);
406
+ }
407
+ inline void from_float(thread Bfloat8_ &dst, Float8_ src) {
408
+ Bfloat4_ x;
409
+ Bfloat4_ y;
410
+ from_float(x, src.x);
411
+ from_float(y, src.y);
412
+ dst.x = x;
413
+ dst.y = y;
414
+ }
415
+
416
+ // #endif
417
+
418
+ // FP16 vector data types.
419
+ struct Half8_ {
420
+ half4 x;
421
+ half4 y;
422
+ };
423
+
424
+ template <> struct Vec<half, 1> {
425
+ using Type = half;
426
+ };
427
+ template <> struct Vec<half, 2> {
428
+ using Type = half2;
429
+ };
430
+ template <> struct Vec<half, 4> {
431
+ using Type = half4;
432
+ };
433
+ template <> struct Vec<half, 8> {
434
+ using Type = Half8_;
435
+ };
436
+
437
+ template <> struct FloatVec<half> {
438
+ using Type = float;
439
+ };
440
+ template <> struct FloatVec<half2> {
441
+ using Type = float2;
442
+ };
443
+ template <> struct FloatVec<half4> {
444
+ using Type = float4;
445
+ };
446
+ template <> struct FloatVec<Half8_> {
447
+ using Type = Float8_;
448
+ };
449
+
450
+ template <> inline float mul(half a, half b) { return (float)a * (float)b; }
451
+ template <> inline half mul(half a, half b) { return a * b; }
452
+
453
+ template <> inline float2 mul(half2 a, half2 b) {
454
+ return (float2)a * (float2)b;
455
+ }
456
+ template <> inline half2 mul(half2 a, half2 b) { return a * b; }
457
+
458
+ template <> inline float4 mul(half4 a, half4 b) {
459
+ return (float4)a * (float4)b;
460
+ }
461
+ template <> inline half4 mul(half4 a, half4 b) { return a * b; }
462
+
463
+ template <> inline Float8_ mul(Half8_ a, Half8_ b) {
464
+ float4 x = mul<float4, half4, half4>(a.x, b.x);
465
+ float4 y = mul<float4, half4, half4>(a.y, b.y);
466
+ Float8_ c;
467
+ c.x = x;
468
+ c.y = y;
469
+ return c;
470
+ }
471
+ template <> inline Half8_ mul(Half8_ a, Half8_ b) {
472
+ Half8_ c;
473
+ c.x = mul<half4, half4, half4>(a.x, b.x);
474
+ c.y = mul<half4, half4, half4>(a.y, b.y);
475
+ return c;
476
+ }
477
+
478
+ template <> inline float sum(half a) { return (float)a; }
479
+
480
+ template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; }
481
+
482
+ template <> inline float sum(half4 a) { return a.x + a.y + a.z + a.w; }
483
+
484
+ template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); }
485
+
486
+ inline float fma(half a, half b, float c) { return (float)a * (float)b + c; }
487
+
488
+ inline float2 fma(half2 a, half2 b, float2 c) {
489
+ return (float2)a * (float2)b + c;
490
+ }
491
+
492
+ inline float4 fma(half4 a, half4 b, float4 c) {
493
+ return (float4)a * (float4)b + c;
494
+ }
495
+
496
+ inline Float8_ fma(Half8_ a, Half8_ b, Float8_ c) {
497
+ float4 x = fma(a.x, b.x, c.x);
498
+ float4 y = fma(a.y, b.y, c.y);
499
+ Float8_ res;
500
+ res.x = x;
501
+ res.y = y;
502
+ return res;
503
+ }
504
+ inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) {
505
+ Half8_ res;
506
+ res.x = fma(a.x, b.x, c.x);
507
+ res.y = fma(a.y, b.y, c.y);
508
+ return res;
509
+ }
510
+
511
+ inline void from_float(thread half &dst, float src) {
512
+ dst = static_cast<half>(src);
513
+ }
514
+ inline void from_float(thread half2 &dst, float2 src) {
515
+ dst.x = static_cast<half>(src.x);
516
+ dst.y = static_cast<half>(src.y);
517
+ }
518
+ inline void from_float(thread half4 &dst, float4 src) {
519
+ dst.x = static_cast<half>(src.x);
520
+ dst.y = static_cast<half>(src.y);
521
+ dst.z = static_cast<half>(src.z);
522
+ dst.w = static_cast<half>(src.w);
523
+ }
524
+ inline void from_float(thread Half8_ &dst, Float8_ src) {
525
+ half4 x;
526
+ half4 y;
527
+ from_float(x, src.x);
528
+ from_float(y, src.y);
529
+ dst.x = x;
530
+ dst.y = y;
531
+ }
532
+
533
+ // ========================================== FP8 (uchar) vector data types.
534
+
535
+ // 8‑lane uchar vector – Metal only provides up to uchar4, so build our own.
536
+ struct Uchar8_ {
537
+ uchar4 x;
538
+ uchar4 y;
539
+ };
540
+
541
+ // Vec specialisations so Vec<uchar, N>::Type resolves correctly.
542
+ template <> struct Vec<uchar, 1> {
543
+ using Type = uchar;
544
+ };
545
+ template <> struct Vec<uchar, 2> {
546
+ using Type = uchar2;
547
+ };
548
+ template <> struct Vec<uchar, 4> {
549
+ using Type = uchar4;
550
+ };
551
+ template <> struct Vec<uchar, 8> {
552
+ using Type = Uchar8_;
553
+ };
554
+
555
+ // General case: not uchar
556
+ template <typename T> inline constexpr bool is_uchar() { return false; }
557
+
558
+ // Specialization: T is uchar
559
+ template <> inline constexpr bool is_uchar<uchar>() { return true; }
560
+
561
+ // Generic fallback – will fail to compile if a required specialisation is
562
+ // missing.
563
+ template <typename Vec, typename Quant_vec>
564
+ inline Vec fp8_convert(const thread Quant_vec &, float scale) {
565
+ static_assert(sizeof(Vec) == 0, "Missing fp8_convert specialisation");
566
+ }
567
+
568
+ // ========================================== FP8 → float/half/bfloat
569
+ inline float __dequant_single(uchar v, float scale) {
570
+ return fp8_e4m3_to_float(v) * scale;
571
+ }
572
+
573
+ // ---- 1‑lane ----
574
+ template <>
575
+ inline float fp8_convert<float, uchar>(const thread uchar &in, float scale) {
576
+ return __dequant_single(in, scale);
577
+ }
578
+ template <>
579
+ inline half fp8_convert<half, uchar>(const thread uchar &in, float scale) {
580
+ return half(__dequant_single(in, scale));
581
+ }
582
+ template <>
583
+ inline bfloat16_t fp8_convert<bfloat16_t, uchar>(const thread uchar &in,
584
+ float scale) {
585
+ return bfloat16_t(__dequant_single(in, scale));
586
+ }
587
+
588
+ // ---- 2‑lane ----
589
+ template <>
590
+ inline float2 fp8_convert<float2, uchar2>(const thread uchar2 &in,
591
+ float scale) {
592
+ return float2(__dequant_single(in.x, scale), __dequant_single(in.y, scale));
593
+ }
594
+ template <>
595
+ inline half2 fp8_convert<half2, uchar2>(const thread uchar2 &in, float scale) {
596
+ half2 out;
597
+ out.x = half(__dequant_single(in.x, scale));
598
+ out.y = half(__dequant_single(in.y, scale));
599
+ return out;
600
+ }
601
+ template <>
602
+ inline Bfloat2_ fp8_convert<Bfloat2_, uchar2>(const thread uchar2 &in,
603
+ float scale) {
604
+ Bfloat2_ out;
605
+ out.x = bfloat16_t(__dequant_single(in.x, scale));
606
+ out.y = bfloat16_t(__dequant_single(in.y, scale));
607
+ return out;
608
+ }
609
+
610
+ // ---- 4‑lane ----
611
+ template <>
612
+ inline float4 fp8_convert<float4, uchar4>(const thread uchar4 &in,
613
+ float scale) {
614
+ return float4(__dequant_single(in.x, scale), __dequant_single(in.y, scale),
615
+ __dequant_single(in.z, scale), __dequant_single(in.w, scale));
616
+ }
617
+ template <>
618
+ inline half4 fp8_convert<half4, uchar4>(const thread uchar4 &in, float scale) {
619
+ half4 out;
620
+ out.x = half(__dequant_single(in.x, scale));
621
+ out.y = half(__dequant_single(in.y, scale));
622
+ out.z = half(__dequant_single(in.z, scale));
623
+ out.w = half(__dequant_single(in.w, scale));
624
+ return out;
625
+ }
626
+ template <>
627
+ inline Bfloat4_ fp8_convert<Bfloat4_, uchar4>(const thread uchar4 &in,
628
+ float scale) {
629
+ Bfloat4_ out;
630
+ out.x.x = bfloat16_t(__dequant_single(in.x, scale));
631
+ out.x.y = bfloat16_t(__dequant_single(in.y, scale));
632
+ out.y.x = bfloat16_t(__dequant_single(in.z, scale));
633
+ out.y.y = bfloat16_t(__dequant_single(in.w, scale));
634
+ return out;
635
+ }
636
+
637
+ // ---- 8‑lane ----
638
+ template <>
639
+ inline Float8_ fp8_convert<Float8_, Uchar8_>(const thread Uchar8_ &in,
640
+ float scale) {
641
+ Float8_ out;
642
+ out.x =
643
+ float4(__dequant_single(in.x.x, scale), __dequant_single(in.x.y, scale),
644
+ __dequant_single(in.x.z, scale), __dequant_single(in.x.w, scale));
645
+ out.y =
646
+ float4(__dequant_single(in.y.x, scale), __dequant_single(in.y.y, scale),
647
+ __dequant_single(in.y.z, scale), __dequant_single(in.y.w, scale));
648
+ return out;
649
+ }
650
+ template <>
651
+ inline Half8_ fp8_convert<Half8_, Uchar8_>(const thread Uchar8_ &in,
652
+ float scale) {
653
+ Half8_ out;
654
+ out.x = half4(half(__dequant_single(in.x.x, scale)),
655
+ half(__dequant_single(in.x.y, scale)),
656
+ half(__dequant_single(in.x.z, scale)),
657
+ half(__dequant_single(in.x.w, scale)));
658
+ out.y = half4(half(__dequant_single(in.y.x, scale)),
659
+ half(__dequant_single(in.y.y, scale)),
660
+ half(__dequant_single(in.y.z, scale)),
661
+ half(__dequant_single(in.y.w, scale)));
662
+ return out;
663
+ }
664
+ template <>
665
+ inline Bfloat8_ fp8_convert<Bfloat8_, Uchar8_>(const thread Uchar8_ &in,
666
+ float scale) {
667
+ Bfloat8_ out;
668
+ // first 4
669
+ out.x.x.x = bfloat16_t(__dequant_single(in.x.x, scale));
670
+ out.x.x.y = bfloat16_t(__dequant_single(in.x.y, scale));
671
+ out.x.y.x = bfloat16_t(__dequant_single(in.x.z, scale));
672
+ out.x.y.y = bfloat16_t(__dequant_single(in.x.w, scale));
673
+ // second 4
674
+ out.y.x.x = bfloat16_t(__dequant_single(in.y.x, scale));
675
+ out.y.x.y = bfloat16_t(__dequant_single(in.y.y, scale));
676
+ out.y.y.x = bfloat16_t(__dequant_single(in.y.z, scale));
677
+ out.y.y.y = bfloat16_t(__dequant_single(in.y.w, scale));
678
+ return out;
679
+ }
680
+
681
+ // ========================================== Dot product utilities
682
+
683
+ // TODO(EricLBuehler): optimize with vectorization
684
+ template <int THREAD_GROUP_SIZE, typename Vec, int N>
685
+ inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) {
686
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
687
+ using A_vec = typename FloatVec<Vec>::Type;
688
+ A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
689
+ #pragma unroll
690
+ for (int ii = 1; ii < N; ++ii) {
691
+ qk_vec = fma(q[ii], k[ii], qk_vec);
692
+ }
693
+
694
+ // Finalize the reduction across lanes.
695
+ float qk = sum(qk_vec);
696
+ #pragma unroll
697
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
698
+ qk += simd_shuffle_xor(qk, mask);
699
+ }
700
+ return qk;
701
+ }
702
+
703
+ template <typename T, int THREAD_GROUP_SIZE> struct Qk_dot {
704
+ template <typename Vec, int N>
705
+ static inline float dot(const threadgroup Vec (&q)[N],
706
+ const thread Vec (&k)[N]) {
707
+ return qk_dot_<THREAD_GROUP_SIZE>(q, k);
708
+ }
709
+ };
710
+
711
+ // ========================================== Block sum utility
712
+
713
+ // Utility function for attention softmax.
714
+ template <int NUM_WARPS, int NUM_SIMD_LANES>
715
+ inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid,
716
+ uint simd_lid) {
717
+ // Compute the sum per simdgroup.
718
+ #pragma unroll
719
+ for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
720
+ sum += simd_shuffle_xor(sum, mask);
721
+ }
722
+
723
+ // Simd leaders store the data to shared memory.
724
+ if (simd_lid == 0) {
725
+ red_smem[simd_tid] = sum;
726
+ }
727
+
728
+ // Make sure the data is in shared memory.
729
+ threadgroup_barrier(mem_flags::mem_threadgroup);
730
+
731
+ // The warps compute the final sums.
732
+ if (simd_lid < NUM_WARPS) {
733
+ sum = red_smem[simd_lid];
734
+ }
735
+
736
+ // Parallel reduction inside the simd group.
737
+ #pragma unroll
738
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
739
+ sum += simd_shuffle_xor(sum, mask);
740
+ }
741
+
742
+ // Broadcast to other threads.
743
+ return simd_shuffle(sum, 0);
744
+ }
745
+
746
+ // ========================================== Paged Attention kernel
747
+
748
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
749
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
750
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
751
+
752
+ constant bool use_partitioning [[function_constant(10)]];
753
+ constant bool use_alibi [[function_constant(20)]];
754
+ constant bool use_fp8_scales [[function_constant(30)]];
755
+
756
+ template <typename T, typename CACHE_T, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS,
757
+ int NUM_SIMD_LANES, int PARTITION_SIZE = 0>
758
+ [[kernel]] void paged_attention(
759
+ device float *exp_sums
760
+ [[buffer(0)]], // [num_seqs, num_heads, max_num_partitions] - only used when
761
+ // use_partitioning
762
+ device float *max_logits
763
+ [[buffer(1)]], // [num_seqs, num_heads, max_num_partitions] - only used when
764
+ // use_partitioning
765
+ device T *out
766
+ [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size]
767
+ device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size]
768
+ device const CACHE_T *k_cache
769
+ [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x]
770
+ device const CACHE_T *v_cache
771
+ [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size]
772
+ const device float *__restrict__ k_scale
773
+ [[buffer(6)]], // [1] - only used when use_fp8_scales
774
+ const device float *__restrict__ v_scale
775
+ [[buffer(7)]], // [1] - only used when use_fp8_scales
776
+ const constant int &num_kv_heads [[buffer(8)]], // [num_heads]
777
+ const constant float &scale [[buffer(9)]],
778
+ const constant float &softcapping [[buffer(10)]],
779
+ device const uint32_t *block_tables
780
+ [[buffer(11)]], // [num_seqs, max_num_blocks_per_seq]
781
+ device const uint32_t *context_lens [[buffer(12)]], // [num_seqs]
782
+ const constant int &max_num_blocks_per_seq [[buffer(13)]],
783
+ device const float *alibi_slopes
784
+ [[buffer(14)]], // [num_heads] - only used when use_alibi
785
+ const constant int &q_stride [[buffer(15)]],
786
+ const constant int &kv_block_stride [[buffer(16)]],
787
+ const constant int &kv_head_stride [[buffer(17)]],
788
+ threadgroup char *shared_mem [[threadgroup(0)]],
789
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
790
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]],
791
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
792
+ uint simd_tid [[simdgroup_index_in_threadgroup]],
793
+ uint simd_lid [[thread_index_in_simdgroup]]) {
794
+ const int seq_idx = threadgroup_position_in_grid.y;
795
+ const int partition_idx = threadgroup_position_in_grid.z;
796
+ const int max_num_partitions = threadgroups_per_grid.z;
797
+ const int thread_idx = thread_position_in_threadgroup.x;
798
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
799
+ const uint32_t context_len = context_lens[seq_idx];
800
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
801
+ // No work to do. Terminate the thread block.
802
+ return;
803
+ }
804
+
805
+ const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
806
+ const int num_blocks_per_partition =
807
+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
808
+
809
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
810
+ const int start_block_idx =
811
+ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
812
+ const int end_block_idx =
813
+ MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
814
+ const int num_blocks = end_block_idx - start_block_idx;
815
+
816
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
817
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
818
+ const int end_token_idx =
819
+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
820
+ const int num_tokens = end_token_idx - start_token_idx;
821
+
822
+ constexpr int THREAD_GROUP_SIZE = MAX(NUM_SIMD_LANES / BLOCK_SIZE, 1);
823
+ constexpr int NUM_THREAD_GROUPS =
824
+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
825
+ // divides NUM_THREADS
826
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
827
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP =
828
+ DIVIDE_ROUND_UP(BLOCK_SIZE, NUM_SIMD_LANES);
829
+ constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
830
+ const int warp_idx = simd_tid;
831
+ const int lane = simd_lid;
832
+
833
+ const int head_idx = threadgroup_position_in_grid.x;
834
+ const int num_heads = threadgroups_per_grid.x;
835
+ const int num_queries_per_kv = num_heads / num_kv_heads;
836
+ const int kv_head_idx = head_idx / num_queries_per_kv;
837
+ const float alibi_slope = !use_alibi ? 0.f : alibi_slopes[head_idx];
838
+
839
+ // A vector type to store a part of a key or a query.
840
+ // The vector size is configured in such a way that the threads in a thread
841
+ // group fetch or compute 16 bytes at a time. For example, if the size of a
842
+ // thread group is 4 and the data type is half, then the vector size is 16 /
843
+ // (4 * sizeof(half)) == 2.
844
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(T)), 1);
845
+ using K_vec = typename Vec<T, VEC_SIZE>::Type;
846
+ using Q_vec = typename Vec<T, VEC_SIZE>::Type;
847
+ using Quant_vec = typename Vec<CACHE_T, VEC_SIZE>::Type;
848
+
849
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
850
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
851
+
852
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
853
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
854
+
855
+ // Load the query to registers.
856
+ // Each thread in a thread group has a different part of the query.
857
+ // For example, if the thread group size is 4, then the first thread in the
858
+ // group has 0, 4, 8, ... th vectors of the query, and the second thread has
859
+ // 1, 5, 9, ... th vectors of the query, and so on.
860
+ const device T *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
861
+ threadgroup Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
862
+ #pragma unroll
863
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
864
+ i += NUM_THREAD_GROUPS) {
865
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
866
+ q_vecs[thread_group_offset][i] =
867
+ *reinterpret_cast<const device Q_vec *>(q_ptr + vec_idx * VEC_SIZE);
868
+ }
869
+ threadgroup_barrier(mem_flags::mem_threadgroup);
870
+
871
+ // Use fp32 on softmax logits for better accuracy
872
+ threadgroup float *logits = reinterpret_cast<threadgroup float *>(shared_mem);
873
+ // Workspace for reduction
874
+ threadgroup float red_smem[2 * NUM_WARPS];
875
+
876
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
877
+ // Each thread group fetches x elements from the key at a time.
878
+ constexpr int x = 16 / sizeof(CACHE_T);
879
+ float qk_max = -FLT_MAX;
880
+
881
+ // Iterate over the key blocks.
882
+ // Each warp fetches a block of keys for each iteration.
883
+ // Each thread group in a warp fetches a key from the block, and computes
884
+ // dot product with the query.
885
+ const device uint32_t *block_table =
886
+ block_tables + seq_idx * max_num_blocks_per_seq;
887
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
888
+ block_idx += NUM_WARPS) {
889
+ // NOTE: The block number is stored in int32. However, we cast it to int64
890
+ // because int32 can lead to overflow when this variable is multiplied by
891
+ // large numbers (e.g., kv_block_stride).
892
+ const int64_t physical_block_number =
893
+ static_cast<int64_t>(block_table[block_idx]);
894
+
895
+ // Load a key to registers.
896
+ // Each thread in a thread group has a different part of the key.
897
+ // For example, if the thread group size is 4, then the first thread in the
898
+ // group has 0, 4, 8, ... th vectors of the key, and the second thread has
899
+ // 1, 5, 9, ... th vectors of the key, and so on.
900
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
901
+ const int physical_block_offset =
902
+ (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE;
903
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
904
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
905
+
906
+ #pragma unroll
907
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
908
+ const device CACHE_T *k_ptr =
909
+ k_cache + physical_block_number * kv_block_stride +
910
+ kv_head_idx * kv_head_stride + physical_block_offset * x;
911
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
912
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
913
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
914
+
915
+ if constexpr (is_uchar<CACHE_T>()) {
916
+ // FP8 support
917
+ Quant_vec k_vec_quant = *reinterpret_cast<const device Quant_vec *>(
918
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
919
+ k_vecs[j] = fp8_convert<K_vec, Quant_vec>(k_vec_quant, *k_scale);
920
+ } else {
921
+ // Non-FP8 default
922
+ k_vecs[j] = *reinterpret_cast<const device K_vec *>(
923
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
924
+ }
925
+ }
926
+
927
+ // Compute dot product.
928
+ // This includes a reduction across the threads in the same thread group.
929
+ float qk = scale * Qk_dot<T, THREAD_GROUP_SIZE>::dot(
930
+ q_vecs[thread_group_offset], k_vecs);
931
+
932
+ // Apply softcapping
933
+ if (softcapping != 1.0) {
934
+ qk = precise::tanh(qk / softcapping) * softcapping;
935
+ }
936
+
937
+ // Add the ALiBi bias if slopes are given.
938
+ if (use_alibi && alibi_slope != 0) {
939
+ // Compute bias with explicit float precision to minimize precision loss
940
+ int position_offset = token_idx - int(context_len) + 1;
941
+ float alibi_bias = alibi_slope * float(position_offset);
942
+ qk += alibi_bias;
943
+ }
944
+
945
+ if (thread_group_offset == 0) {
946
+ // Store the partial reductions to shared memory.
947
+ // NOTE: It is required to zero out the masked logits.
948
+ const bool mask = token_idx >= context_len;
949
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
950
+ // Update the max value.
951
+ qk_max = mask ? qk_max : max(qk_max, qk);
952
+ }
953
+ }
954
+ }
955
+
956
+ // Perform reduction across the threads in the same warp to get the
957
+ // max qk value for each "warp" (not across the thread block yet).
958
+ // The 0-th thread of each thread group already has its max qk value.
959
+ #pragma unroll
960
+ for (int mask = NUM_SIMD_LANES / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
961
+ qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
962
+ }
963
+ if (lane == 0) {
964
+ red_smem[warp_idx] = qk_max;
965
+ }
966
+ threadgroup_barrier(mem_flags::mem_threadgroup);
967
+
968
+ // Get the max qk value for the sequence.
969
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
970
+ #pragma unroll
971
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
972
+ qk_max = max(qk_max, simd_shuffle_xor(qk_max, mask));
973
+ }
974
+ // Broadcast the max qk value to all threads.
975
+ qk_max = simd_shuffle(qk_max, 0);
976
+
977
+ // Get the sum of the exp values.
978
+ float exp_sum = 0.f;
979
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
980
+ float val = exp(logits[i] - qk_max);
981
+ logits[i] = val;
982
+ exp_sum += val;
983
+ }
984
+ exp_sum = block_sum<NUM_WARPS, NUM_SIMD_LANES>(&red_smem[NUM_WARPS], exp_sum,
985
+ simd_tid, simd_lid);
986
+
987
+ // Compute softmax.
988
+ const float inv_sum = divide(1.f, exp_sum + 1e-6f);
989
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
990
+ logits[i] *= inv_sum;
991
+ }
992
+ threadgroup_barrier(mem_flags::mem_threadgroup);
993
+
994
+ // If partitioning is enabled, store the max logit and exp_sum.
995
+ if (USE_PARTITIONING && thread_idx == 0 && use_partitioning) {
996
+ device float *max_logits_ptr =
997
+ max_logits + seq_idx * num_heads * max_num_partitions +
998
+ head_idx * max_num_partitions + partition_idx;
999
+ *max_logits_ptr = qk_max;
1000
+ device float *exp_sums_ptr = exp_sums +
1001
+ seq_idx * num_heads * max_num_partitions +
1002
+ head_idx * max_num_partitions + partition_idx;
1003
+ *exp_sums_ptr = exp_sum;
1004
+ }
1005
+
1006
+ // Each thread will fetch 16 bytes from the value cache at a time.
1007
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(T), BLOCK_SIZE);
1008
+ using V_vec = typename Vec<T, V_VEC_SIZE>::Type;
1009
+ using L_vec = typename Vec<T, V_VEC_SIZE>::Type;
1010
+ using Float_L_vec = typename FloatVec<L_vec>::Type;
1011
+ using V_quant_vec = typename Vec<CACHE_T, V_VEC_SIZE>::Type;
1012
+
1013
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
1014
+ constexpr int NUM_ROWS_PER_ITER = NUM_SIMD_LANES / NUM_V_VECS_PER_ROW;
1015
+ constexpr int NUM_ROWS_PER_THREAD =
1016
+ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
1017
+
1018
+ // NOTE: We use FP32 for the accumulator for better accuracy.
1019
+ float accs[NUM_ROWS_PER_THREAD];
1020
+ #pragma unroll
1021
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1022
+ accs[i] = 0.f;
1023
+ }
1024
+
1025
+ T zero_value = 0;
1026
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
1027
+ block_idx += NUM_WARPS) {
1028
+ // NOTE: The block number is stored in int32. However, we cast it to int64
1029
+ // because int32 can lead to overflow when this variable is multiplied by
1030
+ // large numbers (e.g., kv_block_stride).
1031
+ const int64_t physical_block_number =
1032
+ static_cast<int64_t>(block_table[block_idx]);
1033
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
1034
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
1035
+ L_vec logits_vec;
1036
+ Float_L_vec logits_float_vec = *reinterpret_cast<threadgroup Float_L_vec *>(
1037
+ logits + token_idx - start_token_idx);
1038
+ from_float(logits_vec, logits_float_vec);
1039
+
1040
+ const device CACHE_T *v_ptr = v_cache + physical_block_number * kv_block_stride +
1041
+ kv_head_idx * kv_head_stride;
1042
+ #pragma unroll
1043
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1044
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
1045
+ if (row_idx < HEAD_SIZE) {
1046
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
1047
+ // NOTE: When v_vec contains the tokens that are out of the context,
1048
+ // we should explicitly zero out the values since they may contain NaNs.
1049
+ // See
1050
+ // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
1051
+ V_vec v_vec;
1052
+
1053
+ if constexpr (is_uchar<CACHE_T>()) {
1054
+ // FP8 support
1055
+ V_quant_vec v_quant_vec =
1056
+ *reinterpret_cast<const device V_quant_vec *>(v_ptr + offset);
1057
+ v_vec = fp8_convert<V_vec, V_quant_vec>(v_quant_vec, *v_scale);
1058
+ } else {
1059
+ // Non-FP8 default
1060
+ v_vec = *reinterpret_cast<const device V_vec *>(v_ptr + offset);
1061
+ }
1062
+
1063
+ if (block_idx == num_context_blocks - 1) {
1064
+ thread T *v_vec_ptr = reinterpret_cast<thread T *>(&v_vec);
1065
+ #pragma unroll
1066
+ for (int j = 0; j < V_VEC_SIZE; j++) {
1067
+ v_vec_ptr[j] =
1068
+ token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
1069
+ }
1070
+ }
1071
+ accs[i] += dot(logits_vec, v_vec);
1072
+ }
1073
+ }
1074
+ }
1075
+
1076
+ // Perform reduction within each warp.
1077
+ #pragma unroll
1078
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1079
+ float acc = accs[i];
1080
+ #pragma unroll
1081
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
1082
+ acc += simd_shuffle_xor(acc, mask);
1083
+ }
1084
+ accs[i] = acc;
1085
+ }
1086
+
1087
+ // NOTE: A barrier is required because the shared memory space for logits
1088
+ // is reused for the output.
1089
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1090
+
1091
+ // Perform reduction across warps.
1092
+ threadgroup float *out_smem =
1093
+ reinterpret_cast<threadgroup float *>(shared_mem);
1094
+ #pragma unroll
1095
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
1096
+ int mid = i / 2;
1097
+ // Upper warps write to shared memory.
1098
+ if (warp_idx >= mid && warp_idx < i) {
1099
+ threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
1100
+ #pragma unroll
1101
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1102
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
1103
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
1104
+ dst[row_idx] = accs[i];
1105
+ }
1106
+ }
1107
+ }
1108
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1109
+
1110
+ // Lower warps update the output.
1111
+ if (warp_idx < mid) {
1112
+ const threadgroup float *src = &out_smem[warp_idx * HEAD_SIZE];
1113
+ #pragma unroll
1114
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1115
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
1116
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
1117
+ accs[i] += src[row_idx];
1118
+ }
1119
+ }
1120
+ }
1121
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1122
+ }
1123
+
1124
+ // Write the final output.
1125
+ if (warp_idx == 0) {
1126
+ device T *out_ptr =
1127
+ out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1128
+ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
1129
+ #pragma unroll
1130
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
1131
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
1132
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
1133
+ *(out_ptr + row_idx) = T(accs[i]);
1134
+ }
1135
+ }
1136
+ }
1137
+ }
1138
+
1139
+ template <typename T, int HEAD_SIZE, int NUM_THREADS, int NUM_SIMD_LANES,
1140
+ int PARTITION_SIZE = 0>
1141
+ [[kernel]] void paged_attention_v2_reduce(
1142
+ device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]],
1143
+ const device float *max_logits [[buffer(2)]],
1144
+ const device T *tmp_out [[buffer(3)]],
1145
+ device uint32_t *context_lens [[buffer(4)]],
1146
+ const constant int &max_num_partitions [[buffer(5)]],
1147
+ threadgroup char *shared_mem [[threadgroup(0)]],
1148
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
1149
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]],
1150
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]],
1151
+ uint3 threads_per_threadgroup [[threads_per_threadgroup]],
1152
+ uint simd_tid [[simdgroup_index_in_threadgroup]],
1153
+ uint simd_lid [[thread_index_in_simdgroup]]) {
1154
+ const int num_heads = threadgroups_per_grid.x;
1155
+ const int head_idx = threadgroup_position_in_grid.x;
1156
+ const int seq_idx = threadgroup_position_in_grid.y;
1157
+ const uint32_t context_len = context_lens[seq_idx];
1158
+ const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
1159
+ if (num_partitions == 1) {
1160
+ // No need to reduce. Only copy tmp_out to out.
1161
+ device T *out_ptr =
1162
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
1163
+ const device T *tmp_out_ptr =
1164
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1165
+ head_idx * max_num_partitions * HEAD_SIZE;
1166
+ for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
1167
+ i += threads_per_threadgroup.x) {
1168
+ out_ptr[i] = tmp_out_ptr[i];
1169
+ }
1170
+ // Terminate the thread block.
1171
+ return;
1172
+ }
1173
+
1174
+ constexpr int NUM_WARPS = NUM_THREADS / NUM_SIMD_LANES;
1175
+ const int warp_idx = simd_tid;
1176
+ const int lane = simd_lid;
1177
+
1178
+ // Workspace for reduction.
1179
+ threadgroup float red_smem[2 * NUM_WARPS];
1180
+
1181
+ // Load max logits to shared memory.
1182
+ threadgroup float *shared_max_logits =
1183
+ reinterpret_cast<threadgroup float *>(shared_mem);
1184
+ const device float *max_logits_ptr =
1185
+ max_logits + seq_idx * num_heads * max_num_partitions +
1186
+ head_idx * max_num_partitions;
1187
+ float max_logit = -FLT_MAX;
1188
+ for (int i = thread_position_in_threadgroup.x; i < num_partitions;
1189
+ i += threads_per_threadgroup.x) {
1190
+ const float l = max_logits_ptr[i];
1191
+ shared_max_logits[i] = l;
1192
+ max_logit = max(max_logit, l);
1193
+ }
1194
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1195
+
1196
+ // Get the global max logit.
1197
+ // Reduce within the warp.
1198
+ #pragma unroll
1199
+ for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) {
1200
+ max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
1201
+ }
1202
+ if (lane == 0) {
1203
+ red_smem[warp_idx] = max_logit;
1204
+ }
1205
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1206
+ // Reduce across warps.
1207
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
1208
+ #pragma unroll
1209
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
1210
+ max_logit = max(max_logit, simd_shuffle_xor(max_logit, mask));
1211
+ }
1212
+ // Broadcast the max value to all threads.
1213
+ max_logit = simd_shuffle(max_logit, 0);
1214
+
1215
+ // Load rescaled exp sums to shared memory.
1216
+ threadgroup float *shared_exp_sums = reinterpret_cast<threadgroup float *>(
1217
+ shared_mem + sizeof(float) * num_partitions);
1218
+ const device float *exp_sums_ptr = exp_sums +
1219
+ seq_idx * num_heads * max_num_partitions +
1220
+ head_idx * max_num_partitions;
1221
+ float global_exp_sum = 0.0f;
1222
+ for (int i = thread_position_in_threadgroup.x; i < num_partitions;
1223
+ i += threads_per_threadgroup.x) {
1224
+ float l = shared_max_logits[i];
1225
+ float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit);
1226
+ global_exp_sum += rescaled_exp_sum;
1227
+ shared_exp_sums[i] = rescaled_exp_sum;
1228
+ }
1229
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1230
+ global_exp_sum = block_sum<NUM_WARPS, NUM_SIMD_LANES>(
1231
+ &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid);
1232
+ const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f);
1233
+
1234
+ // Aggregate tmp_out to out.
1235
+ const device T *tmp_out_ptr =
1236
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
1237
+ head_idx * max_num_partitions * HEAD_SIZE;
1238
+ device T *out_ptr =
1239
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
1240
+ #pragma unroll
1241
+ for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE;
1242
+ i += NUM_THREADS) {
1243
+ float acc = 0.0f;
1244
+ for (int j = 0; j < num_partitions; ++j) {
1245
+ acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
1246
+ inv_global_exp_sum;
1247
+ }
1248
+ out_ptr[i] = T(acc);
1249
+ }
1250
+ }
1251
+
1252
+ #define instantiate_paged_attention_inner(type, cache_type, head_size, \
1253
+ block_size, num_threads, \
1254
+ num_simd_lanes, partition_size) \
1255
+ template [[host_name("paged_attention_" #type "_cache_" #cache_type \
1256
+ "_hs" #head_size "_bs" #block_size "_nt" #num_threads \
1257
+ "_nsl" #num_simd_lanes \
1258
+ "_ps" #partition_size)]] [[kernel]] void \
1259
+ paged_attention<type, cache_type, head_size, block_size, num_threads, \
1260
+ num_simd_lanes, partition_size>( \
1261
+ device float *exp_sums [[buffer(0)]], \
1262
+ device float *max_logits [[buffer(1)]], \
1263
+ device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \
1264
+ device const cache_type *k_cache [[buffer(4)]], \
1265
+ device const cache_type *v_cache [[buffer(5)]], \
1266
+ const device float *__restrict__ k_scale [[buffer(6)]], \
1267
+ const device float *__restrict__ v_scale [[buffer(7)]], \
1268
+ const constant int &num_kv_heads [[buffer(8)]], \
1269
+ const constant float &scale [[buffer(9)]], \
1270
+ const constant float &softcapping [[buffer(10)]], \
1271
+ device const uint32_t *block_tables [[buffer(11)]], \
1272
+ device const uint32_t *context_lens [[buffer(12)]], \
1273
+ const constant int &max_num_blocks_per_seq [[buffer(13)]], \
1274
+ device const float *alibi_slopes [[buffer(14)]], \
1275
+ const constant int &q_stride [[buffer(15)]], \
1276
+ const constant int &kv_block_stride [[buffer(16)]], \
1277
+ const constant int &kv_head_stride [[buffer(17)]], \
1278
+ threadgroup char *shared_mem [[threadgroup(0)]], \
1279
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
1280
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
1281
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
1282
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
1283
+ uint simd_lid [[thread_index_in_simdgroup]]);
1284
+
1285
+ #define instantiate_paged_attention_v2_reduce_inner( \
1286
+ type, head_size, num_threads, num_simd_lanes, partition_size) \
1287
+ template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \
1288
+ "_nt" #num_threads "_nsl" #num_simd_lanes \
1289
+ "_ps" #partition_size)]] [[kernel]] void \
1290
+ paged_attention_v2_reduce<type, head_size, num_threads, num_simd_lanes, \
1291
+ partition_size>( \
1292
+ device type * out [[buffer(0)]], \
1293
+ const device float *exp_sums [[buffer(1)]], \
1294
+ const device float *max_logits [[buffer(2)]], \
1295
+ const device type *tmp_out [[buffer(3)]], \
1296
+ device uint32_t *context_lens [[buffer(4)]], \
1297
+ const constant int &max_num_partitions [[buffer(5)]], \
1298
+ threadgroup char *shared_mem [[threadgroup(0)]], \
1299
+ uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
1300
+ uint3 threadgroups_per_grid [[threadgroups_per_grid]], \
1301
+ uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \
1302
+ uint3 threads_per_threadgroup [[threads_per_threadgroup]], \
1303
+ uint simd_tid [[simdgroup_index_in_threadgroup]], \
1304
+ uint simd_lid [[thread_index_in_simdgroup]]);
1305
+
1306
+ #define instantiate_paged_attention_heads( \
1307
+ type, cache_type, block_size, num_threads, num_simd_lanes, partition_size) \
1308
+ instantiate_paged_attention_inner(type, cache_type, 32, block_size, \
1309
+ num_threads, num_simd_lanes, \
1310
+ partition_size); \
1311
+ instantiate_paged_attention_inner(type, cache_type, 64, block_size, \
1312
+ num_threads, num_simd_lanes, \
1313
+ partition_size); \
1314
+ instantiate_paged_attention_inner(type, cache_type, 80, block_size, \
1315
+ num_threads, num_simd_lanes, \
1316
+ partition_size); \
1317
+ instantiate_paged_attention_inner(type, cache_type, 96, block_size, \
1318
+ num_threads, num_simd_lanes, \
1319
+ partition_size); \
1320
+ instantiate_paged_attention_inner(type, cache_type, 112, block_size, \
1321
+ num_threads, num_simd_lanes, \
1322
+ partition_size); \
1323
+ instantiate_paged_attention_inner(type, cache_type, 120, block_size, \
1324
+ num_threads, num_simd_lanes, \
1325
+ partition_size); \
1326
+ instantiate_paged_attention_inner(type, cache_type, 128, block_size, \
1327
+ num_threads, num_simd_lanes, \
1328
+ partition_size); \
1329
+ instantiate_paged_attention_inner(type, cache_type, 192, block_size, \
1330
+ num_threads, num_simd_lanes, \
1331
+ partition_size); \
1332
+ instantiate_paged_attention_inner(type, cache_type, 256, block_size, \
1333
+ num_threads, num_simd_lanes, \
1334
+ partition_size);
1335
+
1336
+ #define instantiate_paged_attention_v2_reduce_heads( \
1337
+ type, num_threads, num_simd_lanes, partition_size) \
1338
+ instantiate_paged_attention_v2_reduce_inner(type, 32, num_threads, \
1339
+ num_simd_lanes, partition_size); \
1340
+ instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \
1341
+ num_simd_lanes, partition_size); \
1342
+ instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, \
1343
+ num_simd_lanes, partition_size); \
1344
+ instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, \
1345
+ num_simd_lanes, partition_size); \
1346
+ instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, \
1347
+ num_simd_lanes, partition_size); \
1348
+ instantiate_paged_attention_v2_reduce_inner(type, 120, num_threads, \
1349
+ num_simd_lanes, partition_size); \
1350
+ instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, \
1351
+ num_simd_lanes, partition_size); \
1352
+ instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, \
1353
+ num_simd_lanes, partition_size); \
1354
+ instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, \
1355
+ num_simd_lanes, partition_size);
1356
+
1357
+ #define instantiate_paged_attention_block_size(type, cache_type, num_threads, \
1358
+ num_simd_lanes, partition_size) \
1359
+ instantiate_paged_attention_heads(type, cache_type, 8, num_threads, \
1360
+ num_simd_lanes, partition_size); \
1361
+ instantiate_paged_attention_heads(type, cache_type, 16, num_threads, \
1362
+ num_simd_lanes, partition_size); \
1363
+ instantiate_paged_attention_heads(type, cache_type, 32, num_threads, \
1364
+ num_simd_lanes, partition_size);
1365
+
1366
+ // TODO: tune num_threads = 256
1367
+ // NOTE: partition_size = 0
1368
+ #define instantiate_paged_attention_v1(type, cache_type, num_simd_lanes) \
1369
+ instantiate_paged_attention_block_size(type, cache_type, 256, \
1370
+ num_simd_lanes, 0);
1371
+
1372
+ // TODO: tune num_threads = 256
1373
+ // NOTE: partition_size = 512
1374
+ #define instantiate_paged_attention_v2(type, cache_type, num_simd_lanes) \
1375
+ instantiate_paged_attention_block_size(type, cache_type, 256, \
1376
+ num_simd_lanes, 512);
1377
+
1378
+ // TODO: tune num_threads = 256
1379
+ // NOTE: partition_size = 512
1380
+ #define instantiate_paged_attention_v2_reduce(type, num_simd_lanes) \
1381
+ instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512);
1382
+
1383
+ instantiate_paged_attention_v1(float, float, 32);
1384
+ instantiate_paged_attention_v1(bfloat16_t, bfloat16_t, 32);
1385
+ instantiate_paged_attention_v1(half, half, 32);
1386
+
1387
+ instantiate_paged_attention_v1(float, uchar, 32);
1388
+ instantiate_paged_attention_v1(bfloat16_t, uchar, 32);
1389
+ instantiate_paged_attention_v1(half, uchar, 32);
1390
+
1391
+ instantiate_paged_attention_v2_reduce(float, 32);
1392
+ instantiate_paged_attention_v2_reduce(bfloat16_t, 32);
1393
+ instantiate_paged_attention_v2_reduce(half, 32);
1394
+
1395
+ instantiate_paged_attention_v2(float, float, 32);
1396
+ instantiate_paged_attention_v2(bfloat16_t, bfloat16_t, 32);
1397
+ instantiate_paged_attention_v2(half, half, 32);
1398
+
1399
+ instantiate_paged_attention_v2(float, uchar, 32);
1400
+ instantiate_paged_attention_v2(bfloat16_t, uchar, 32);
1401
+ instantiate_paged_attention_v2(half, uchar, 32);
paged-attention-metal/cache.mm ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/mps/MPSDevice.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <torch/torch.h>
4
+
5
+ #import <Foundation/Foundation.h>
6
+ #import <Metal/Metal.h>
7
+ #include <dlfcn.h>
8
+ #include <mach-o/dyld.h>
9
+ #include <string>
10
+
11
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
12
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
13
+ }
14
+
15
+ static std::string getModuleDirectory() {
16
+ Dl_info dl_info;
17
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
18
+ std::string path(dl_info.dli_fname);
19
+ size_t pos = path.find_last_of('/');
20
+ if (pos != std::string::npos) {
21
+ return path.substr(0, pos);
22
+ }
23
+ }
24
+ return ".";
25
+ }
26
+
27
+ void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
28
+ const torch::Tensor &block_mapping) {
29
+ TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
30
+
31
+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
32
+ const int64_t num_blocks = block_mapping.size(0);
33
+
34
+ // Handle different device combinations
35
+ if (src.device().is_mps() && dst.device().is_mps()) {
36
+ // MPS to MPS: Use Metal blit encoder
37
+ @autoreleasepool {
38
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
39
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
40
+
41
+ id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
42
+ TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");
43
+
44
+ dispatch_queue_t serialQueue = stream->queue();
45
+
46
+ dispatch_sync(serialQueue, ^{
47
+ id<MTLBlitCommandEncoder> blitEncoder =
48
+ [commandBuffer blitCommandEncoder];
49
+ TORCH_CHECK(blitEncoder, "Failed to create blit command encoder");
50
+
51
+ id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
52
+ id<MTLBuffer> dstBuf = getMTLBufferStorage(dst);
53
+
54
+ for (int64_t i = 0; i < num_blocks; ++i) {
55
+ int64_t src_block_number = block_mapping[i][0].item<int64_t>();
56
+ int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
57
+ NSUInteger src_offset = src_block_number * block_size_in_bytes;
58
+ NSUInteger dst_offset = dst_block_number * block_size_in_bytes;
59
+
60
+ [blitEncoder copyFromBuffer:srcBuf
61
+ sourceOffset:src_offset
62
+ toBuffer:dstBuf
63
+ destinationOffset:dst_offset
64
+ size:block_size_in_bytes];
65
+ }
66
+
67
+ [blitEncoder endEncoding];
68
+ stream->synchronize(at::mps::SyncType::COMMIT);
69
+ });
70
+ }
71
+ } else {
72
+ // Cross-device transfers (MPS-CPU, CPU-MPS, CPU-CPU): Use PyTorch's copy
73
+ for (int64_t i = 0; i < num_blocks; ++i) {
74
+ int64_t src_block_number = block_mapping[i][0].item<int64_t>();
75
+ int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
76
+
77
+ // Copy the entire block
78
+ dst[dst_block_number].copy_(src[src_block_number]);
79
+ }
80
+ }
81
+ }
82
+
83
+ void copy_blocks(const std::vector<torch::Tensor> &key_caches,
84
+ const std::vector<torch::Tensor> &value_caches,
85
+ const torch::Tensor &block_mapping) {
86
+ const int64_t num_layers = key_caches.size();
87
+ TORCH_CHECK(num_layers == static_cast<int64_t>(value_caches.size()),
88
+ "key_caches and value_caches must have the same length");
89
+ if (num_layers == 0) {
90
+ return;
91
+ }
92
+
93
+ // --- Preconditions --------------------------------------------------
94
+ torch::Device dev = key_caches[0].device();
95
+ TORCH_CHECK(dev.is_mps(), "copy_blocks: expected MPS tensors");
96
+
97
+ // Move block_mapping to CPU if it's on MPS
98
+ torch::Tensor block_mapping_cpu = block_mapping;
99
+ if (block_mapping.device().is_mps()) {
100
+ block_mapping_cpu = block_mapping.cpu();
101
+ }
102
+
103
+ for (int64_t i = 0; i < num_layers; ++i) {
104
+ TORCH_CHECK(key_caches[i].device() == dev &&
105
+ value_caches[i].device() == dev,
106
+ "All cache tensors must be on the same MPS device");
107
+ TORCH_CHECK(key_caches[i].dtype() == value_caches[i].dtype(),
108
+ "Key/value cache dtype mismatch at layer ", i);
109
+ }
110
+
111
+ const int64_t num_pairs = block_mapping.size(0);
112
+ const int32_t numel_per_block =
113
+ static_cast<int32_t>(key_caches[0][0].numel());
114
+
115
+ @autoreleasepool {
116
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
117
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
118
+
119
+ id<MTLDevice> device = stream->device();
120
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
121
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
122
+
123
+ // Construct the full path to the metallib file
124
+ std::string moduleDir = getModuleDirectory();
125
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
126
+
127
+ NSString *metallibPathStr =
128
+ [NSString stringWithUTF8String:metallibPath.c_str()];
129
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
130
+ NSError *error = nil;
131
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
132
+ if (!lib) {
133
+ NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@",
134
+ metallibPathStr, error.localizedDescription);
135
+ }
136
+
137
+ // Process each layer separately
138
+ for (int64_t layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
139
+ NSString *kernName = nil;
140
+ switch (key_caches[layer_idx].scalar_type()) {
141
+ case torch::kFloat:
142
+ kernName = @"copy_blocks_float";
143
+ break;
144
+ case torch::kHalf:
145
+ kernName = @"copy_blocks_half";
146
+ break;
147
+ case torch::kBFloat16:
148
+ kernName = @"copy_blocks_bfloat16_t";
149
+ break;
150
+ case torch::kUInt8:
151
+ kernName = @"copy_blocks_uchar";
152
+ break;
153
+ default:
154
+ TORCH_CHECK(false, "Unsupported dtype for copy_blocks");
155
+ }
156
+
157
+ id<MTLFunction> fn = [lib newFunctionWithName:kernName];
158
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String);
159
+
160
+ id<MTLComputePipelineState> pso =
161
+ [device newComputePipelineStateWithFunction:fn error:&error];
162
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
163
+
164
+ dispatch_queue_t q = stream->queue();
165
+ dispatch_sync(q, ^{
166
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
167
+ TORCH_CHECK(enc, "Failed to create compute encoder");
168
+
169
+ [enc setComputePipelineState:pso];
170
+
171
+ // Set key and value cache buffers
172
+ [enc setBuffer:getMTLBufferStorage(key_caches[layer_idx])
173
+ offset:key_caches[layer_idx].storage_offset() *
174
+ key_caches[layer_idx].element_size()
175
+ atIndex:0];
176
+ [enc setBuffer:getMTLBufferStorage(value_caches[layer_idx])
177
+ offset:value_caches[layer_idx].storage_offset() *
178
+ value_caches[layer_idx].element_size()
179
+ atIndex:1];
180
+
181
+ // Set block mapping buffer
182
+ id<MTLBuffer> mappingBuf =
183
+ [device newBufferWithBytes:block_mapping_cpu.data_ptr<int64_t>()
184
+ length:num_pairs * 2 * sizeof(int64_t)
185
+ options:MTLResourceStorageModeShared];
186
+ [enc setBuffer:mappingBuf offset:0 atIndex:2];
187
+
188
+ // Set numel_per_block as buffer
189
+ id<MTLBuffer> numelBuf =
190
+ [device newBufferWithBytes:&numel_per_block
191
+ length:sizeof(int32_t)
192
+ options:MTLResourceStorageModeShared];
193
+ [enc setBuffer:numelBuf offset:0 atIndex:3];
194
+
195
+ const uint32_t threadsPerThreadgroup =
196
+ std::min<uint32_t>(256, numel_per_block);
197
+ MTLSize tg = MTLSizeMake(threadsPerThreadgroup, 1, 1);
198
+ MTLSize grid = MTLSizeMake(threadsPerThreadgroup * num_pairs, 1, 1);
199
+
200
+ [enc dispatchThreads:grid threadsPerThreadgroup:tg];
201
+ [enc endEncoding];
202
+ });
203
+ }
204
+
205
+ stream->synchronize(at::mps::SyncType::COMMIT);
206
+ }
207
+ }
208
+
209
+ void reshape_and_cache(
210
+ torch::Tensor &key, // [num_tokens, num_heads, head_size]
211
+ torch::Tensor &value, // [num_tokens, num_heads, head_size]
212
+ torch::Tensor
213
+ &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
214
+ torch::Tensor
215
+ &value_cache, // [num_blocks, num_heads, head_size, block_size]
216
+ torch::Tensor &slot_mapping, // [num_tokens]
217
+ const std::string &kv_cache_dtype, torch::Tensor &k_scale,
218
+ torch::Tensor &v_scale) {
219
+
220
+ // Determine cache dtype and FP8 usage
221
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
222
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
223
+ if (use_fp8_scales) {
224
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
225
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
226
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
227
+ "FP8 scales must be float32");
228
+ }
229
+
230
+ TORCH_CHECK(key.device().is_mps() && value.device().is_mps() &&
231
+ key_cache.device().is_mps() && value_cache.device().is_mps(),
232
+ "All tensors must be on MPS device");
233
+
234
+ // Move slot_mapping to CPU if it's on MPS
235
+ torch::Tensor slot_mapping_cpu = slot_mapping;
236
+ if (slot_mapping.device().is_mps()) {
237
+ slot_mapping_cpu = slot_mapping.cpu();
238
+ }
239
+
240
+ const int64_t num_tokens = key.size(0);
241
+ const int64_t num_heads = key.size(1);
242
+ const int64_t head_size = key.size(2);
243
+ const int64_t block_size = key_cache.size(3);
244
+ const int64_t x = key_cache.size(4);
245
+
246
+ const int32_t key_stride = key.stride(0);
247
+ const int32_t value_stride = value.stride(0);
248
+
249
+ @autoreleasepool {
250
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
251
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
252
+
253
+ id<MTLDevice> device = stream->device();
254
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
255
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
256
+
257
+ // Construct the full path to the metallib file
258
+ std::string moduleDir = getModuleDirectory();
259
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
260
+
261
+ NSString *metallibPathStr =
262
+ [NSString stringWithUTF8String:metallibPath.c_str()];
263
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
264
+ NSError *error = nil;
265
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
266
+ if (!lib) {
267
+ NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@",
268
+ metallibPathStr, error.localizedDescription);
269
+ }
270
+
271
+ NSString *kernName = nil;
272
+ std::string kv_dtype_str, cache_dtype_str;
273
+
274
+ // Get KV dtype string
275
+ switch (key.scalar_type()) {
276
+ case torch::kFloat:
277
+ kv_dtype_str = "float";
278
+ break;
279
+ case torch::kHalf:
280
+ kv_dtype_str = "half";
281
+ break;
282
+ case torch::kBFloat16:
283
+ kv_dtype_str = "bfloat16_t";
284
+ break;
285
+ default:
286
+ TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache");
287
+ }
288
+
289
+ // Get cache dtype string
290
+ switch (cache_dtype) {
291
+ case torch::kFloat:
292
+ cache_dtype_str = "float";
293
+ break;
294
+ case torch::kHalf:
295
+ cache_dtype_str = "half";
296
+ break;
297
+ case torch::kBFloat16:
298
+ cache_dtype_str = "bfloat16_t";
299
+ break;
300
+ case torch::kUInt8:
301
+ cache_dtype_str = "uchar";
302
+ break;
303
+ default:
304
+ TORCH_CHECK(false, "Unsupported cache dtype for reshape_and_cache");
305
+ }
306
+
307
+ std::string kernName_str = "reshape_and_cache_kv_" + kv_dtype_str + "_cache_" + cache_dtype_str;
308
+ kernName = [NSString stringWithUTF8String:kernName_str.c_str()];
309
+
310
+ // Create function constants for FP8 support
311
+ MTLFunctionConstantValues *constants = [[MTLFunctionConstantValues alloc] init];
312
+ [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:10];
313
+
314
+ id<MTLFunction> fn = [lib newFunctionWithName:kernName constantValues:constants error:&error];
315
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String,
316
+ error ? [NSString stringWithFormat:@": %@", error.localizedDescription].UTF8String : "");
317
+
318
+ id<MTLComputePipelineState> pso =
319
+ [device newComputePipelineStateWithFunction:fn error:&error];
320
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
321
+
322
+ dispatch_queue_t q = stream->queue();
323
+ dispatch_sync(q, ^{
324
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
325
+ TORCH_CHECK(enc, "Failed to create compute encoder");
326
+
327
+ [enc setComputePipelineState:pso];
328
+
329
+ // Set tensor buffers
330
+ [enc setBuffer:getMTLBufferStorage(key)
331
+ offset:key.storage_offset() * key.element_size()
332
+ atIndex:0];
333
+ [enc setBuffer:getMTLBufferStorage(value)
334
+ offset:value.storage_offset() * value.element_size()
335
+ atIndex:1];
336
+ [enc setBuffer:getMTLBufferStorage(key_cache)
337
+ offset:key_cache.storage_offset() * key_cache.element_size()
338
+ atIndex:2];
339
+ [enc setBuffer:getMTLBufferStorage(value_cache)
340
+ offset:value_cache.storage_offset() * value_cache.element_size()
341
+ atIndex:3];
342
+
343
+ // Set slot mapping buffer
344
+ id<MTLBuffer> slotMappingBuf =
345
+ [device newBufferWithBytes:slot_mapping_cpu.data_ptr<int64_t>()
346
+ length:num_tokens * sizeof(int64_t)
347
+ options:MTLResourceStorageModeShared];
348
+ [enc setBuffer:slotMappingBuf offset:0 atIndex:4];
349
+
350
+ // k_scale and v_scale buffers (for FP8)
351
+ if (use_fp8_scales) {
352
+ [enc setBuffer:getMTLBufferStorage(k_scale)
353
+ offset:k_scale.storage_offset() * k_scale.element_size()
354
+ atIndex:5];
355
+ [enc setBuffer:getMTLBufferStorage(v_scale)
356
+ offset:v_scale.storage_offset() * v_scale.element_size()
357
+ atIndex:6];
358
+ } else {
359
+ // For non-FP8, we still need to increment buffer indices
360
+ // The Metal kernel expects buffers at indices 5 and 6 even if unused
361
+ }
362
+
363
+ // Set parameters as individual buffers (matching mistralrs pattern)
364
+ id<MTLBuffer> keyStrideBuf =
365
+ [device newBufferWithBytes:&key_stride
366
+ length:sizeof(int32_t)
367
+ options:MTLResourceStorageModeShared];
368
+ [enc setBuffer:keyStrideBuf offset:0 atIndex:7];
369
+
370
+ id<MTLBuffer> valueStrideBuf =
371
+ [device newBufferWithBytes:&value_stride
372
+ length:sizeof(int32_t)
373
+ options:MTLResourceStorageModeShared];
374
+ [enc setBuffer:valueStrideBuf offset:0 atIndex:8];
375
+
376
+ const int32_t num_heads_i32 = static_cast<int32_t>(num_heads);
377
+ id<MTLBuffer> numHeadsBuf =
378
+ [device newBufferWithBytes:&num_heads_i32
379
+ length:sizeof(int32_t)
380
+ options:MTLResourceStorageModeShared];
381
+ [enc setBuffer:numHeadsBuf offset:0 atIndex:9];
382
+
383
+ const int32_t head_size_i32 = static_cast<int32_t>(head_size);
384
+ id<MTLBuffer> headSizeBuf =
385
+ [device newBufferWithBytes:&head_size_i32
386
+ length:sizeof(int32_t)
387
+ options:MTLResourceStorageModeShared];
388
+ [enc setBuffer:headSizeBuf offset:0 atIndex:10];
389
+
390
+ const int32_t block_size_i32 = static_cast<int32_t>(block_size);
391
+ id<MTLBuffer> blockSizeBuf =
392
+ [device newBufferWithBytes:&block_size_i32
393
+ length:sizeof(int32_t)
394
+ options:MTLResourceStorageModeShared];
395
+ [enc setBuffer:blockSizeBuf offset:0 atIndex:11];
396
+
397
+ const int32_t x_i32 = static_cast<int32_t>(x);
398
+ id<MTLBuffer> xBuf =
399
+ [device newBufferWithBytes:&x_i32
400
+ length:sizeof(int32_t)
401
+ options:MTLResourceStorageModeShared];
402
+ [enc setBuffer:xBuf offset:0 atIndex:12];
403
+
404
+ const uint64_t threads_per_threadgroup =
405
+ std::min<uint64_t>(512, num_heads * head_size);
406
+ MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1);
407
+ MTLSize grid = MTLSizeMake(num_tokens, 1, 1);
408
+
409
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
410
+ [enc endEncoding];
411
+ });
412
+
413
+ stream->synchronize(at::mps::SyncType::COMMIT);
414
+ }
415
+ }
416
+
417
+ void reshape_and_cache_flash(
418
+ torch::Tensor &key, // [num_tokens, num_heads, head_size]
419
+ torch::Tensor &value, // [num_tokens, num_heads, head_size]
420
+ torch::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size]
421
+ torch::Tensor
422
+ &value_cache, // [num_blocks, block_size, num_heads, head_size]
423
+ torch::Tensor &slot_mapping, // [num_tokens]
424
+ const std::string &kv_cache_dtype, torch::Tensor &k_scale,
425
+ torch::Tensor &v_scale) {
426
+
427
+ TORCH_CHECK(key.device().is_mps() && value.device().is_mps() &&
428
+ key_cache.device().is_mps() && value_cache.device().is_mps(),
429
+ "All tensors must be on MPS device");
430
+
431
+ // Move slot_mapping to CPU if it's on MPS
432
+ torch::Tensor slot_mapping_cpu = slot_mapping;
433
+ if (slot_mapping.device().is_mps()) {
434
+ slot_mapping_cpu = slot_mapping.cpu();
435
+ }
436
+
437
+ const int64_t num_tokens = key.size(0);
438
+ const int64_t num_heads = key.size(1);
439
+ const int64_t head_size = key.size(2);
440
+ const int64_t block_size = key_cache.size(1);
441
+
442
+ const int32_t key_stride = key.stride(0);
443
+ const int32_t value_stride = value.stride(0);
444
+
445
+ @autoreleasepool {
446
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
447
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
448
+
449
+ id<MTLDevice> device = stream->device();
450
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
451
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
452
+
453
+ // Construct the full path to the metallib file
454
+ std::string moduleDir = getModuleDirectory();
455
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
456
+
457
+ NSString *metallibPathStr =
458
+ [NSString stringWithUTF8String:metallibPath.c_str()];
459
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
460
+ NSError *error = nil;
461
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
462
+ if (!lib) {
463
+ NSLog(@"[cache.mm] Failed to load pre-compiled Metal library at %@: %@",
464
+ metallibPathStr, error.localizedDescription);
465
+ }
466
+
467
+ NSString *kernName = nil;
468
+ switch (key.scalar_type()) {
469
+ case torch::kFloat:
470
+ kernName = @"reshape_and_cache_flash_float";
471
+ break;
472
+ case torch::kHalf:
473
+ kernName = @"reshape_and_cache_flash_half";
474
+ break;
475
+ case torch::kBFloat16:
476
+ kernName = @"reshape_and_cache_flash_bfloat16_t";
477
+ break;
478
+ default:
479
+ TORCH_CHECK(false, "Unsupported dtype for reshape_and_cache_flash");
480
+ }
481
+
482
+ id<MTLFunction> fn = [lib newFunctionWithName:kernName];
483
+ TORCH_CHECK(fn, "Missing Metal kernel function: ", kernName.UTF8String);
484
+
485
+ id<MTLComputePipelineState> pso =
486
+ [device newComputePipelineStateWithFunction:fn error:&error];
487
+ TORCH_CHECK(pso, error.localizedDescription.UTF8String);
488
+
489
+ dispatch_queue_t q = stream->queue();
490
+ dispatch_sync(q, ^{
491
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
492
+ TORCH_CHECK(enc, "Failed to create compute encoder");
493
+
494
+ [enc setComputePipelineState:pso];
495
+
496
+ // Set tensor buffers
497
+ [enc setBuffer:getMTLBufferStorage(key)
498
+ offset:key.storage_offset() * key.element_size()
499
+ atIndex:0];
500
+ [enc setBuffer:getMTLBufferStorage(value)
501
+ offset:value.storage_offset() * value.element_size()
502
+ atIndex:1];
503
+ [enc setBuffer:getMTLBufferStorage(key_cache)
504
+ offset:key_cache.storage_offset() * key_cache.element_size()
505
+ atIndex:2];
506
+ [enc setBuffer:getMTLBufferStorage(value_cache)
507
+ offset:value_cache.storage_offset() * value_cache.element_size()
508
+ atIndex:3];
509
+
510
+ // Set slot mapping buffer
511
+ id<MTLBuffer> slotMappingBuf =
512
+ [device newBufferWithBytes:slot_mapping_cpu.data_ptr<int64_t>()
513
+ length:num_tokens * sizeof(int64_t)
514
+ options:MTLResourceStorageModeShared];
515
+ [enc setBuffer:slotMappingBuf offset:0 atIndex:4];
516
+
517
+ // Set parameters as individual buffers
518
+ id<MTLBuffer> keyStrideBuf =
519
+ [device newBufferWithBytes:&key_stride
520
+ length:sizeof(int32_t)
521
+ options:MTLResourceStorageModeShared];
522
+ [enc setBuffer:keyStrideBuf offset:0 atIndex:5];
523
+
524
+ id<MTLBuffer> valueStrideBuf =
525
+ [device newBufferWithBytes:&value_stride
526
+ length:sizeof(int32_t)
527
+ options:MTLResourceStorageModeShared];
528
+ [enc setBuffer:valueStrideBuf offset:0 atIndex:6];
529
+
530
+ const int32_t num_heads_i32 = static_cast<int32_t>(num_heads);
531
+ id<MTLBuffer> numHeadsBuf =
532
+ [device newBufferWithBytes:&num_heads_i32
533
+ length:sizeof(int32_t)
534
+ options:MTLResourceStorageModeShared];
535
+ [enc setBuffer:numHeadsBuf offset:0 atIndex:7];
536
+
537
+ const int32_t head_size_i32 = static_cast<int32_t>(head_size);
538
+ id<MTLBuffer> headSizeBuf =
539
+ [device newBufferWithBytes:&head_size_i32
540
+ length:sizeof(int32_t)
541
+ options:MTLResourceStorageModeShared];
542
+ [enc setBuffer:headSizeBuf offset:0 atIndex:8];
543
+
544
+ const int32_t block_size_i32 = static_cast<int32_t>(block_size);
545
+ id<MTLBuffer> blockSizeBuf =
546
+ [device newBufferWithBytes:&block_size_i32
547
+ length:sizeof(int32_t)
548
+ options:MTLResourceStorageModeShared];
549
+ [enc setBuffer:blockSizeBuf offset:0 atIndex:9];
550
+
551
+ const uint64_t threads_per_threadgroup =
552
+ std::min<uint64_t>(512, num_heads * head_size);
553
+ MTLSize tg = MTLSizeMake(threads_per_threadgroup, 1, 1);
554
+ MTLSize grid = MTLSizeMake(num_tokens, 1, 1);
555
+
556
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:tg];
557
+ [enc endEncoding];
558
+ });
559
+
560
+ stream->synchronize(at::mps::SyncType::COMMIT);
561
+ }
562
+ }
paged-attention-metal/cache/copy_blocks.metal ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "../utils.metal"
2
+ #include <metal_stdlib>
3
+
4
+ using namespace metal;
5
+
6
+ template <typename T>
7
+ [[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]],
8
+ device T *value_cache [[buffer(1)]],
9
+ const device int64_t *block_mapping [[buffer(2)]],
10
+ device const int &numel_per_block,
11
+ uint tgid [[threadgroup_position_in_grid]],
12
+ uint tid [[thread_position_in_threadgroup]],
13
+ uint threads_per_threadgroup
14
+ [[threads_per_threadgroup]]) {
15
+ const int pair_idx = tgid;
16
+
17
+ int64_t src_block_number = block_mapping[2 * pair_idx];
18
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
19
+
20
+ const int64_t src_block_offset = src_block_number * numel_per_block;
21
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
22
+
23
+ // Copy key cache blocks
24
+ for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
25
+ int64_t src_offset = src_block_offset + i;
26
+ int64_t dst_offset = dst_block_offset + i;
27
+ key_cache[dst_offset] = key_cache[src_offset];
28
+ }
29
+
30
+ // Copy value cache blocks
31
+ for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) {
32
+ int64_t src_offset = src_block_offset + i;
33
+ int64_t dst_offset = dst_block_offset + i;
34
+ value_cache[dst_offset] = value_cache[src_offset];
35
+ }
36
+ }
37
+
38
+ #define instantiate_copy_blocks(type) \
39
+ template [[host_name("copy_blocks_" #type)]] [[kernel]] void \
40
+ copy_blocks<type>(device type * key_cache [[buffer(0)]], \
41
+ device type * value_cache [[buffer(1)]], \
42
+ const device int64_t *block_mapping [[buffer(2)]], \
43
+ device const int &numel_per_block, \
44
+ uint tgid [[threadgroup_position_in_grid]], \
45
+ uint tid [[thread_position_in_threadgroup]], \
46
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
47
+
48
+ instantiate_copy_blocks(float);
49
+ instantiate_copy_blocks(bfloat16_t);
50
+ instantiate_copy_blocks(half);
51
+ instantiate_copy_blocks(uchar);
paged-attention-metal/cache/reshape_and_cache.metal ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "../utils.metal"
2
+ #include "../float8.metal"
3
+ #include <metal_stdlib>
4
+
5
+ using namespace metal;
6
+
7
+ template <typename KV_T, typename CACHE_T>
8
+ inline CACHE_T to_cache(KV_T v) = delete;
9
+
10
+ template <> inline uchar to_cache<float, uchar>(float v) {
11
+ return float_to_fp8_e4m3(v);
12
+ }
13
+
14
+ template <> inline uchar to_cache<bfloat16_t, uchar>(bfloat16_t v) {
15
+ return float_to_fp8_e4m3((float)v);
16
+ }
17
+
18
+ template <> inline uchar to_cache<half, uchar>(half v) {
19
+ return float_to_fp8_e4m3((float)v);
20
+ }
21
+
22
+ template <> inline float to_cache<float, float>(float v) { return v; }
23
+
24
+ template <> inline bfloat16_t to_cache<bfloat16_t, bfloat16_t>(bfloat16_t v) {
25
+ return v;
26
+ }
27
+
28
+ template <> inline half to_cache<half, half>(half v) { return v; }
29
+
30
+ constant bool use_fp8_scales [[function_constant(10)]];
31
+
32
+ template <typename KV_T, typename CACHE_T>
33
+ [[kernel]] void reshape_and_cache(
34
+ const device KV_T *__restrict__ key
35
+ [[buffer(0)]], // [num_tokens, num_heads, head_size]
36
+ const device KV_T *__restrict__ value
37
+ [[buffer(1)]], // [num_tokens, num_heads, head_size]
38
+ device CACHE_T *__restrict__ key_cache
39
+ [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x]
40
+ device CACHE_T *__restrict__ value_cache
41
+ [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size]
42
+ const device int64_t *__restrict__ slot_mapping
43
+ [[buffer(4)]], // [num_tokens]
44
+ const device float *__restrict__ k_scale
45
+ [[buffer(5)]], // [1] - only used when use_fp8_scales
46
+ const device float *__restrict__ v_scale
47
+ [[buffer(6)]], // [1] - only used when use_fp8_scales
48
+ device const int &key_stride [[buffer(7)]],
49
+ device const int &value_stride [[buffer(8)]],
50
+ device const int &num_heads [[buffer(9)]],
51
+ device const int &head_size [[buffer(10)]],
52
+ device const int &block_size [[buffer(11)]],
53
+ device const int &x [[buffer(12)]],
54
+ uint gid [[threadgroup_position_in_grid]],
55
+ uint tid [[thread_position_in_threadgroup]],
56
+ uint threads_per_threadgroup [[threads_per_threadgroup]]) {
57
+ const int64_t token_idx = gid;
58
+ const int64_t slot_idx = slot_mapping[token_idx];
59
+ if (slot_idx < 0) {
60
+ // Padding token that should be ignored.
61
+ return;
62
+ }
63
+
64
+ const int64_t block_idx = slot_idx / block_size;
65
+ const int64_t block_offset = slot_idx % block_size;
66
+
67
+ const int n = num_heads * head_size;
68
+ for (int i = tid; i < n; i += threads_per_threadgroup) {
69
+ const int64_t src_key_idx = token_idx * key_stride + i;
70
+ const int64_t src_value_idx = token_idx * value_stride + i;
71
+
72
+ const int head_idx = i / head_size;
73
+ const int head_offset = i % head_size;
74
+ const int x_idx = head_offset / x;
75
+ const int x_offset = head_offset % x;
76
+
77
+ const int64_t tgt_key_idx =
78
+ block_idx * num_heads * (head_size / x) * block_size * x +
79
+ head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
80
+ block_offset * x + x_offset;
81
+ const int64_t tgt_value_idx =
82
+ block_idx * num_heads * head_size * block_size +
83
+ head_idx * head_size * block_size + head_offset * block_size +
84
+ block_offset;
85
+
86
+ if (use_fp8_scales) {
87
+ key_cache[tgt_key_idx] =
88
+ to_cache<KV_T, CACHE_T>(KV_T((float)key[src_key_idx] / *k_scale));
89
+ value_cache[tgt_value_idx] =
90
+ to_cache<KV_T, CACHE_T>(KV_T((float)value[src_value_idx] / *v_scale));
91
+ } else {
92
+ key_cache[tgt_key_idx] = to_cache<KV_T, CACHE_T>(key[src_key_idx]);
93
+ value_cache[tgt_value_idx] = to_cache<KV_T, CACHE_T>(value[src_value_idx]);
94
+ }
95
+ }
96
+ }
97
+
98
+ #define instantiate_reshape_and_cache(kv_type, cache_type) \
99
+ template [[host_name("reshape_and_cache_kv_" #kv_type \
100
+ "_cache_" #cache_type)]] [[kernel]] void \
101
+ reshape_and_cache<kv_type, cache_type>( \
102
+ const device kv_type *__restrict__ key [[buffer(0)]], \
103
+ const device kv_type *__restrict__ value [[buffer(1)]], \
104
+ device cache_type *__restrict__ key_cache [[buffer(2)]], \
105
+ device cache_type *__restrict__ value_cache [[buffer(3)]], \
106
+ const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
107
+ const device float *__restrict__ k_scale [[buffer(5)]], \
108
+ const device float *__restrict__ v_scale [[buffer(6)]], \
109
+ device const int &key_stride [[buffer(7)]], \
110
+ device const int &value_stride [[buffer(8)]], \
111
+ device const int &num_heads [[buffer(9)]], \
112
+ device const int &head_size [[buffer(10)]], \
113
+ device const int &block_size [[buffer(11)]], \
114
+ device const int &x [[buffer(12)]], \
115
+ uint gid [[threadgroup_position_in_grid]], \
116
+ uint tid [[thread_position_in_threadgroup]], \
117
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
118
+
119
+ instantiate_reshape_and_cache(float, float);
120
+ instantiate_reshape_and_cache(bfloat16_t, bfloat16_t);
121
+ instantiate_reshape_and_cache(half, half);
122
+
123
+ instantiate_reshape_and_cache(float, uchar);
124
+ instantiate_reshape_and_cache(bfloat16_t, uchar);
125
+ instantiate_reshape_and_cache(half, uchar);
126
+
127
+ // Flash version with different cache layout: [num_blocks, block_size,
128
+ // num_heads, head_size]
129
+ template <typename T>
130
+ [[kernel]] void reshape_and_cache_flash(
131
+ const device T *__restrict__ key
132
+ [[buffer(0)]], // [num_tokens, num_heads, head_size]
133
+ const device T *__restrict__ value
134
+ [[buffer(1)]], // [num_tokens, num_heads, head_size]
135
+ device T *__restrict__ key_cache
136
+ [[buffer(2)]], // [num_blocks, block_size, num_heads, head_size]
137
+ device T *__restrict__ value_cache
138
+ [[buffer(3)]], // [num_blocks, block_size, num_heads, head_size]
139
+ const device int64_t *__restrict__ slot_mapping
140
+ [[buffer(4)]], // [num_tokens]
141
+ device const int &key_stride, device const int &value_stride,
142
+ device const int &num_heads, device const int &head_size,
143
+ device const int &block_size, uint gid [[threadgroup_position_in_grid]],
144
+ uint tid [[thread_position_in_threadgroup]],
145
+ uint threads_per_threadgroup [[threads_per_threadgroup]]) {
146
+ const int64_t token_idx = gid;
147
+ const int64_t slot_idx = slot_mapping[token_idx];
148
+ if (slot_idx < 0) {
149
+ // Padding token that should be ignored.
150
+ return;
151
+ }
152
+
153
+ const int64_t block_idx = slot_idx / block_size;
154
+ const int64_t block_offset = slot_idx % block_size;
155
+
156
+ const int n = num_heads * head_size;
157
+ for (int i = tid; i < n; i += threads_per_threadgroup) {
158
+ const int64_t src_key_idx = token_idx * key_stride + i;
159
+ const int64_t src_value_idx = token_idx * value_stride + i;
160
+
161
+ const int head_idx = i / head_size;
162
+ const int head_offset = i % head_size;
163
+
164
+ // Flash cache layout: [num_blocks, block_size, num_heads, head_size]
165
+ const int64_t tgt_key_idx = block_idx * block_size * num_heads * head_size +
166
+ block_offset * num_heads * head_size +
167
+ head_idx * head_size + head_offset;
168
+ const int64_t tgt_value_idx =
169
+ block_idx * block_size * num_heads * head_size +
170
+ block_offset * num_heads * head_size + head_idx * head_size +
171
+ head_offset;
172
+ key_cache[tgt_key_idx] = key[src_key_idx];
173
+ value_cache[tgt_value_idx] = value[src_value_idx];
174
+ }
175
+ }
176
+
177
+ #define instantiate_reshape_and_cache_flash(type) \
178
+ template [[host_name("reshape_and_cache_flash_" #type)]] [[kernel]] void \
179
+ reshape_and_cache_flash<type>( \
180
+ const device type *__restrict__ key [[buffer(0)]], \
181
+ const device type *__restrict__ value [[buffer(1)]], \
182
+ device type *__restrict__ key_cache [[buffer(2)]], \
183
+ device type *__restrict__ value_cache [[buffer(3)]], \
184
+ const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \
185
+ device const int &key_stride, device const int &value_stride, \
186
+ device const int &num_heads, device const int &head_size, \
187
+ device const int &block_size, uint gid [[threadgroup_position_in_grid]], \
188
+ uint tid [[thread_position_in_threadgroup]], \
189
+ uint threads_per_threadgroup [[threads_per_threadgroup]]);
190
+
191
+ instantiate_reshape_and_cache_flash(float);
192
+ instantiate_reshape_and_cache_flash(bfloat16_t);
193
+ instantiate_reshape_and_cache_flash(half);
paged-attention-metal/convert_fp8.metal ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "float8.metal"
2
+ #include "utils.metal"
3
+ #include <metal_stdlib>
4
+
5
+ using namespace metal;
6
+
7
+ // Convert between different precision formats for cache tensors
8
+ // This kernel handles conversions like float->fp8, fp8->float, etc.
9
+
10
+ template <typename SRC_T, typename DST_T>
11
+ [[kernel]] void convert_fp8_kernel(
12
+ const device SRC_T *__restrict__ src [[buffer(0)]],
13
+ device DST_T *__restrict__ dst [[buffer(1)]],
14
+ const device float &scale [[buffer(2)]],
15
+ const device uint32_t &num_elements [[buffer(3)]],
16
+ uint gid [[thread_position_in_grid]]) {
17
+
18
+ if (gid >= num_elements) {
19
+ return;
20
+ }
21
+
22
+ // Load source value
23
+ SRC_T src_val = src[gid];
24
+
25
+ // Convert based on source and destination types
26
+ if constexpr (is_same_v<SRC_T, uchar> && !is_same_v<DST_T, uchar>) {
27
+ // FP8 -> higher precision (dequantization)
28
+ float fp32_val = fp8_e4m3_to_float(src_val) * scale;
29
+ dst[gid] = static_cast<DST_T>(fp32_val);
30
+ } else if constexpr (!is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) {
31
+ // Higher precision -> FP8 (quantization)
32
+ float fp32_val = static_cast<float>(src_val) / scale;
33
+ dst[gid] = float_to_fp8_e4m3(fp32_val);
34
+ } else if constexpr (is_same_v<SRC_T, uchar> && is_same_v<DST_T, uchar>) {
35
+ // FP8 -> FP8 (with rescaling)
36
+ float fp32_val = fp8_e4m3_to_float(src_val) * scale;
37
+ dst[gid] = float_to_fp8_e4m3(fp32_val);
38
+ } else {
39
+ // Regular precision -> regular precision (with scaling)
40
+ float fp32_val = static_cast<float>(src_val) * scale;
41
+ dst[gid] = static_cast<DST_T>(fp32_val);
42
+ }
43
+ }
44
+
45
+ // Instantiate all required combinations
46
+ #define INSTANTIATE_CONVERT_FP8(src_type, dst_type) \
47
+ template [[host_name("convert_fp8_" #src_type "_to_" #dst_type)]] \
48
+ [[kernel]] void convert_fp8_kernel<src_type, dst_type>( \
49
+ const device src_type *__restrict__ src [[buffer(0)]], \
50
+ device dst_type *__restrict__ dst [[buffer(1)]], \
51
+ const device float &scale [[buffer(2)]], \
52
+ const device uint32_t &num_elements [[buffer(3)]], \
53
+ uint gid [[thread_position_in_grid]]);
54
+
55
+ // FP8 to other formats (dequantization)
56
+ INSTANTIATE_CONVERT_FP8(uchar, float);
57
+ INSTANTIATE_CONVERT_FP8(uchar, half);
58
+ INSTANTIATE_CONVERT_FP8(uchar, bfloat16_t);
59
+
60
+ // Other formats to FP8 (quantization)
61
+ INSTANTIATE_CONVERT_FP8(float, uchar);
62
+ INSTANTIATE_CONVERT_FP8(half, uchar);
63
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, uchar);
64
+
65
+ // FP8 to FP8 (rescaling)
66
+ INSTANTIATE_CONVERT_FP8(uchar, uchar);
67
+
68
+ // Regular precision conversions with scaling
69
+ INSTANTIATE_CONVERT_FP8(float, float);
70
+ INSTANTIATE_CONVERT_FP8(float, half);
71
+ INSTANTIATE_CONVERT_FP8(float, bfloat16_t);
72
+ INSTANTIATE_CONVERT_FP8(half, float);
73
+ INSTANTIATE_CONVERT_FP8(half, half);
74
+ INSTANTIATE_CONVERT_FP8(half, bfloat16_t);
75
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, float);
76
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, half);
77
+ INSTANTIATE_CONVERT_FP8(bfloat16_t, bfloat16_t);
paged-attention-metal/convert_fp8.mm ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/mps/MPSDevice.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <torch/torch.h>
4
+
5
+ #import <Foundation/Foundation.h>
6
+ #import <Metal/Metal.h>
7
+ #include <algorithm>
8
+ #include <dlfcn.h>
9
+ #include <mach-o/dyld.h>
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
14
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
15
+ }
16
+
17
+ static std::string getModuleDirectory() {
18
+ Dl_info dl_info;
19
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
20
+ std::string path(dl_info.dli_fname);
21
+ size_t pos = path.find_last_of('/');
22
+ if (pos != std::string::npos) {
23
+ return path.substr(0, pos);
24
+ }
25
+ }
26
+ return ".";
27
+ }
28
+
29
+ // Helper function to get conversion kernel name
30
+ static std::string getConvertKernelName(torch::ScalarType src_dtype, torch::ScalarType dst_dtype) {
31
+ std::string src_str, dst_str;
32
+
33
+ auto dtype_to_string = [](torch::ScalarType dtype) -> std::string {
34
+ switch (dtype) {
35
+ case torch::kFloat: return "float";
36
+ case torch::kHalf: return "half";
37
+ case torch::kBFloat16: return "bfloat16_t";
38
+ case torch::kUInt8: return "uchar";
39
+ default:
40
+ TORCH_CHECK(false, "Unsupported dtype for convert_fp8: ", dtype);
41
+ }
42
+ };
43
+
44
+ src_str = dtype_to_string(src_dtype);
45
+ dst_str = dtype_to_string(dst_dtype);
46
+
47
+ return "convert_fp8_" + src_str + "_to_" + dst_str;
48
+ }
49
+
50
+ void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache,
51
+ const double scale, const std::string &kv_cache_dtype) {
52
+ // Validate input tensors
53
+ TORCH_CHECK(src_cache.device().is_mps() && dst_cache.device().is_mps(),
54
+ "Both tensors must be on MPS device");
55
+ TORCH_CHECK(src_cache.device() == dst_cache.device(),
56
+ "Source and destination tensors must be on the same device");
57
+ TORCH_CHECK(src_cache.numel() == dst_cache.numel(),
58
+ "Source and destination tensors must have the same number of elements");
59
+ TORCH_CHECK(src_cache.is_contiguous() && dst_cache.is_contiguous(),
60
+ "Both tensors must be contiguous");
61
+
62
+ const uint32_t num_elements = static_cast<uint32_t>(src_cache.numel());
63
+ if (num_elements == 0) {
64
+ return; // Nothing to convert
65
+ }
66
+
67
+ // Determine conversion kernel name
68
+ std::string kernel_name = getConvertKernelName(src_cache.scalar_type(), dst_cache.scalar_type());
69
+
70
+ @autoreleasepool {
71
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
72
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
73
+
74
+ id<MTLDevice> device = stream->device();
75
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
76
+ TORCH_CHECK(cmdBuf, "Failed to get command buffer");
77
+
78
+ // Load Metal library
79
+ std::string moduleDir = getModuleDirectory();
80
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
81
+ NSString *metallibPathStr = [NSString stringWithUTF8String:metallibPath.c_str()];
82
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
83
+ NSError *error = nil;
84
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
85
+ TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ",
86
+ error ? error.localizedDescription.UTF8String : "unknown error");
87
+
88
+ // Create kernel function
89
+ NSString *kernelNameStr = [NSString stringWithUTF8String:kernel_name.c_str()];
90
+ id<MTLFunction> fn = [lib newFunctionWithName:kernelNameStr];
91
+ TORCH_CHECK(fn, "Failed to find Metal kernel function: ", kernel_name);
92
+
93
+ id<MTLComputePipelineState> pso = [device newComputePipelineStateWithFunction:fn error:&error];
94
+ TORCH_CHECK(pso, "Failed to create compute pipeline state: ",
95
+ error ? error.localizedDescription.UTF8String : "unknown error");
96
+
97
+ dispatch_queue_t q = stream->queue();
98
+ dispatch_sync(q, ^{
99
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
100
+ TORCH_CHECK(enc, "Failed to create compute encoder");
101
+
102
+ [enc setComputePipelineState:pso];
103
+
104
+ // Set buffers
105
+ [enc setBuffer:getMTLBufferStorage(src_cache)
106
+ offset:src_cache.storage_offset() * src_cache.element_size()
107
+ atIndex:0];
108
+ [enc setBuffer:getMTLBufferStorage(dst_cache)
109
+ offset:dst_cache.storage_offset() * dst_cache.element_size()
110
+ atIndex:1];
111
+
112
+ // Set scale parameter
113
+ float scale_f32 = static_cast<float>(scale);
114
+ id<MTLBuffer> scaleBuf = [device newBufferWithBytes:&scale_f32
115
+ length:sizeof(float)
116
+ options:MTLResourceStorageModeShared];
117
+ [enc setBuffer:scaleBuf offset:0 atIndex:2];
118
+
119
+ // Set num_elements parameter
120
+ id<MTLBuffer> numElementsBuf = [device newBufferWithBytes:&num_elements
121
+ length:sizeof(uint32_t)
122
+ options:MTLResourceStorageModeShared];
123
+ [enc setBuffer:numElementsBuf offset:0 atIndex:3];
124
+
125
+ // Dispatch threads
126
+ const uint32_t threads_per_threadgroup = std::min<uint32_t>(1024, num_elements);
127
+ const uint32_t threadgroups = (num_elements + threads_per_threadgroup - 1) / threads_per_threadgroup;
128
+
129
+ MTLSize threadsPerThreadgroup = MTLSizeMake(threads_per_threadgroup, 1, 1);
130
+ MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1);
131
+
132
+ [enc dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
133
+ [enc endEncoding];
134
+ });
135
+
136
+ stream->synchronize(at::mps::SyncType::COMMIT);
137
+ }
138
+ }
paged-attention-metal/device.mm ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "../torch-ext/torch_binding.h"
2
+ #import <Metal/Metal.h>
3
+ #include <torch/torch.h>
4
+
5
+ int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
6
+ TORCH_CHECK(false, "get_device_attribute is not supported on Metal");
7
+ }
8
+
9
+ int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
10
+ // On macOS you can have multiple GPUs; fetch the N-th one.
11
+ NSArray<id<MTLDevice>> *all = MTLCopyAllDevices();
12
+ TORCH_CHECK(device_id >= 0 && device_id < (int64_t)all.count,
13
+ "Invalid Metal device index");
14
+
15
+ id<MTLDevice> dev = all[device_id];
16
+ return static_cast<int64_t>(dev.maxThreadgroupMemoryLength);
17
+ }
paged-attention-metal/float8.metal ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ // Helpers ------------------------------------------------------------
5
+ static inline uint as_bits(float x) { return as_type<uint>(x); }
6
+ static inline float from_bits(uint b) { return as_type<float>(b); }
7
+
8
+ // -------------------------------------------------------------------
9
+ // FP8 E4M3 (bias = 7)
10
+ // -------------------------------------------------------------------
11
+ inline float fp8_e4m3_to_float(uchar v) {
12
+ const uint s = v >> 7;
13
+ const uint exp = (v >> 3) & 0xF;
14
+ const uint man = v & 0x7;
15
+
16
+ if (exp == 0) { // zero / sub-normal
17
+ if (man == 0)
18
+ return s ? -0.f : 0.f;
19
+ const float m = float(man) / 8.f; // already scaled by 2^-3
20
+ float val = ldexp(m, 1 - 7); // 2^(1-bias) = 2^-6
21
+ return s ? -val : val;
22
+ }
23
+
24
+ if (exp == 0xF) { // Inf / NaN (E4M3FN keeps only NaN)
25
+ if (man != 0)
26
+ return NAN;
27
+ return s ? -INFINITY : INFINITY;
28
+ }
29
+
30
+ const float m = 1.f + float(man) / 8.f;
31
+ float val = ldexp(m, int(exp) - 7);
32
+ return s ? -val : val;
33
+ }
34
+
35
+ // -------------------------------------------------------------------
36
+ // FP8 E5M2 (bias = 15)
37
+ // -------------------------------------------------------------------
38
+ inline float fp8_e5m2_to_float(uchar v) {
39
+ const uint s = v >> 7;
40
+ const uint exp = (v >> 2) & 0x1F;
41
+ const uint man = v & 0x3;
42
+
43
+ if (exp == 0) {
44
+ if (man == 0)
45
+ return s ? -0.f : 0.f;
46
+ const float m = float(man) / 4.f;
47
+ float val = ldexp(m, 1 - 15); // 2^(1-bias) = 2^-14
48
+ return s ? -val : val;
49
+ }
50
+
51
+ if (exp == 0x1F) {
52
+ if (man != 0)
53
+ return NAN;
54
+ return s ? -INFINITY : INFINITY;
55
+ }
56
+
57
+ const float m = 1.f + float(man) / 4.f;
58
+ float val = ldexp(m, int(exp) - 15);
59
+ return s ? -val : val;
60
+ }
61
+
62
+ // -------------------------------------------------------------------
63
+ // Encoding helpers (round-to-nearest-even, gradual under-flow, sat-to-∞)
64
+ // -------------------------------------------------------------------
65
+ namespace detail {
66
+ template <int EXP_BITS, int MAN_BITS, int BIAS>
67
+ inline uchar fp32_to_fp8(float f) {
68
+ const uint bits = as_bits(f);
69
+ const uint s = bits >> 31;
70
+ const uint abs = bits & 0x7FFFFFFF;
71
+
72
+ // NaN propagates, Inf saturates
73
+ if (abs >= 0x7F800000u) {
74
+ return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS) |
75
+ (abs != 0x7F800000u));
76
+ }
77
+
78
+ int e = int((abs >> 23) & 0xFF) - 127; // unbiased exponent
79
+ uint m = abs & 0x7FFFFFu; // 23-bit mantissa
80
+ const int EXP_MAX = (1 << EXP_BITS) - 2; // last finite exponent
81
+
82
+ // ---------- Normal path -------------------------------------------------
83
+ int e_fp8 = e + BIAS;
84
+ if (e_fp8 >= 1 && e_fp8 <= EXP_MAX) {
85
+ // round-to-nearest-even
86
+ const int shift = 23 - MAN_BITS;
87
+ uint mant = m >> shift;
88
+ const uint lsb = mant & 1u;
89
+ const uint round = (m >> (shift - 1)) & 1u;
90
+ const uint sticky = (m & ((1u << (shift - 1)) - 1u)) != 0u;
91
+ mant += (round & (sticky | lsb));
92
+ if (mant >> MAN_BITS) { // mantissa overflow
93
+ mant = 0;
94
+ ++e_fp8;
95
+ if (e_fp8 > EXP_MAX)
96
+ return uchar((s << 7) | (((1u << EXP_BITS) - 1u) << MAN_BITS)); // ∞
97
+ }
98
+ return uchar((s << 7) | (uint(e_fp8) << MAN_BITS) |
99
+ (mant & ((1u << MAN_BITS) - 1u)));
100
+ }
101
+
102
+ // ---------- Sub-normal / under-flow ------------------------------------
103
+ if (e_fp8 < 1 - MAN_BITS) // too small -> ±0
104
+ return uchar(s << 7);
105
+
106
+ // shift so that exponent becomes 1
107
+ int rshift = (1 - e_fp8) + (23 - MAN_BITS);
108
+ uint mant = (0x800000u | m); // implicit 1
109
+ uint rounded = (mant + (1u << (rshift - 1))) >> rshift;
110
+ if (rounded == 0)
111
+ return uchar(s << 7); // rounds to zero
112
+
113
+ return uchar((s << 7) | (rounded & ((1u << MAN_BITS) - 1u)));
114
+ }
115
+ } // namespace detail
116
+
117
+ inline uchar float_to_fp8_e4m3(float f) {
118
+ return detail::fp32_to_fp8<4, 3, 7>(f);
119
+ }
120
+ inline uchar float_to_fp8_e5m2(float f) {
121
+ return detail::fp32_to_fp8<5, 2, 15>(f);
122
+ }
paged-attention-metal/paged_attention.mm ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/mps/MPSDevice.h>
2
+ #include <ATen/mps/MPSStream.h>
3
+ #include <torch/torch.h>
4
+
5
+ #import <Foundation/Foundation.h>
6
+ #import <Metal/Metal.h>
7
+ #include <algorithm>
8
+ #include <dlfcn.h>
9
+ #include <mach-o/dyld.h>
10
+ #include <string>
11
+ #include <vector>
12
+
13
+ static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
14
+ return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
15
+ }
16
+
17
+ static std::string getModuleDirectory() {
18
+ Dl_info dl_info;
19
+ if (dladdr((void *)getModuleDirectory, &dl_info)) {
20
+ std::string path(dl_info.dli_fname);
21
+ size_t pos = path.find_last_of('/');
22
+ if (pos != std::string::npos) {
23
+ return path.substr(0, pos);
24
+ }
25
+ }
26
+ return ".";
27
+ }
28
+
29
+ // Helper function to get kernel name based on dtype and parameters
30
+ static std::string getKernelName(const std::string &base_name,
31
+ torch::ScalarType dtype,
32
+ torch::ScalarType cache_dtype,
33
+ int head_size,
34
+ int block_size, int num_threads,
35
+ int num_simd_lanes, int partition_size = 0) {
36
+ std::string dtype_str;
37
+ switch (dtype) {
38
+ case torch::kFloat:
39
+ dtype_str = "float";
40
+ break;
41
+ case torch::kHalf:
42
+ dtype_str = "half";
43
+ break;
44
+ case torch::kBFloat16:
45
+ dtype_str = "bfloat16_t";
46
+ break;
47
+ default:
48
+ TORCH_CHECK(false, "Unsupported dtype for paged attention: ", dtype);
49
+ }
50
+
51
+ std::string cache_dtype_str;
52
+ switch (cache_dtype) {
53
+ case torch::kFloat:
54
+ cache_dtype_str = "float";
55
+ break;
56
+ case torch::kHalf:
57
+ cache_dtype_str = "half";
58
+ break;
59
+ case torch::kBFloat16:
60
+ cache_dtype_str = "bfloat16_t";
61
+ break;
62
+ case torch::kUInt8:
63
+ cache_dtype_str = "uchar";
64
+ break;
65
+ default:
66
+ TORCH_CHECK(false, "Unsupported cache dtype for paged attention: ", cache_dtype);
67
+ }
68
+
69
+ std::string kernel_name =
70
+ base_name + "_" + dtype_str + "_cache_" + cache_dtype_str + "_hs" + std::to_string(head_size) + "_bs" +
71
+ std::to_string(block_size) + "_nt" + std::to_string(num_threads) +
72
+ "_nsl" + std::to_string(num_simd_lanes);
73
+
74
+ if (partition_size >= 0) {
75
+ kernel_name += "_ps" + std::to_string(partition_size);
76
+ }
77
+
78
+ return kernel_name;
79
+ }
80
+
81
+ // Helper function to calculate shared memory size
82
+ static size_t calculateSharedMemorySize(int max_seq_len, int head_size,
83
+ int num_threads, int num_simd_lanes) {
84
+ // Logits storage: max_seq_len * sizeof(float)
85
+ size_t logits_size = max_seq_len * sizeof(float);
86
+
87
+ // Reduction workspace: 2 * (num_threads / num_simd_lanes) * sizeof(float)
88
+ size_t reduction_size = 2 * (num_threads / num_simd_lanes) * sizeof(float);
89
+
90
+ // Output workspace for cross-warp reduction: head_size * sizeof(float)
91
+ size_t output_size = head_size * sizeof(float);
92
+ return std::max(logits_size + reduction_size, output_size);
93
+ }
94
+
95
+ // Helper function to get supported configurations
96
+ static bool isValidConfiguration(int head_size, int block_size) {
97
+ // Supported head sizes from the Metal kernel instantiations
98
+ std::vector<int> supported_head_sizes = {32, 64, 80, 96, 112,
99
+ 120, 128, 192, 256};
100
+ std::vector<int> supported_block_sizes = {8, 16, 32};
101
+
102
+ return std::find(supported_head_sizes.begin(), supported_head_sizes.end(),
103
+ head_size) != supported_head_sizes.end() &&
104
+ std::find(supported_block_sizes.begin(), supported_block_sizes.end(),
105
+ block_size) != supported_block_sizes.end();
106
+ }
107
+
108
+ void paged_attention_v1(
109
+ torch::Tensor &out, // [num_seqs, num_heads, head_size]
110
+ torch::Tensor &query, // [num_seqs, num_heads, head_size]
111
+ torch::Tensor
112
+ &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
113
+ torch::Tensor
114
+ &value_cache, // [num_blocks, num_heads, head_size, block_size]
115
+ int64_t num_kv_heads, // [num_heads]
116
+ double scale,
117
+ torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq]
118
+ torch::Tensor &seq_lens, // [num_seqs]
119
+ int64_t block_size, int64_t max_seq_len,
120
+ const std::optional<torch::Tensor> &alibi_slopes,
121
+ const std::string &kv_cache_dtype, torch::Tensor &k_scale,
122
+ torch::Tensor &v_scale, const int64_t tp_rank,
123
+ const int64_t blocksparse_local_blocks,
124
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
125
+ const int64_t blocksparse_head_sliding_step) {
126
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
127
+
128
+ // Validate block sparse is not supported yet
129
+ // TODO: support blocksparse.
130
+ TORCH_CHECK(
131
+ !is_block_sparse,
132
+ "Block sparse attention is not yet supported in Metal implementation");
133
+
134
+ // Determine cache dtype based on kv_cache_dtype
135
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
136
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
137
+ if (use_fp8_scales) {
138
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
139
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
140
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
141
+ "FP8 scales must be float32");
142
+ }
143
+
144
+ // Validate input tensors
145
+ TORCH_CHECK(out.device().is_mps() && query.device().is_mps() &&
146
+ key_cache.device().is_mps() &&
147
+ value_cache.device().is_mps() &&
148
+ block_tables.device().is_mps() && seq_lens.device().is_mps(),
149
+ "All tensors must be on MPS device");
150
+
151
+ const int64_t num_seqs = query.size(0);
152
+ const int64_t num_heads = query.size(1);
153
+ const int64_t head_size = query.size(2);
154
+ const int64_t max_num_blocks_per_seq = block_tables.size(1);
155
+
156
+ // Validate configurations
157
+ TORCH_CHECK(isValidConfiguration(head_size, block_size),
158
+ "Unsupported head_size/block_size combination: ", head_size, "/",
159
+ block_size);
160
+
161
+ // For v1, no partitioning - each sequence processed by one threadgroup
162
+ // Kernel configuration (should match the instantiated kernels)
163
+ const int num_threads = 256;
164
+ const int num_simd_lanes = 32;
165
+ const int partition_size = 0; // v1 doesn't use partitioning
166
+
167
+ // Calculate shared memory requirements (from mistral.rs)
168
+ const int num_simds = num_threads / num_simd_lanes;
169
+ const int padded_max_context_len =
170
+ ((max_seq_len + block_size - 1) / block_size) * block_size;
171
+ const int logits_size = padded_max_context_len * sizeof(float);
172
+ const int outputs_size = (num_simds / 2) * head_size * sizeof(float);
173
+ const size_t shared_memory_size = std::max(logits_size, outputs_size);
174
+
175
+ // Get kernel name - v1 kernels have partition_size=0 in their name
176
+ std::string kernel_name =
177
+ getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size,
178
+ block_size, num_threads, num_simd_lanes, partition_size);
179
+
180
+ @autoreleasepool {
181
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
182
+
183
+ // Load Metal library
184
+ std::string moduleDir = getModuleDirectory();
185
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
186
+ NSString *metallibPathStr =
187
+ [NSString stringWithUTF8String:metallibPath.c_str()];
188
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
189
+ NSError *error = nil;
190
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
191
+ TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ",
192
+ error ? error.localizedDescription.UTF8String
193
+ : "unknown error");
194
+
195
+ // Create function constants for conditional compilation
196
+ MTLFunctionConstantValues *constants =
197
+ [[MTLFunctionConstantValues alloc] init];
198
+ bool use_partitioning = false;
199
+ bool use_alibi = alibi_slopes.has_value();
200
+ [constants setConstantValue:&use_partitioning
201
+ type:MTLDataTypeBool
202
+ atIndex:10];
203
+ [constants setConstantValue:&use_alibi type:MTLDataTypeBool atIndex:20];
204
+ [constants setConstantValue:&use_fp8_scales type:MTLDataTypeBool atIndex:30];
205
+
206
+ NSString *kernelNameStr =
207
+ [NSString stringWithUTF8String:kernel_name.c_str()];
208
+ id<MTLFunction> fn = [lib newFunctionWithName:kernelNameStr
209
+ constantValues:constants
210
+ error:&error];
211
+ TORCH_CHECK(
212
+ fn, "Failed to create Metal function '", kernel_name,
213
+ "': ", error ? error.localizedDescription.UTF8String : "unknown error");
214
+
215
+ id<MTLComputePipelineState> pso =
216
+ [device newComputePipelineStateWithFunction:fn error:&error];
217
+ TORCH_CHECK(pso, "Failed to create compute pipeline state: ",
218
+ error ? error.localizedDescription.UTF8String
219
+ : "unknown error");
220
+
221
+ // Setup command buffer and encoder
222
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
223
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
224
+
225
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
226
+ TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
227
+
228
+ dispatch_queue_t q = stream->queue();
229
+ dispatch_sync(q, ^{
230
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
231
+ TORCH_CHECK(enc, "Failed to create compute command encoder");
232
+
233
+ [enc setComputePipelineState:pso];
234
+
235
+ // Set threadgroup memory
236
+ [enc setThreadgroupMemoryLength:shared_memory_size atIndex:0];
237
+
238
+ // Buffer arguments (matching the Metal kernel signature)
239
+ int buffer_idx = 0;
240
+
241
+ // Skip exp_sums and max_logits for v1 (buffers 0, 1)
242
+ buffer_idx = 2;
243
+
244
+ // out buffer
245
+ [enc setBuffer:getMTLBufferStorage(out)
246
+ offset:out.storage_offset() * out.element_size()
247
+ atIndex:buffer_idx++];
248
+
249
+ // query buffer
250
+ [enc setBuffer:getMTLBufferStorage(query)
251
+ offset:query.storage_offset() * query.element_size()
252
+ atIndex:buffer_idx++];
253
+
254
+ // key_cache buffer
255
+ [enc setBuffer:getMTLBufferStorage(key_cache)
256
+ offset:key_cache.storage_offset() * key_cache.element_size()
257
+ atIndex:buffer_idx++];
258
+
259
+ // value_cache buffer
260
+ [enc setBuffer:getMTLBufferStorage(value_cache)
261
+ offset:value_cache.storage_offset() * value_cache.element_size()
262
+ atIndex:buffer_idx++];
263
+
264
+ // k_scale and v_scale (for FP8)
265
+ if (use_fp8_scales) {
266
+ [enc setBuffer:getMTLBufferStorage(k_scale)
267
+ offset:k_scale.storage_offset() * k_scale.element_size()
268
+ atIndex:buffer_idx++];
269
+ [enc setBuffer:getMTLBufferStorage(v_scale)
270
+ offset:v_scale.storage_offset() * v_scale.element_size()
271
+ atIndex:buffer_idx++];
272
+ } else {
273
+ buffer_idx += 2; // Skip k_scale and v_scale buffer slots
274
+ }
275
+
276
+ // num_kv_heads
277
+ int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
278
+ [enc setBytes:&num_kv_heads_i32
279
+ length:sizeof(int32_t)
280
+ atIndex:buffer_idx++];
281
+
282
+ // scale
283
+ float scale_f32 = static_cast<float>(scale);
284
+ [enc setBytes:&scale_f32 length:sizeof(float) atIndex:buffer_idx++];
285
+
286
+ // softcapping (default to 1.0 for no capping)
287
+ float softcapping = 1.0f;
288
+ [enc setBytes:&softcapping length:sizeof(float) atIndex:buffer_idx++];
289
+
290
+ // block_tables buffer
291
+ [enc setBuffer:getMTLBufferStorage(block_tables)
292
+ offset:block_tables.storage_offset() * block_tables.element_size()
293
+ atIndex:buffer_idx++];
294
+
295
+ // seq_lens buffer (context_lens in kernel)
296
+ [enc setBuffer:getMTLBufferStorage(seq_lens)
297
+ offset:seq_lens.storage_offset() * seq_lens.element_size()
298
+ atIndex:buffer_idx++];
299
+
300
+ // max_num_blocks_per_seq
301
+ int32_t max_num_blocks_per_seq_i32 =
302
+ static_cast<int32_t>(max_num_blocks_per_seq);
303
+ [enc setBytes:&max_num_blocks_per_seq_i32
304
+ length:sizeof(int32_t)
305
+ atIndex:buffer_idx++];
306
+
307
+ // alibi_slopes (optional)
308
+ if (use_alibi) {
309
+ [enc setBuffer:getMTLBufferStorage(alibi_slopes.value())
310
+ offset:alibi_slopes.value().storage_offset() *
311
+ alibi_slopes.value().element_size()
312
+ atIndex:buffer_idx++];
313
+ } else {
314
+ buffer_idx++; // Skip this buffer slot
315
+ }
316
+
317
+ // Stride parameters
318
+ int32_t q_stride = static_cast<int32_t>(query.stride(0));
319
+ int32_t kv_block_stride = static_cast<int32_t>(key_cache.stride(0));
320
+ int32_t kv_head_stride = static_cast<int32_t>(key_cache.stride(1));
321
+
322
+ [enc setBytes:&q_stride length:sizeof(int32_t) atIndex:buffer_idx++];
323
+ [enc setBytes:&kv_block_stride
324
+ length:sizeof(int32_t)
325
+ atIndex:buffer_idx++];
326
+ [enc setBytes:&kv_head_stride
327
+ length:sizeof(int32_t)
328
+ atIndex:buffer_idx++];
329
+
330
+ // Dispatch configuration
331
+ // Grid: (num_heads, num_seqs, 1) - no partitioning for v1
332
+ MTLSize grid = MTLSizeMake(num_heads, num_seqs, 1);
333
+ MTLSize threadgroup = MTLSizeMake(num_threads, 1, 1);
334
+
335
+ [enc dispatchThreadgroups:grid threadsPerThreadgroup:threadgroup];
336
+ [enc endEncoding];
337
+
338
+ stream->synchronize(at::mps::SyncType::COMMIT);
339
+ });
340
+ }
341
+ }
342
+
343
+ void paged_attention_v2(
344
+ torch::Tensor &out, // [num_seqs, num_heads, head_size]
345
+ torch::Tensor &exp_sums, // [num_seqs, num_heads, max_num_partitions]
346
+ torch::Tensor &max_logits, // [num_seqs, num_heads, max_num_partitions]
347
+ torch::Tensor
348
+ &tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
349
+ torch::Tensor &query, // [num_seqs, num_heads, head_size]
350
+ torch::Tensor
351
+ &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
352
+ torch::Tensor
353
+ &value_cache, // [num_blocks, num_heads, head_size, block_size]
354
+ int64_t num_kv_heads, // [num_heads]
355
+ double scale,
356
+ torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq]
357
+ torch::Tensor &seq_lens, // [num_seqs]
358
+ int64_t block_size, int64_t max_seq_len,
359
+ const std::optional<torch::Tensor> &alibi_slopes,
360
+ const std::string &kv_cache_dtype, torch::Tensor &k_scale,
361
+ torch::Tensor &v_scale, const int64_t tp_rank,
362
+ const int64_t blocksparse_local_blocks,
363
+ const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
364
+ const int64_t blocksparse_head_sliding_step) {
365
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
366
+
367
+ // TODO: support blocksparse.
368
+ // Validate block sparse is not supported yet
369
+ TORCH_CHECK(
370
+ !is_block_sparse,
371
+ "Block sparse attention is not yet supported in Metal implementation");
372
+
373
+ // Determine cache dtype based on kv_cache_dtype
374
+ torch::ScalarType cache_dtype = key_cache.scalar_type();
375
+ bool use_fp8_scales = (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3");
376
+ if (use_fp8_scales) {
377
+ TORCH_CHECK(cache_dtype == torch::kUInt8, "FP8 cache requires UInt8 tensor type");
378
+ TORCH_CHECK(k_scale.numel() == 1 && v_scale.numel() == 1, "FP8 scales must be scalars");
379
+ TORCH_CHECK(k_scale.scalar_type() == torch::kFloat32 && v_scale.scalar_type() == torch::kFloat32,
380
+ "FP8 scales must be float32");
381
+ }
382
+
383
+ // Validate input tensors
384
+ TORCH_CHECK(out.device().is_mps() && query.device().is_mps() &&
385
+ key_cache.device().is_mps() &&
386
+ value_cache.device().is_mps() && exp_sums.device().is_mps() &&
387
+ max_logits.device().is_mps() && tmp_out.device().is_mps() &&
388
+ block_tables.device().is_mps() && seq_lens.device().is_mps(),
389
+ "All tensors must be on MPS device");
390
+
391
+ const int64_t num_seqs = query.size(0);
392
+ const int64_t num_heads = query.size(1);
393
+ const int64_t head_size = query.size(2);
394
+ const int64_t max_num_blocks_per_seq = block_tables.size(1);
395
+ const int64_t max_num_partitions = exp_sums.size(2);
396
+
397
+ // Validate configurations
398
+ TORCH_CHECK(isValidConfiguration(head_size, block_size),
399
+ "Unsupported head_size/block_size combination: ", head_size, "/",
400
+ block_size);
401
+
402
+ // For v2, use partitioning (matching the instantiated kernels)
403
+ const int num_threads = 256;
404
+ const int num_simd_lanes = 32;
405
+ const int partition_size = 512; // v2 uses partitioning
406
+
407
+ // Calculate shared memory requirements (from mistral.rs)
408
+ const int num_simds = num_threads / num_simd_lanes;
409
+ const int logits_size = partition_size * sizeof(float);
410
+ const int outputs_size = (num_simds / 2) * head_size * sizeof(float);
411
+ const size_t shared_memory_size = std::max(logits_size, outputs_size);
412
+
413
+ // Get kernel names
414
+ std::string kernel_name =
415
+ getKernelName("paged_attention", query.scalar_type(), cache_dtype, head_size,
416
+ block_size, num_threads, num_simd_lanes, partition_size);
417
+ // Reduce kernel doesn't have block_size in its name
418
+ std::string reduce_kernel_name = "paged_attention_v2_reduce";
419
+ switch (query.scalar_type()) {
420
+ case torch::kFloat:
421
+ reduce_kernel_name += "_float";
422
+ break;
423
+ case torch::kHalf:
424
+ reduce_kernel_name += "_half";
425
+ break;
426
+ case torch::kBFloat16:
427
+ reduce_kernel_name += "_bfloat16_t";
428
+ break;
429
+ default:
430
+ TORCH_CHECK(false,
431
+ "Unsupported dtype for paged attention: ", query.scalar_type());
432
+ }
433
+ reduce_kernel_name += "_hs" + std::to_string(head_size) + "_nt" +
434
+ std::to_string(num_threads) + "_nsl" +
435
+ std::to_string(num_simd_lanes) + "_ps" +
436
+ std::to_string(partition_size);
437
+
438
+ @autoreleasepool {
439
+ id<MTLDevice> device = MTLCreateSystemDefaultDevice();
440
+
441
+ // Load Metal library
442
+ std::string moduleDir = getModuleDirectory();
443
+ std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
444
+ NSString *metallibPathStr =
445
+ [NSString stringWithUTF8String:metallibPath.c_str()];
446
+ NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
447
+ NSError *error = nil;
448
+ id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
449
+ TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath, ": ",
450
+ error ? error.localizedDescription.UTF8String
451
+ : "unknown error");
452
+
453
+ // Setup command buffer and queue
454
+ at::mps::MPSStream *stream = at::mps::getCurrentMPSStream();
455
+ TORCH_CHECK(stream, "Failed to get current MPS stream");
456
+
457
+ id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
458
+ TORCH_CHECK(cmdBuf, "Failed to get MPS command buffer");
459
+
460
+ dispatch_queue_t q = stream->queue();
461
+ dispatch_sync(q, ^{
462
+ // ==================================================================
463
+ // Phase 1: Main paged attention kernel with partitioning
464
+ // ==================================================================
465
+
466
+ // Create function constants for main kernel
467
+ MTLFunctionConstantValues *mainConstants =
468
+ [[MTLFunctionConstantValues alloc] init];
469
+ bool use_partitioning = true;
470
+ bool use_alibi = alibi_slopes.has_value();
471
+ [mainConstants setConstantValue:&use_partitioning
472
+ type:MTLDataTypeBool
473
+ atIndex:10];
474
+ [mainConstants setConstantValue:&use_alibi
475
+ type:MTLDataTypeBool
476
+ atIndex:20];
477
+ [mainConstants setConstantValue:&use_fp8_scales
478
+ type:MTLDataTypeBool
479
+ atIndex:30];
480
+
481
+ NSString *kernelNameStr =
482
+ [NSString stringWithUTF8String:kernel_name.c_str()];
483
+ NSError *mainError = nil;
484
+ id<MTLFunction> mainFn = [lib newFunctionWithName:kernelNameStr
485
+ constantValues:mainConstants
486
+ error:&mainError];
487
+ TORCH_CHECK(mainFn, "Failed to create Metal function '", kernel_name,
488
+ "': ",
489
+ mainError ? mainError.localizedDescription.UTF8String
490
+ : "unknown error");
491
+
492
+ NSError *psoError = nil;
493
+ id<MTLComputePipelineState> mainPso =
494
+ [device newComputePipelineStateWithFunction:mainFn error:&psoError];
495
+ TORCH_CHECK(mainPso, "Failed to create compute pipeline state: ",
496
+ psoError ? psoError.localizedDescription.UTF8String
497
+ : "unknown error");
498
+
499
+ id<MTLComputeCommandEncoder> enc = [cmdBuf computeCommandEncoder];
500
+ TORCH_CHECK(enc, "Failed to create compute command encoder");
501
+
502
+ [enc setComputePipelineState:mainPso];
503
+ [enc setThreadgroupMemoryLength:shared_memory_size atIndex:0];
504
+
505
+ // Set buffers for main kernel
506
+ int buffer_idx = 0;
507
+
508
+ // exp_sums buffer
509
+ [enc setBuffer:getMTLBufferStorage(exp_sums)
510
+ offset:exp_sums.storage_offset() * exp_sums.element_size()
511
+ atIndex:buffer_idx++];
512
+
513
+ // max_logits buffer
514
+ [enc setBuffer:getMTLBufferStorage(max_logits)
515
+ offset:max_logits.storage_offset() * max_logits.element_size()
516
+ atIndex:buffer_idx++];
517
+
518
+ // tmp_out buffer
519
+ [enc setBuffer:getMTLBufferStorage(tmp_out)
520
+ offset:tmp_out.storage_offset() * tmp_out.element_size()
521
+ atIndex:buffer_idx++];
522
+
523
+ // query buffer
524
+ [enc setBuffer:getMTLBufferStorage(query)
525
+ offset:query.storage_offset() * query.element_size()
526
+ atIndex:buffer_idx++];
527
+
528
+ // key_cache buffer
529
+ [enc setBuffer:getMTLBufferStorage(key_cache)
530
+ offset:key_cache.storage_offset() * key_cache.element_size()
531
+ atIndex:buffer_idx++];
532
+
533
+ // value_cache buffer
534
+ [enc setBuffer:getMTLBufferStorage(value_cache)
535
+ offset:value_cache.storage_offset() * value_cache.element_size()
536
+ atIndex:buffer_idx++];
537
+
538
+ // k_scale and v_scale (for FP8)
539
+ if (use_fp8_scales) {
540
+ [enc setBuffer:getMTLBufferStorage(k_scale)
541
+ offset:k_scale.storage_offset() * k_scale.element_size()
542
+ atIndex:buffer_idx++];
543
+ [enc setBuffer:getMTLBufferStorage(v_scale)
544
+ offset:v_scale.storage_offset() * v_scale.element_size()
545
+ atIndex:buffer_idx++];
546
+ } else {
547
+ buffer_idx += 2; // Skip k_scale and v_scale buffer slots
548
+ }
549
+
550
+ // num_kv_heads
551
+ int32_t num_kv_heads_i32 = static_cast<int32_t>(num_kv_heads);
552
+ [enc setBytes:&num_kv_heads_i32
553
+ length:sizeof(int32_t)
554
+ atIndex:buffer_idx++];
555
+
556
+ // scale
557
+ float scale_f32 = static_cast<float>(scale);
558
+ [enc setBytes:&scale_f32 length:sizeof(float) atIndex:buffer_idx++];
559
+
560
+ // softcapping (default to 1.0 for no capping)
561
+ float softcapping = 1.0f;
562
+ [enc setBytes:&softcapping length:sizeof(float) atIndex:buffer_idx++];
563
+
564
+ // block_tables buffer
565
+ [enc setBuffer:getMTLBufferStorage(block_tables)
566
+ offset:block_tables.storage_offset() * block_tables.element_size()
567
+ atIndex:buffer_idx++];
568
+
569
+ // seq_lens buffer (context_lens in kernel)
570
+ [enc setBuffer:getMTLBufferStorage(seq_lens)
571
+ offset:seq_lens.storage_offset() * seq_lens.element_size()
572
+ atIndex:buffer_idx++];
573
+
574
+ // max_num_blocks_per_seq
575
+ int32_t max_num_blocks_per_seq_i32 =
576
+ static_cast<int32_t>(max_num_blocks_per_seq);
577
+ [enc setBytes:&max_num_blocks_per_seq_i32
578
+ length:sizeof(int32_t)
579
+ atIndex:buffer_idx++];
580
+
581
+ // alibi_slopes (optional)
582
+ if (use_alibi) {
583
+ [enc setBuffer:getMTLBufferStorage(alibi_slopes.value())
584
+ offset:alibi_slopes.value().storage_offset() *
585
+ alibi_slopes.value().element_size()
586
+ atIndex:buffer_idx++];
587
+ } else {
588
+ buffer_idx++; // Skip this buffer slot
589
+ }
590
+
591
+ // Stride parameters
592
+ int32_t q_stride = static_cast<int32_t>(query.stride(0));
593
+ int32_t kv_block_stride = static_cast<int32_t>(key_cache.stride(0));
594
+ int32_t kv_head_stride = static_cast<int32_t>(key_cache.stride(1));
595
+
596
+ [enc setBytes:&q_stride length:sizeof(int32_t) atIndex:buffer_idx++];
597
+ [enc setBytes:&kv_block_stride
598
+ length:sizeof(int32_t)
599
+ atIndex:buffer_idx++];
600
+ [enc setBytes:&kv_head_stride
601
+ length:sizeof(int32_t)
602
+ atIndex:buffer_idx++];
603
+
604
+ // Dispatch main kernel
605
+ // Grid: (num_heads, num_seqs, max_num_partitions) - with partitioning for
606
+ // v2
607
+ MTLSize mainGrid = MTLSizeMake(num_heads, num_seqs, max_num_partitions);
608
+ MTLSize mainThreadgroup = MTLSizeMake(num_threads, 1, 1);
609
+
610
+ [enc dispatchThreadgroups:mainGrid threadsPerThreadgroup:mainThreadgroup];
611
+ [enc endEncoding];
612
+
613
+ // ==================================================================
614
+ // Phase 2: Reduction kernel to combine partitions
615
+ // ==================================================================
616
+
617
+ // Create reduction kernel
618
+ NSString *reduceKernelNameStr =
619
+ [NSString stringWithUTF8String:reduce_kernel_name.c_str()];
620
+ id<MTLFunction> reduceFn = [lib newFunctionWithName:reduceKernelNameStr];
621
+ TORCH_CHECK(reduceFn, "Failed to create Metal function '",
622
+ reduce_kernel_name, "'");
623
+
624
+ NSError *reducePsoError = nil;
625
+ id<MTLComputePipelineState> reducePso =
626
+ [device newComputePipelineStateWithFunction:reduceFn
627
+ error:&reducePsoError];
628
+ TORCH_CHECK(
629
+ reducePso, "Failed to create compute pipeline state for reduction: ",
630
+ reducePsoError ? reducePsoError.localizedDescription.UTF8String
631
+ : "unknown error");
632
+
633
+ // Calculate shared memory for reduction kernel
634
+ size_t reduce_shared_memory_size =
635
+ max_num_partitions * sizeof(float) * 2; // max_logits + exp_sums
636
+
637
+ id<MTLComputeCommandEncoder> reduceEnc = [cmdBuf computeCommandEncoder];
638
+ TORCH_CHECK(reduceEnc,
639
+ "Failed to create compute command encoder for reduction");
640
+
641
+ [reduceEnc setComputePipelineState:reducePso];
642
+ [reduceEnc setThreadgroupMemoryLength:reduce_shared_memory_size
643
+ atIndex:0];
644
+
645
+ // Set buffers for reduction kernel
646
+ buffer_idx = 0;
647
+
648
+ // out buffer (final output)
649
+ [reduceEnc setBuffer:getMTLBufferStorage(out)
650
+ offset:out.storage_offset() * out.element_size()
651
+ atIndex:buffer_idx++];
652
+
653
+ // exp_sums buffer
654
+ [reduceEnc setBuffer:getMTLBufferStorage(exp_sums)
655
+ offset:exp_sums.storage_offset() * exp_sums.element_size()
656
+ atIndex:buffer_idx++];
657
+
658
+ // max_logits buffer
659
+ [reduceEnc
660
+ setBuffer:getMTLBufferStorage(max_logits)
661
+ offset:max_logits.storage_offset() * max_logits.element_size()
662
+ atIndex:buffer_idx++];
663
+
664
+ // tmp_out buffer
665
+ [reduceEnc setBuffer:getMTLBufferStorage(tmp_out)
666
+ offset:tmp_out.storage_offset() * tmp_out.element_size()
667
+ atIndex:buffer_idx++];
668
+
669
+ // seq_lens buffer (context_lens in kernel)
670
+ [reduceEnc setBuffer:getMTLBufferStorage(seq_lens)
671
+ offset:seq_lens.storage_offset() * seq_lens.element_size()
672
+ atIndex:buffer_idx++];
673
+
674
+ // max_num_partitions
675
+ int32_t max_num_partitions_i32 = static_cast<int32_t>(max_num_partitions);
676
+ [reduceEnc setBytes:&max_num_partitions_i32
677
+ length:sizeof(int32_t)
678
+ atIndex:buffer_idx++];
679
+
680
+ // Dispatch reduction kernel
681
+ // Grid: (num_heads, num_seqs) - one threadgroup per sequence/head
682
+ // combination
683
+ MTLSize reduceGrid = MTLSizeMake(num_heads, num_seqs, 1);
684
+ MTLSize reduceThreadgroup = MTLSizeMake(num_threads, 1, 1);
685
+
686
+ [reduceEnc dispatchThreadgroups:reduceGrid
687
+ threadsPerThreadgroup:reduceThreadgroup];
688
+ [reduceEnc endEncoding];
689
+
690
+ stream->synchronize(at::mps::SyncType::COMMIT);
691
+ });
692
+ }
693
+ }
paged-attention-metal/utils.metal ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <metal_stdlib>
2
+ using namespace metal;
3
+
4
+ #if defined(__HAVE_BFLOAT__)
5
+
6
+ typedef bfloat bfloat16_t;
7
+
8
+ #else
9
+
10
+ /////////////////////////////////////////////////////////////////////////////
11
+ // Helpers
12
+ /////////////////////////////////////////////////////////////////////////////
13
+
14
+ constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
15
+ // Check for nan
16
+ if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
17
+ _fp_encoding_traits<float>::inf_mask) {
18
+ return uint16_t(as_type<uint32_t>(0x7FC0));
19
+ }
20
+ // Take bits
21
+ uint32_t float_bits = as_type<uint32_t>(x);
22
+
23
+ // Round to nearest even
24
+ float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
25
+
26
+ // Take upper 16 bits
27
+ return float_bits >> 16;
28
+ }
29
+
30
+ constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
31
+ // Upper 16 bits are the data and lower 16 bits are 0s
32
+ return as_type<float>((uint32_t)x << 16);
33
+ }
34
+
35
+ struct _MLX_BFloat16;
36
+
37
+ template <typename T>
38
+ static constexpr constant bool can_convert_to_bfloat =
39
+ !is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
40
+
41
+ template <typename T>
42
+ static constexpr constant bool can_convert_from_bfloat =
43
+ !is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
44
+
45
+ /////////////////////////////////////////////////////////////////////////////
46
+ // Bfloat struct
47
+ /////////////////////////////////////////////////////////////////////////////
48
+
49
+ struct _MLX_BFloat16 {
50
+ /////////////////////////////////////////////////////////////////////////////
51
+ // Constructors
52
+ uint16_t bits_;
53
+ _MLX_BFloat16() thread = default;
54
+ _MLX_BFloat16() threadgroup = default;
55
+ _MLX_BFloat16() device = default;
56
+ _MLX_BFloat16() constant = default;
57
+
58
+ struct bits_to_bfloat_struct {};
59
+ static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
60
+ return bits_to_bfloat_struct();
61
+ }
62
+ constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
63
+ : bits_(bits) {}
64
+
65
+ /////////////////////////////////////////////////////////////////////////////
66
+ // Conversions to bfloat
67
+
68
+ template <typename T,
69
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
70
+ constexpr METAL_FUNC _MLX_BFloat16(T x) thread
71
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
72
+
73
+ template <typename T,
74
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
75
+ constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
76
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
77
+
78
+ template <typename T,
79
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
80
+ constexpr METAL_FUNC _MLX_BFloat16(T x) device
81
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
82
+
83
+ template <typename T,
84
+ typename = typename enable_if<can_convert_to_bfloat<T>>::type>
85
+ constexpr METAL_FUNC _MLX_BFloat16(T x) constant
86
+ : bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
87
+
88
+ /////////////////////////////////////////////////////////////////////////////
89
+ // Conversions from bfloat
90
+
91
+ template <typename T,
92
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
93
+ constexpr METAL_FUNC operator T() const thread {
94
+ return static_cast<T>(bfloat_bits_to_float(bits_));
95
+ }
96
+
97
+ template <typename T,
98
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
99
+ constexpr METAL_FUNC operator T() const threadgroup {
100
+ return static_cast<T>(bfloat_bits_to_float(bits_));
101
+ }
102
+
103
+ template <typename T,
104
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
105
+ constexpr METAL_FUNC operator T() const device {
106
+ return static_cast<T>(bfloat_bits_to_float(bits_));
107
+ }
108
+
109
+ template <typename T,
110
+ typename = typename enable_if<can_convert_from_bfloat<T>>::type>
111
+ constexpr METAL_FUNC operator T() constant {
112
+ return static_cast<T>(bfloat_bits_to_float(bits_));
113
+ }
114
+ };
115
+
116
+ /////////////////////////////////////////////////////////////////////////////
117
+ // Bfloat operators
118
+ /////////////////////////////////////////////////////////////////////////////
119
+
120
+ /////////////////////////////////////////////////////////////////////////////
121
+ // Unary ops
122
+ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
123
+ return -static_cast<float>(x);
124
+ }
125
+
126
+ /////////////////////////////////////////////////////////////////////////////
127
+ // Binary operators
128
+ #define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
129
+ constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
130
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
131
+ }
132
+
133
+ #define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
134
+ constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
135
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
136
+ } \
137
+ constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
138
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
139
+ }
140
+
141
+ /////////////////////////////////////////////////////////////////////////////
142
+ // Arithmetic Operators
143
+ #define bfloat_binop(_op_, _operator_) \
144
+ bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \
145
+ _MLX_BFloat16, float); \
146
+ bfloat_binop_helper(_op_, _operator_, float, float, float); \
147
+ bfloat_binop_helper(_op_, _operator_, float, half, float); \
148
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
149
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
150
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
151
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
152
+
153
+ bfloat_binop(+, operator+);
154
+ bfloat_binop(-, operator-);
155
+ bfloat_binop(*, operator*);
156
+ bfloat_binop(/, operator/);
157
+
158
+ /////////////////////////////////////////////////////////////////////////////
159
+ // Comparison ops
160
+ #define bfloat_compop(__op__, __operator__) \
161
+ bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \
162
+ float); \
163
+ bfloat_binop_helper(__op__, __operator__, bool, float, float); \
164
+ bfloat_binop_helper(__op__, __operator__, bool, half, float); \
165
+ bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
166
+ bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
167
+ bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
168
+ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
169
+
170
+ bfloat_compop(>, operator>);
171
+ bfloat_compop(<, operator<);
172
+ bfloat_compop(>=, operator>=);
173
+ bfloat_compop(<=, operator<=);
174
+ bfloat_compop(==, operator==);
175
+ bfloat_compop(!=, operator!=);
176
+
177
+ #undef bfloat_compop
178
+ #undef bfloat_binop_base
179
+ #undef bfloat_binop_helper
180
+ #undef bfloat_binop
181
+
182
+ /////////////////////////////////////////////////////////////////////////////
183
+ // Inplace Operators
184
+ #define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
185
+ constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \
186
+ addr_space _MLX_BFloat16 &lhs, itype rhs) { \
187
+ lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
188
+ return lhs; \
189
+ } \
190
+ constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \
191
+ _MLX_BFloat16 rhs) { \
192
+ lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
193
+ return lhs; \
194
+ }
195
+
196
+ #define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
197
+ bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
198
+ bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
199
+ bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
200
+
201
+ #define bfloat_inplace_op(itype) \
202
+ bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
203
+ bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
204
+ bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
205
+ bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
206
+
207
+ bfloat_inplace_op(float);
208
+ bfloat_inplace_op(half);
209
+ bfloat_inplace_op(int16_t);
210
+ bfloat_inplace_op(int32_t);
211
+ bfloat_inplace_op(int64_t);
212
+ bfloat_inplace_op(uint16_t);
213
+ bfloat_inplace_op(uint32_t);
214
+ bfloat_inplace_op(uint64_t);
215
+
216
+ #undef bfloat_inplace_op_helper
217
+ #undef bfloat_inplace_op_addr_space_helper
218
+ #undef bfloat_inplace_op
219
+
220
+ #define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
221
+ constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \
222
+ addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \
223
+ lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
224
+ return lhs; \
225
+ }
226
+
227
+ #define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
228
+ bfloat_inplace_op_helper(__op__, __operator__, device); \
229
+ bfloat_inplace_op_helper(__op__, __operator__, thread); \
230
+ bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
231
+
232
+ bfloat_inplace_op_addr_space_helper(+, operator+=);
233
+ bfloat_inplace_op_addr_space_helper(-, operator-=);
234
+ bfloat_inplace_op_addr_space_helper(*, operator*=);
235
+ bfloat_inplace_op_addr_space_helper(/, operator/=);
236
+
237
+ #undef bfloat_inplace_op_helper
238
+ #undef bfloat_inplace_op_addr_space_helper
239
+
240
+ /////////////////////////////////////////////////////////////////////////////
241
+ // Bfloat typedef
242
+ /////////////////////////////////////////////////////////////////////////////
243
+
244
+ typedef struct _MLX_BFloat16 bfloat16_t;
245
+
246
+ #endif
tests/kernels/test_attention.py CHANGED
@@ -33,10 +33,15 @@ HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256]
33
 
34
  BLOCK_SIZES = [16, 32]
35
  USE_ALIBI = [False, True]
36
- KV_CACHE_DTYPE = ["auto", "fp8"]
 
 
 
37
  SEEDS = [0]
38
- CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
39
-
 
 
40
 
41
  def ref_masked_attention(
42
  query: torch.Tensor,
@@ -119,7 +124,7 @@ def ref_single_query_cached_kv_attention(
119
  @pytest.mark.parametrize("dtype", DTYPES)
120
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
121
  @pytest.mark.parametrize("seed", SEEDS)
122
- @pytest.mark.parametrize("device", CUDA_DEVICES)
123
  def test_paged_attention(
124
  kv_cache_factory,
125
  version: str,
@@ -227,7 +232,7 @@ def test_paged_attention(
227
  64,
228
  0,
229
  ),
230
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
231
  )
232
 
233
  elif version in ("v2", "rocm"):
@@ -290,7 +295,7 @@ def test_paged_attention(
290
  64,
291
  0,
292
  ),
293
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
294
  )
295
 
296
  else:
@@ -335,7 +340,7 @@ def test_paged_attention(
335
  k_scale,
336
  v_scale,
337
  ),
338
- cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
339
  )
340
 
341
  else:
@@ -383,6 +388,9 @@ def test_paged_attention(
383
  atol, rtol = 1e-3, 1e-5
384
  if kv_cache_dtype == "fp8":
385
  atol, rtol = 1e-2, 1e-5
 
 
 
386
  torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
387
 
388
 
 
33
 
34
  BLOCK_SIZES = [16, 32]
35
  USE_ALIBI = [False, True]
36
+ if current_platform.is_mps():
37
+ KV_CACHE_DTYPE = ["auto", "fp8"]
38
+ else:
39
+ KV_CACHE_DTYPE = ["auto", "fp8"]
40
  SEEDS = [0]
41
+ if current_platform.is_mps():
42
+ DEVICES = ["mps:0"]
43
+ else:
44
+ DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
45
 
46
  def ref_masked_attention(
47
  query: torch.Tensor,
 
124
  @pytest.mark.parametrize("dtype", DTYPES)
125
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
126
  @pytest.mark.parametrize("seed", SEEDS)
127
+ @pytest.mark.parametrize("device", DEVICES)
128
  def test_paged_attention(
129
  kv_cache_factory,
130
  version: str,
 
232
  64,
233
  0,
234
  ),
235
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
236
  )
237
 
238
  elif version in ("v2", "rocm"):
 
295
  64,
296
  0,
297
  ),
298
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
299
  )
300
 
301
  else:
 
340
  k_scale,
341
  v_scale,
342
  ),
343
+ cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0] and not device.startswith("mps")),
344
  )
345
 
346
  else:
 
388
  atol, rtol = 1e-3, 1e-5
389
  if kv_cache_dtype == "fp8":
390
  atol, rtol = 1e-2, 1e-5
391
+ # NOTE: bfloat16 with ALiBi can have slightly higher precision differences
392
+ elif dtype == torch.bfloat16 and use_alibi:
393
+ atol, rtol = 2e-3, 1e-5
394
  torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
395
 
396
 
tests/kernels/test_cache.py CHANGED
@@ -8,7 +8,7 @@ from paged_attention.platforms import current_platform
8
 
9
  from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
 
11
- COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
12
  DTYPES = [torch.half, torch.bfloat16, torch.float]
13
  NUM_TOKENS = [42] # Arbitrary values for testing
14
  NUM_LAYERS = [1] # Arbitrary values for testing
@@ -22,10 +22,15 @@ NUM_BLOCKS = [1024, 10000]
22
 
23
  NUM_MAPPINGS = [256] # Arbitrary values for testing
24
  SEEDS = [0]
25
- CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
 
26
 
27
- # We assume fp8 is always enabled for testing.
28
- KV_CACHE_DTYPE = ["auto", "fp8"]
 
 
29
 
30
 
31
  @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@@ -36,7 +41,7 @@ KV_CACHE_DTYPE = ["auto", "fp8"]
36
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
37
  @pytest.mark.parametrize("dtype", DTYPES)
38
  @pytest.mark.parametrize("seed", SEEDS)
39
- @pytest.mark.parametrize("device", CUDA_DEVICES)
40
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
41
  @torch.inference_mode()
42
  def test_copy_blocks(
@@ -121,7 +126,7 @@ def test_copy_blocks(
121
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
122
  @pytest.mark.parametrize("dtype", DTYPES)
123
  @pytest.mark.parametrize("seed", SEEDS)
124
- @pytest.mark.parametrize("device", CUDA_DEVICES)
125
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
126
  @torch.inference_mode()
127
  def test_reshape_and_cache(
@@ -221,10 +226,10 @@ def test_reshape_and_cache(
221
 
222
  if kv_cache_dtype == "fp8":
223
  torch.testing.assert_close(
224
- result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
225
  )
226
  torch.testing.assert_close(
227
- result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
228
  )
229
  else:
230
  torch.testing.assert_close(key_cache, cloned_key_cache)
@@ -238,7 +243,7 @@ def test_reshape_and_cache(
238
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
239
  @pytest.mark.parametrize("dtype", DTYPES)
240
  @pytest.mark.parametrize("seed", SEEDS)
241
- @pytest.mark.parametrize("device", CUDA_DEVICES)
242
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
243
  @torch.inference_mode()
244
  def test_reshape_and_cache_flash(
@@ -253,6 +258,9 @@ def test_reshape_and_cache_flash(
253
  device: str,
254
  kv_cache_dtype: str,
255
  ) -> None:
 
 
 
256
  current_platform.seed_everything(seed)
257
  torch.set_default_device(device)
258
 
@@ -341,10 +349,10 @@ def test_reshape_and_cache_flash(
341
 
342
  if kv_cache_dtype == "fp8":
343
  torch.testing.assert_close(
344
- result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
345
  )
346
  torch.testing.assert_close(
347
- result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
348
  )
349
  else:
350
  torch.testing.assert_close(key_cache, cloned_key_cache)
@@ -359,7 +367,7 @@ def test_reshape_and_cache_flash(
359
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
360
  @pytest.mark.parametrize("dtype", DTYPES)
361
  @pytest.mark.parametrize("seed", SEEDS)
362
- @pytest.mark.parametrize("device", CUDA_DEVICES)
363
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
364
  @torch.inference_mode()
365
  def test_swap_blocks(
@@ -382,8 +390,8 @@ def test_swap_blocks(
382
 
383
  current_platform.seed_everything(seed)
384
 
385
- src_device = device if direction[0] == "cuda" else "cpu"
386
- dst_device = device if direction[1] == "cuda" else "cpu"
387
 
388
  src_blocks = random.sample(range(num_blocks), num_mappings)
389
  # For the same device, mapping must not overlap
@@ -458,7 +466,7 @@ def test_swap_blocks(
458
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
459
  @pytest.mark.parametrize("dtype", DTYPES)
460
  @pytest.mark.parametrize("seed", SEEDS)
461
- @pytest.mark.parametrize("device", CUDA_DEVICES)
462
  @torch.inference_mode()
463
  def test_fp8_e4m3_conversion(
464
  num_heads: int,
@@ -483,4 +491,4 @@ def test_fp8_e4m3_conversion(
483
  converted_cache = torch.empty_like(cache)
484
  ops.convert_fp8(converted_cache, cache_fp8)
485
 
486
- torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
 
8
 
9
  from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
10
 
11
+ COPYING_DIRECTION = [("gpu", "cpu"), ("gpu", "gpu"), ("cpu", "gpu")]
12
  DTYPES = [torch.half, torch.bfloat16, torch.float]
13
  NUM_TOKENS = [42] # Arbitrary values for testing
14
  NUM_LAYERS = [1] # Arbitrary values for testing
 
22
 
23
  NUM_MAPPINGS = [256] # Arbitrary values for testing
24
  SEEDS = [0]
25
+ if current_platform.is_mps():
26
+ DEVICES = ["mps:0"]
27
+ else:
28
+ DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
29
 
30
+ if current_platform.is_mps():
31
+ KV_CACHE_DTYPE = ["auto", "fp8"]
32
+ else:
33
+ KV_CACHE_DTYPE = ["auto", "fp8"]
34
 
35
 
36
  @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
 
41
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
42
  @pytest.mark.parametrize("dtype", DTYPES)
43
  @pytest.mark.parametrize("seed", SEEDS)
44
+ @pytest.mark.parametrize("device", DEVICES)
45
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
46
  @torch.inference_mode()
47
  def test_copy_blocks(
 
126
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
127
  @pytest.mark.parametrize("dtype", DTYPES)
128
  @pytest.mark.parametrize("seed", SEEDS)
129
+ @pytest.mark.parametrize("device", DEVICES)
130
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
131
  @torch.inference_mode()
132
  def test_reshape_and_cache(
 
226
 
227
  if kv_cache_dtype == "fp8":
228
  torch.testing.assert_close(
229
+ result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2
230
  )
231
  torch.testing.assert_close(
232
+ result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2
233
  )
234
  else:
235
  torch.testing.assert_close(key_cache, cloned_key_cache)
 
243
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
244
  @pytest.mark.parametrize("dtype", DTYPES)
245
  @pytest.mark.parametrize("seed", SEEDS)
246
+ @pytest.mark.parametrize("device", DEVICES)
247
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
248
  @torch.inference_mode()
249
  def test_reshape_and_cache_flash(
 
258
  device: str,
259
  kv_cache_dtype: str,
260
  ) -> None:
261
+ # Flash variant doesn't support FP8 on MPS devices yet
262
+ if current_platform.is_mps() and kv_cache_dtype == "fp8":
263
+ pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS")
264
  current_platform.seed_everything(seed)
265
  torch.set_default_device(device)
266
 
 
349
 
350
  if kv_cache_dtype == "fp8":
351
  torch.testing.assert_close(
352
+ result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2
353
  )
354
  torch.testing.assert_close(
355
+ result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2
356
  )
357
  else:
358
  torch.testing.assert_close(key_cache, cloned_key_cache)
 
367
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
368
  @pytest.mark.parametrize("dtype", DTYPES)
369
  @pytest.mark.parametrize("seed", SEEDS)
370
+ @pytest.mark.parametrize("device", DEVICES)
371
  @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
372
  @torch.inference_mode()
373
  def test_swap_blocks(
 
390
 
391
  current_platform.seed_everything(seed)
392
 
393
+ src_device = device if direction[0] == "gpu" else "cpu"
394
+ dst_device = device if direction[1] == "gpu" else "cpu"
395
 
396
  src_blocks = random.sample(range(num_blocks), num_mappings)
397
  # For the same device, mapping must not overlap
 
466
  @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
467
  @pytest.mark.parametrize("dtype", DTYPES)
468
  @pytest.mark.parametrize("seed", SEEDS)
469
+ @pytest.mark.parametrize("device", DEVICES)
470
  @torch.inference_mode()
471
  def test_fp8_e4m3_conversion(
472
  num_heads: int,
 
491
  converted_cache = torch.empty_like(cache)
492
  ops.convert_fp8(converted_cache, cache_fp8)
493
 
494
+ torch.testing.assert_close(cache, converted_cache, atol=0.02, rtol=0.2)
tests/kernels/utils.py CHANGED
@@ -71,12 +71,24 @@ def opcheck(
71
  cond: bool = True
72
  ) -> Dict[str, str]:
73
  with unittest.mock.patch("torch.allclose", new=fp8_allclose):
74
- return (
75
- torch.library.opcheck(
76
- op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
77
- )
78
- if cond
79
- else {}
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
 
82
 
 
71
  cond: bool = True
72
  ) -> Dict[str, str]:
73
  with unittest.mock.patch("torch.allclose", new=fp8_allclose):
74
+ if not cond:
75
+ return {}
76
+
77
+ # Check if any arguments are on MPS device and skip opcheck if so
78
+ # as MPS has issues with placeholder storage allocation in opcheck
79
+ def is_mps_tensor(x):
80
+ return hasattr(x, 'device') and x.device.type == 'mps'
81
+
82
+ def check_args_for_mps(args):
83
+ if isinstance(args, (list, tuple)):
84
+ return any(check_args_for_mps(arg) for arg in args)
85
+ return is_mps_tensor(args)
86
+
87
+ if check_args_for_mps(args):
88
+ return {}
89
+
90
+ return torch.library.opcheck(
91
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
92
  )
93
 
94
 
torch-ext/paged_attention/platforms.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import torch
9
 
10
  IS_ROCM = torch.version.hip is not None
 
11
 
12
 
13
  class Platform(ABC):
@@ -32,6 +33,9 @@ class Platform(ABC):
32
  @abstractmethod
33
  def is_rocm(self) -> bool: ...
34
 
 
 
 
35
 
36
  class CudaPlatform(Platform):
37
  @classmethod
@@ -45,6 +49,9 @@ class CudaPlatform(Platform):
45
  def is_rocm(self) -> bool:
46
  return False
47
 
 
 
 
48
 
49
  class RocmPlatform(Platform):
50
  @classmethod
@@ -58,5 +65,28 @@ class RocmPlatform(Platform):
58
  def is_rocm(self) -> bool:
59
  return True
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
 
 
 
 
 
8
  import torch
9
 
10
  IS_ROCM = torch.version.hip is not None
11
+ IS_MPS = torch.backends.mps.is_available()
12
 
13
 
14
  class Platform(ABC):
 
33
  @abstractmethod
34
  def is_rocm(self) -> bool: ...
35
 
36
+ @abstractmethod
37
+ def is_mps(self) -> bool: ...
38
+
39
 
40
  class CudaPlatform(Platform):
41
  @classmethod
 
49
  def is_rocm(self) -> bool:
50
  return False
51
 
52
+ def is_mps(self) -> bool:
53
+ return False
54
+
55
 
56
  class RocmPlatform(Platform):
57
  @classmethod
 
65
  def is_rocm(self) -> bool:
66
  return True
67
 
68
+ def is_mps(self) -> bool:
69
+ return False
70
+
71
+
72
+ class MpsPlatform(Platform):
73
+ @classmethod
74
+ @lru_cache(maxsize=8)
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ return torch.cuda.get_device_name(device_id)
77
+
78
+ def is_cuda(self) -> bool:
79
+ return False
80
+
81
+ def is_rocm(self) -> bool:
82
+ return False
83
+
84
+ def is_mps(self) -> bool:
85
+ return True
86
 
87
+ current_platform = (
88
+ RocmPlatform() if IS_ROCM else
89
+ MpsPlatform() if IS_MPS else
90
+ CudaPlatform() if torch.cuda.is_available() else
91
+ None
92
+ )
torch-ext/torch_binding.cpp CHANGED
@@ -15,81 +15,108 @@
15
  // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
16
 
17
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
18
- // Attention ops
19
- // Compute the attention between an input query and the cached
20
- // keys/values using PagedAttention.
21
- ops.def(
22
- "paged_attention_v1("
23
- " Tensor! out, Tensor query, Tensor key_cache,"
24
- " Tensor value_cache, int num_kv_heads, float scale,"
25
- " Tensor block_tables, Tensor seq_lens, int block_size,"
26
- " int max_seq_len, Tensor? alibi_slopes,"
27
- " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
28
- " int tp_rank, int blocksparse_local_blocks,"
29
- " int blocksparse_vert_stride, int blocksparse_block_size,"
30
- " int blocksparse_head_sliding_step) -> ()");
31
- ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
 
 
 
 
32
 
33
- // PagedAttention V2.
34
- ops.def(
35
- "paged_attention_v2("
36
- " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
37
- " Tensor! tmp_out, Tensor query, Tensor key_cache,"
38
- " Tensor value_cache, int num_kv_heads, float scale,"
39
- " Tensor block_tables, Tensor seq_lens, int block_size,"
40
- " int max_seq_len, Tensor? alibi_slopes,"
41
- " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
42
- " int tp_rank, int blocksparse_local_blocks,"
43
- " int blocksparse_vert_stride, int blocksparse_block_size,"
44
- " int blocksparse_head_sliding_step) -> ()");
45
- ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
 
 
 
 
46
 
47
- // Swap in (out) the cache blocks from src to dst.
48
- ops.def(
49
- "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
50
- ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
 
 
 
 
51
 
52
- // Copy the cache blocks from src to dst.
53
- ops.def(
54
- "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
55
- "Tensor block_mapping) -> ()");
56
- ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
 
 
 
 
57
 
58
- // Reshape the key and value tensors and cache them.
59
- ops.def(
60
- "reshape_and_cache(Tensor key, Tensor value,"
61
- " Tensor! key_cache, Tensor! value_cache,"
62
- " Tensor slot_mapping,"
63
- " str kv_cache_dtype,"
64
- " Tensor k_scale, Tensor v_scale) -> ()");
65
- ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
 
 
 
 
66
 
67
- // Reshape the key and value tensors and cache them.
68
- ops.def(
69
- "reshape_and_cache_flash(Tensor key, Tensor value,"
70
- " Tensor! key_cache,"
71
- " Tensor! value_cache,"
72
- " Tensor slot_mapping,"
73
- " str kv_cache_dtype,"
74
- " Tensor k_scale, Tensor v_scale) -> ()");
75
- ops.impl("reshape_and_cache_flash", torch::kCUDA,
76
- &reshape_and_cache_flash);
 
 
 
77
 
78
- // Gets the specified device attribute.
79
- ops.def("get_device_attribute(int attribute, int device_id) -> int");
80
- ops.impl("get_device_attribute", &get_device_attribute);
81
 
82
- // Gets the maximum shared memory per block device attribute.
83
- ops.def(
84
- "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
85
- ops.impl("get_max_shared_memory_per_block_device_attribute",
86
- &get_max_shared_memory_per_block_device_attribute);
87
 
88
- // Convert the key and value cache to fp8 data type.
89
- ops.def(
90
- "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
91
- "str kv_cache_dtype) -> ()");
92
- ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
 
 
 
 
93
  }
94
 
95
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
15
  // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
16
 
17
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
18
+ // Attention ops
19
+ // Compute the attention between an input query and the cached
20
+ // keys/values using PagedAttention.
21
+ ops.def(
22
+ "paged_attention_v1("
23
+ " Tensor! out, Tensor query, Tensor key_cache,"
24
+ " Tensor value_cache, int num_kv_heads, float scale,"
25
+ " Tensor block_tables, Tensor seq_lens, int block_size,"
26
+ " int max_seq_len, Tensor? alibi_slopes,"
27
+ " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
28
+ " int tp_rank, int blocksparse_local_blocks,"
29
+ " int blocksparse_vert_stride, int blocksparse_block_size,"
30
+ " int blocksparse_head_sliding_step) -> ()");
31
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
32
+ ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
33
+ #elif defined(METAL_KERNEL)
34
+ ops.impl("paged_attention_v1", torch::kMPS, paged_attention_v1);
35
+ #endif
36
 
37
+ // PagedAttention V2.
38
+ ops.def(
39
+ "paged_attention_v2("
40
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
41
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
42
+ " Tensor value_cache, int num_kv_heads, float scale,"
43
+ " Tensor block_tables, Tensor seq_lens, int block_size,"
44
+ " int max_seq_len, Tensor? alibi_slopes,"
45
+ " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
46
+ " int tp_rank, int blocksparse_local_blocks,"
47
+ " int blocksparse_vert_stride, int blocksparse_block_size,"
48
+ " int blocksparse_head_sliding_step) -> ()");
49
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
50
+ ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
51
+ #elif defined(METAL_KERNEL)
52
+ ops.impl("paged_attention_v2", torch::kMPS, paged_attention_v2);
53
+ #endif
54
 
55
+ // Swap in (out) the cache blocks from src to dst.
56
+ ops.def(
57
+ "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
58
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
59
+ ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
60
+ #elif defined(METAL_KERNEL)
61
+ ops.impl("swap_blocks", torch::kMPS, swap_blocks);
62
+ #endif
63
 
64
+ // Copy the cache blocks from src to dst.
65
+ ops.def(
66
+ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
67
+ "Tensor block_mapping) -> ()");
68
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
69
+ ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
70
+ #elif defined(METAL_KERNEL)
71
+ ops.impl("copy_blocks", torch::kMPS, copy_blocks);
72
+ #endif
73
 
74
+ // Reshape the key and value tensors and cache them.
75
+ ops.def(
76
+ "reshape_and_cache(Tensor key, Tensor value,"
77
+ " Tensor! key_cache, Tensor! value_cache,"
78
+ " Tensor slot_mapping,"
79
+ " str kv_cache_dtype,"
80
+ " Tensor k_scale, Tensor v_scale) -> ()");
81
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
82
+ ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
83
+ #elif defined(METAL_KERNEL)
84
+ ops.impl("reshape_and_cache", torch::kMPS, reshape_and_cache);
85
+ #endif
86
 
87
+ // Reshape the key and value tensors and cache them.
88
+ ops.def(
89
+ "reshape_and_cache_flash(Tensor key, Tensor value,"
90
+ " Tensor! key_cache,"
91
+ " Tensor! value_cache,"
92
+ " Tensor slot_mapping,"
93
+ " str kv_cache_dtype,"
94
+ " Tensor k_scale, Tensor v_scale) -> ()");
95
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
96
+ ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash);
97
+ #elif defined(METAL_KERNEL)
98
+ ops.impl("reshape_and_cache_flash", torch::kMPS, reshape_and_cache_flash);
99
+ #endif
100
 
101
+ // Gets the specified device attribute.
102
+ ops.def("get_device_attribute(int attribute, int device_id) -> int");
103
+ ops.impl("get_device_attribute", &get_device_attribute);
104
 
105
+ // Gets the maximum shared memory per block device attribute.
106
+ ops.def(
107
+ "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
108
+ ops.impl("get_max_shared_memory_per_block_device_attribute",
109
+ &get_max_shared_memory_per_block_device_attribute);
110
 
111
+ // Convert the key and value cache to fp8 data type.
112
+ ops.def(
113
+ "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
114
+ "str kv_cache_dtype) -> ()");
115
+ #if defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
116
+ ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
117
+ #elif defined(METAL_KERNEL)
118
+ ops.impl("convert_fp8", torch::kMPS, convert_fp8);
119
+ #endif
120
  }
121
 
122
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)