Mohamed Mekkouri commited on
Commit
9ffd725
·
1 Parent(s): 9964bae

new builds

Browse files
build/torch28-metal-aarch64-darwin/gptoss_kernels/__init__.py CHANGED
@@ -1,8 +1,174 @@
1
  from ._ops import ops
2
  import torch
3
 
4
- def f32_bf16w_matmul(input: torch.Tensor, weight_bf16: torch.Tensor, bias_bf16: torch.Tensor, output: torch.Tensor, num_tokens: int, num_cols: int, num_rows: int, threadgroup_size: int) -> None:
5
- ops.f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, num_tokens, num_cols, num_rows, threadgroup_size)
 
 
 
 
 
 
 
 
6
  return output
7
 
8
- __all__ = ["f32_bf16w_matmul"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from ._ops import ops
2
  import torch
3
 
4
+ def f32_bf16w_matmul(input: torch.Tensor,
5
+ weight_bf16: torch.Tensor,
6
+ bias_bf16: torch.Tensor,
7
+ output: torch.Tensor,
8
+ num_tokens: int,
9
+ num_cols: int,
10
+ num_rows: int,
11
+ threadgroup_size: int) -> torch.Tensor:
12
+ ops.f32_bf16w_matmul(input, weight_bf16, bias_bf16, output,
13
+ num_tokens, num_cols, num_rows, threadgroup_size)
14
  return output
15
 
16
+ def bf16_f32_embeddings(token_ids: torch.Tensor,
17
+ weight_bf16: torch.Tensor,
18
+ output: torch.Tensor,
19
+ threadgroup_size: int) -> torch.Tensor:
20
+ ops.bf16_f32_embeddings(token_ids, weight_bf16, output, threadgroup_size)
21
+ return output
22
+
23
+ def f32_bf16w_rmsnorm(input: torch.Tensor,
24
+ weight_bf16: torch.Tensor,
25
+ output: torch.Tensor,
26
+ epsilon: float) -> torch.Tensor:
27
+ ops.f32_bf16w_rmsnorm(input, weight_bf16, output, epsilon)
28
+ return output
29
+
30
+ def f32_bf16w_dense_matmul_qkv(input: torch.Tensor,
31
+ weight_bf16: torch.Tensor,
32
+ bias_bf16: torch.Tensor,
33
+ output: torch.Tensor) -> torch.Tensor:
34
+ ops.f32_bf16w_dense_matmul_qkv(input, weight_bf16, bias_bf16, output)
35
+ return output
36
+
37
+ def f32_bf16w_dense_matmul_attn_output(input: torch.Tensor,
38
+ weight_bf16: torch.Tensor,
39
+ bias_bf16: torch.Tensor,
40
+ output: torch.Tensor) -> torch.Tensor:
41
+ ops.f32_bf16w_dense_matmul_attn_output(input, weight_bf16, bias_bf16, output)
42
+ return output
43
+
44
+ def f32_bf16w_dense_matmul_mlp_gate(input: torch.Tensor,
45
+ weight_bf16: torch.Tensor,
46
+ bias_bf16: torch.Tensor,
47
+ output: torch.Tensor) -> torch.Tensor:
48
+ ops.f32_bf16w_dense_matmul_mlp_gate(input, weight_bf16, bias_bf16, output)
49
+ return output
50
+
51
+ def f32_rope(activations: torch.Tensor,
52
+ rope_base: float,
53
+ interpolation_scale: float,
54
+ yarn_offset: float,
55
+ yarn_scale: float,
56
+ yarn_multiplier: float,
57
+ num_tokens: int,
58
+ num_q_heads: int,
59
+ num_kv_heads: int,
60
+ attn_head_dim: int,
61
+ token_offset: int,
62
+ threadgroup_size: int) -> torch.Tensor:
63
+ ops.f32_rope(activations, rope_base, interpolation_scale, yarn_offset,
64
+ yarn_scale, yarn_multiplier, num_tokens, num_q_heads,
65
+ num_kv_heads, attn_head_dim, token_offset, threadgroup_size)
66
+ return activations
67
+
68
+ def f32_bf16w_matmul_qkv(input: torch.Tensor,
69
+ weight_bf16: torch.Tensor,
70
+ bias_bf16: torch.Tensor,
71
+ output: torch.Tensor,
72
+ kv_cache: torch.Tensor,
73
+ kv_cache_offset_bytes: int,
74
+ num_tokens: int,
75
+ num_cols: int,
76
+ num_q_heads: int,
77
+ num_kv_heads: int,
78
+ attn_head_dim: int,
79
+ token_offset: int,
80
+ max_tokens: int,
81
+ rope_base: float,
82
+ interpolation_scale: float,
83
+ yarn_offset: float,
84
+ yarn_scale: float,
85
+ yarn_multiplier: float,
86
+ threadgroup_size: int) -> torch.Tensor:
87
+ ops.f32_bf16w_matmul_qkv(input, weight_bf16, bias_bf16, output, kv_cache,
88
+ kv_cache_offset_bytes, num_tokens, num_cols,
89
+ num_q_heads, num_kv_heads, attn_head_dim,
90
+ token_offset, max_tokens, rope_base,
91
+ interpolation_scale, yarn_offset, yarn_scale,
92
+ yarn_multiplier, threadgroup_size)
93
+ return output
94
+
95
+ def f32_sdpa(q: torch.Tensor,
96
+ q_offset_bytes: int,
97
+ kv: torch.Tensor,
98
+ kv_offset_bytes: int,
99
+ s_bf16: torch.Tensor,
100
+ s_offset_bytes: int,
101
+ output: torch.Tensor,
102
+ output_offset_bytes: int,
103
+ window: int,
104
+ kv_stride: int,
105
+ num_q_tokens: int,
106
+ num_kv_tokens: int,
107
+ num_q_heads: int,
108
+ num_kv_heads: int,
109
+ head_dim: int) -> torch.Tensor:
110
+ ops.f32_sdpa(q, q_offset_bytes, kv, kv_offset_bytes, s_bf16, s_offset_bytes,
111
+ output, output_offset_bytes, window, kv_stride,
112
+ num_q_tokens, num_kv_tokens, num_q_heads, num_kv_heads, head_dim)
113
+ return output
114
+
115
+ def f32_topk(scores: torch.Tensor,
116
+ expert_ids: torch.Tensor,
117
+ expert_scores: torch.Tensor,
118
+ num_tokens: int,
119
+ num_experts: int,
120
+ num_active_experts: int) -> None:
121
+ ops.f32_topk(scores, expert_ids, expert_scores,
122
+ num_tokens, num_experts, num_active_experts)
123
+
124
+ def expert_routing_metadata(expert_ids: torch.Tensor,
125
+ expert_scores: torch.Tensor,
126
+ expert_offsets: torch.Tensor,
127
+ intra_expert_offsets: torch.Tensor,
128
+ num_tokens: int,
129
+ num_experts: int) -> None:
130
+ ops.expert_routing_metadata(expert_ids, expert_scores,
131
+ expert_offsets, intra_expert_offsets,
132
+ num_tokens, num_experts)
133
+
134
+ def f32_scatter(input: torch.Tensor,
135
+ expert_ids: torch.Tensor,
136
+ expert_scores: torch.Tensor,
137
+ expert_offsets: torch.Tensor,
138
+ intra_expert_offsets: torch.Tensor,
139
+ output: torch.Tensor,
140
+ num_channels: int,
141
+ num_tokens: int,
142
+ num_active_experts: int) -> torch.Tensor:
143
+ ops.f32_scatter(input, expert_ids, expert_scores,
144
+ expert_offsets, intra_expert_offsets,
145
+ output, num_channels, num_tokens, num_active_experts)
146
+ return output
147
+
148
+ def f32_bf16w_matmul_add(input: torch.Tensor,
149
+ weight_bf16: torch.Tensor,
150
+ bias_bf16: torch.Tensor,
151
+ output: torch.Tensor,
152
+ num_tokens: int,
153
+ num_cols: int,
154
+ num_rows: int,
155
+ threadgroup_size: int) -> torch.Tensor:
156
+ ops.f32_bf16w_matmul_add(input, weight_bf16, bias_bf16, output,
157
+ num_tokens, num_cols, num_rows, threadgroup_size)
158
+ return output
159
+
160
+ __all__ = [
161
+ "f32_bf16w_matmul",
162
+ "bf16_f32_embeddings",
163
+ "f32_bf16w_rmsnorm",
164
+ "f32_bf16w_dense_matmul_qkv",
165
+ "f32_bf16w_dense_matmul_attn_output",
166
+ "f32_bf16w_dense_matmul_mlp_gate",
167
+ "f32_rope",
168
+ "f32_bf16w_matmul_qkv",
169
+ "f32_sdpa",
170
+ "f32_topk",
171
+ "expert_routing_metadata",
172
+ "f32_scatter",
173
+ "f32_bf16w_matmul_add",
174
+ ]
build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc differ
 
build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc and b/build/torch28-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc differ
 
build/torch28-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_5341d17_dirty.abi3.so → _gptoss_kernels_9964bae_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fa19b7a232893fc5ac4ef189ae0973e3e672efac424580f68fd2873cb2a7fbc8
3
- size 291032
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b52d3924ac74e614664fd9ec72e9673807ed170e57277b81c1922c0b54a88a6a
3
+ size 391752
build/torch28-metal-aarch64-darwin/gptoss_kernels/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _gptoss_kernels_5341d17_dirty
3
- ops = torch.ops._gptoss_kernels_5341d17_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_gptoss_kernels_5341d17_dirty::{op_name}"
 
1
  import torch
2
+ from . import _gptoss_kernels_9964bae_dirty
3
+ ops = torch.ops._gptoss_kernels_9964bae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_gptoss_kernels_9964bae_dirty::{op_name}"
build/torch29-metal-aarch64-darwin/gptoss_kernels/__init__.py CHANGED
@@ -1,8 +1,174 @@
1
  from ._ops import ops
2
  import torch
3
 
4
- def f32_bf16w_matmul(input: torch.Tensor, weight_bf16: torch.Tensor, bias_bf16: torch.Tensor, output: torch.Tensor, num_tokens: int, num_cols: int, num_rows: int, threadgroup_size: int) -> None:
5
- ops.f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, num_tokens, num_cols, num_rows, threadgroup_size)
 
 
 
 
 
 
 
 
6
  return output
7
 
8
- __all__ = ["f32_bf16w_matmul"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from ._ops import ops
2
  import torch
3
 
4
+ def f32_bf16w_matmul(input: torch.Tensor,
5
+ weight_bf16: torch.Tensor,
6
+ bias_bf16: torch.Tensor,
7
+ output: torch.Tensor,
8
+ num_tokens: int,
9
+ num_cols: int,
10
+ num_rows: int,
11
+ threadgroup_size: int) -> torch.Tensor:
12
+ ops.f32_bf16w_matmul(input, weight_bf16, bias_bf16, output,
13
+ num_tokens, num_cols, num_rows, threadgroup_size)
14
  return output
15
 
16
+ def bf16_f32_embeddings(token_ids: torch.Tensor,
17
+ weight_bf16: torch.Tensor,
18
+ output: torch.Tensor,
19
+ threadgroup_size: int) -> torch.Tensor:
20
+ ops.bf16_f32_embeddings(token_ids, weight_bf16, output, threadgroup_size)
21
+ return output
22
+
23
+ def f32_bf16w_rmsnorm(input: torch.Tensor,
24
+ weight_bf16: torch.Tensor,
25
+ output: torch.Tensor,
26
+ epsilon: float) -> torch.Tensor:
27
+ ops.f32_bf16w_rmsnorm(input, weight_bf16, output, epsilon)
28
+ return output
29
+
30
+ def f32_bf16w_dense_matmul_qkv(input: torch.Tensor,
31
+ weight_bf16: torch.Tensor,
32
+ bias_bf16: torch.Tensor,
33
+ output: torch.Tensor) -> torch.Tensor:
34
+ ops.f32_bf16w_dense_matmul_qkv(input, weight_bf16, bias_bf16, output)
35
+ return output
36
+
37
+ def f32_bf16w_dense_matmul_attn_output(input: torch.Tensor,
38
+ weight_bf16: torch.Tensor,
39
+ bias_bf16: torch.Tensor,
40
+ output: torch.Tensor) -> torch.Tensor:
41
+ ops.f32_bf16w_dense_matmul_attn_output(input, weight_bf16, bias_bf16, output)
42
+ return output
43
+
44
+ def f32_bf16w_dense_matmul_mlp_gate(input: torch.Tensor,
45
+ weight_bf16: torch.Tensor,
46
+ bias_bf16: torch.Tensor,
47
+ output: torch.Tensor) -> torch.Tensor:
48
+ ops.f32_bf16w_dense_matmul_mlp_gate(input, weight_bf16, bias_bf16, output)
49
+ return output
50
+
51
+ def f32_rope(activations: torch.Tensor,
52
+ rope_base: float,
53
+ interpolation_scale: float,
54
+ yarn_offset: float,
55
+ yarn_scale: float,
56
+ yarn_multiplier: float,
57
+ num_tokens: int,
58
+ num_q_heads: int,
59
+ num_kv_heads: int,
60
+ attn_head_dim: int,
61
+ token_offset: int,
62
+ threadgroup_size: int) -> torch.Tensor:
63
+ ops.f32_rope(activations, rope_base, interpolation_scale, yarn_offset,
64
+ yarn_scale, yarn_multiplier, num_tokens, num_q_heads,
65
+ num_kv_heads, attn_head_dim, token_offset, threadgroup_size)
66
+ return activations
67
+
68
+ def f32_bf16w_matmul_qkv(input: torch.Tensor,
69
+ weight_bf16: torch.Tensor,
70
+ bias_bf16: torch.Tensor,
71
+ output: torch.Tensor,
72
+ kv_cache: torch.Tensor,
73
+ kv_cache_offset_bytes: int,
74
+ num_tokens: int,
75
+ num_cols: int,
76
+ num_q_heads: int,
77
+ num_kv_heads: int,
78
+ attn_head_dim: int,
79
+ token_offset: int,
80
+ max_tokens: int,
81
+ rope_base: float,
82
+ interpolation_scale: float,
83
+ yarn_offset: float,
84
+ yarn_scale: float,
85
+ yarn_multiplier: float,
86
+ threadgroup_size: int) -> torch.Tensor:
87
+ ops.f32_bf16w_matmul_qkv(input, weight_bf16, bias_bf16, output, kv_cache,
88
+ kv_cache_offset_bytes, num_tokens, num_cols,
89
+ num_q_heads, num_kv_heads, attn_head_dim,
90
+ token_offset, max_tokens, rope_base,
91
+ interpolation_scale, yarn_offset, yarn_scale,
92
+ yarn_multiplier, threadgroup_size)
93
+ return output
94
+
95
+ def f32_sdpa(q: torch.Tensor,
96
+ q_offset_bytes: int,
97
+ kv: torch.Tensor,
98
+ kv_offset_bytes: int,
99
+ s_bf16: torch.Tensor,
100
+ s_offset_bytes: int,
101
+ output: torch.Tensor,
102
+ output_offset_bytes: int,
103
+ window: int,
104
+ kv_stride: int,
105
+ num_q_tokens: int,
106
+ num_kv_tokens: int,
107
+ num_q_heads: int,
108
+ num_kv_heads: int,
109
+ head_dim: int) -> torch.Tensor:
110
+ ops.f32_sdpa(q, q_offset_bytes, kv, kv_offset_bytes, s_bf16, s_offset_bytes,
111
+ output, output_offset_bytes, window, kv_stride,
112
+ num_q_tokens, num_kv_tokens, num_q_heads, num_kv_heads, head_dim)
113
+ return output
114
+
115
+ def f32_topk(scores: torch.Tensor,
116
+ expert_ids: torch.Tensor,
117
+ expert_scores: torch.Tensor,
118
+ num_tokens: int,
119
+ num_experts: int,
120
+ num_active_experts: int) -> None:
121
+ ops.f32_topk(scores, expert_ids, expert_scores,
122
+ num_tokens, num_experts, num_active_experts)
123
+
124
+ def expert_routing_metadata(expert_ids: torch.Tensor,
125
+ expert_scores: torch.Tensor,
126
+ expert_offsets: torch.Tensor,
127
+ intra_expert_offsets: torch.Tensor,
128
+ num_tokens: int,
129
+ num_experts: int) -> None:
130
+ ops.expert_routing_metadata(expert_ids, expert_scores,
131
+ expert_offsets, intra_expert_offsets,
132
+ num_tokens, num_experts)
133
+
134
+ def f32_scatter(input: torch.Tensor,
135
+ expert_ids: torch.Tensor,
136
+ expert_scores: torch.Tensor,
137
+ expert_offsets: torch.Tensor,
138
+ intra_expert_offsets: torch.Tensor,
139
+ output: torch.Tensor,
140
+ num_channels: int,
141
+ num_tokens: int,
142
+ num_active_experts: int) -> torch.Tensor:
143
+ ops.f32_scatter(input, expert_ids, expert_scores,
144
+ expert_offsets, intra_expert_offsets,
145
+ output, num_channels, num_tokens, num_active_experts)
146
+ return output
147
+
148
+ def f32_bf16w_matmul_add(input: torch.Tensor,
149
+ weight_bf16: torch.Tensor,
150
+ bias_bf16: torch.Tensor,
151
+ output: torch.Tensor,
152
+ num_tokens: int,
153
+ num_cols: int,
154
+ num_rows: int,
155
+ threadgroup_size: int) -> torch.Tensor:
156
+ ops.f32_bf16w_matmul_add(input, weight_bf16, bias_bf16, output,
157
+ num_tokens, num_cols, num_rows, threadgroup_size)
158
+ return output
159
+
160
+ __all__ = [
161
+ "f32_bf16w_matmul",
162
+ "bf16_f32_embeddings",
163
+ "f32_bf16w_rmsnorm",
164
+ "f32_bf16w_dense_matmul_qkv",
165
+ "f32_bf16w_dense_matmul_attn_output",
166
+ "f32_bf16w_dense_matmul_mlp_gate",
167
+ "f32_rope",
168
+ "f32_bf16w_matmul_qkv",
169
+ "f32_sdpa",
170
+ "f32_topk",
171
+ "expert_routing_metadata",
172
+ "f32_scatter",
173
+ "f32_bf16w_matmul_add",
174
+ ]
build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc CHANGED
Binary files a/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/__init__.cpython-313.pyc differ
 
build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc CHANGED
Binary files a/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc and b/build/torch29-metal-aarch64-darwin/gptoss_kernels/__pycache__/_ops.cpython-313.pyc differ
 
build/torch29-metal-aarch64-darwin/gptoss_kernels/{_gptoss_kernels_5341d17_dirty.abi3.so → _gptoss_kernels_9964bae_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:679729e810bc2a360f49eed34299e2d63f0eb24489d0f6a032ee12175e7831a3
3
- size 292040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc170dbf45587f9a1091e9b6c92ab02ebe4dc3cdd13be8e56a9a8d3a353d8c86
3
+ size 392840
build/torch29-metal-aarch64-darwin/gptoss_kernels/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _gptoss_kernels_5341d17_dirty
3
- ops = torch.ops._gptoss_kernels_5341d17_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_gptoss_kernels_5341d17_dirty::{op_name}"
 
1
  import torch
2
+ from . import _gptoss_kernels_9964bae_dirty
3
+ ops = torch.ops._gptoss_kernels_9964bae_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_gptoss_kernels_9964bae_dirty::{op_name}"
gptoss_kernels/source/tensor_wrappers.cpp CHANGED
@@ -1,6 +1,227 @@
 
 
 
1
  #include <internal/metal-kernels.h>
2
  #include <internal/metal.h>
3
- #include <ATen/Tensor.h>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  void f32_bf16w_matmul_torch(const at::Tensor &input,
6
  const at::Tensor &weight_bf16,
@@ -25,53 +246,719 @@ void f32_bf16w_matmul_torch(const at::Tensor &input,
25
  TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_rows,
26
  "output shape must be [num_tokens, num_rows]");
27
 
28
- auto input_cpu = input.contiguous().to(at::kCPU);
29
  auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
30
- auto bias_cpu = bias_bf16.contiguous().to(at::kCPU);
31
- auto out_cpu = output.detach().to(at::kCPU).contiguous().clone();
32
 
33
- gptoss_metal_device device{}; gptoss_metal_library library{};
34
- gptoss_metal_function fn{}; gptoss_metal_command_queue cq{};
35
- gptoss_metal_command_buffer cb{};
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- TORCH_CHECK(gptoss_metal_device_create_system_default(&device) == gptoss_status_success, "device_create failed");
38
- TORCH_CHECK(gptoss_metal_library_create_default(&device, &library) == gptoss_status_success, "library_create failed");
39
- TORCH_CHECK(gptoss_metal_function_create(&library, "gptoss_f32_bf16w_matmul", &fn) == gptoss_status_success, "function_create failed");
40
- TORCH_CHECK(gptoss_metal_command_queue_create(&device, &cq) == gptoss_status_success, "cq_create failed");
41
- TORCH_CHECK(gptoss_metal_command_buffer_create(&cq, &cb) == gptoss_status_success, "cb_create failed");
42
-
43
- const size_t in_bytes = (size_t)num_tokens * (size_t)num_cols * sizeof(float);
44
- const size_t wt_bytes = (size_t)num_rows * (size_t)num_cols * sizeof(uint16_t);
45
- const size_t bs_bytes = (size_t)num_rows * sizeof(uint16_t);
46
- const size_t out_bytes = (size_t)num_tokens * (size_t)num_rows * sizeof(float);
47
-
48
- gptoss_metal_buffer in_buf{}, wt_buf{}, bs_buf{}, out_buf{}, ctrl_buf{};
49
- TORCH_CHECK(gptoss_metal_buffer_wrap(&device, in_bytes, input_cpu.data_ptr(), &in_buf) == gptoss_status_success, "wrap input failed");
50
- TORCH_CHECK(gptoss_metal_buffer_wrap(&device, wt_bytes, weight_cpu.data_ptr(), &wt_buf) == gptoss_status_success, "wrap weight failed");
51
- TORCH_CHECK(gptoss_metal_buffer_wrap(&device, bs_bytes, bias_cpu.data_ptr(), &bs_buf) == gptoss_status_success, "wrap bias failed");
52
- TORCH_CHECK(gptoss_metal_buffer_create(&device, out_bytes, nullptr, &out_buf) == gptoss_status_success, "alloc out failed");
53
- uint32_t ctrl_zero = 0;
54
- TORCH_CHECK(gptoss_metal_buffer_create(&device, sizeof(uint32_t), &ctrl_zero, &ctrl_buf) == gptoss_status_success, "alloc ctrl failed");
55
-
56
- TORCH_CHECK(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
57
- &cb, &fn, (size_t)threadgroup_size,
58
- &in_buf, 0, &wt_buf, 0, &bs_buf, 0, &out_buf, 0, &ctrl_buf, 0,
59
- (uint32_t)num_tokens, (uint32_t)num_cols, (uint32_t)num_rows) == gptoss_status_success, "encode failed");
60
-
61
- TORCH_CHECK(gptoss_metal_command_buffer_commit(&cb) == gptoss_status_success, "commit failed");
62
- TORCH_CHECK(gptoss_metal_command_buffer_wait_completion(&cb, nullptr) == gptoss_status_success, "wait failed");
63
-
64
- std::memcpy(out_cpu.data_ptr(), out_buf.ptr, out_bytes);
65
- output.copy_(out_cpu.to(output.device(), /*non_blocking=*/false, /*copy=*/true));
66
-
67
- (void) gptoss_metal_command_buffer_release(&cb);
68
- (void) gptoss_metal_command_queue_release(&cq);
69
- (void) gptoss_metal_function_release(&fn);
70
- (void) gptoss_metal_library_release(&library);
71
- (void) gptoss_metal_device_release(&device);
72
- (void) gptoss_metal_buffer_release(&ctrl_buf);
73
- (void) gptoss_metal_buffer_release(&out_buf);
74
- (void) gptoss_metal_buffer_release(&bs_buf);
75
- (void) gptoss_metal_buffer_release(&wt_buf);
76
- (void) gptoss_metal_buffer_release(&in_buf);
77
  }
 
1
+ #include <ATen/Functions.h>
2
+ #include <ATen/Tensor.h>
3
+
4
  #include <internal/metal-kernels.h>
5
  #include <internal/metal.h>
6
+
7
+ #include <algorithm>
8
+ #include <cstddef>
9
+ #include <cstdint>
10
+ #include <cstring>
11
+ #include <utility>
12
+ #include <vector>
13
+
14
+ namespace {
15
+
16
+ class MetalBuffer {
17
+ public:
18
+ MetalBuffer() = default;
19
+ MetalBuffer(const MetalBuffer&) = delete;
20
+ MetalBuffer& operator=(const MetalBuffer&) = delete;
21
+
22
+ MetalBuffer(MetalBuffer&& other) noexcept
23
+ : buffer_(other.buffer_), has_value_(other.has_value_) {
24
+ other.buffer_ = {};
25
+ other.has_value_ = false;
26
+ }
27
+
28
+ MetalBuffer& operator=(MetalBuffer&& other) noexcept {
29
+ if (this != &other) {
30
+ reset();
31
+ buffer_ = other.buffer_;
32
+ has_value_ = other.has_value_;
33
+ other.buffer_ = {};
34
+ other.has_value_ = false;
35
+ }
36
+ return *this;
37
+ }
38
+
39
+ ~MetalBuffer() {
40
+ reset();
41
+ }
42
+
43
+ gptoss_metal_buffer* get() {
44
+ return &buffer_;
45
+ }
46
+
47
+ const gptoss_metal_buffer* get() const {
48
+ return &buffer_;
49
+ }
50
+
51
+ void* ptr() const {
52
+ return buffer_.ptr;
53
+ }
54
+
55
+ size_t size_bytes() const {
56
+ return buffer_.size;
57
+ }
58
+
59
+ bool valid() const {
60
+ return has_value_;
61
+ }
62
+
63
+ void wrap(const gptoss_metal_device* device, size_t size, const void* data) {
64
+ reset();
65
+ TORCH_CHECK(gptoss_metal_buffer_wrap(device, size, data, &buffer_) == gptoss_status_success,
66
+ "metal_buffer_wrap failed");
67
+ has_value_ = true;
68
+ }
69
+
70
+ void create(const gptoss_metal_device* device, size_t size, const void* data = nullptr) {
71
+ reset();
72
+ TORCH_CHECK(gptoss_metal_buffer_create(device, size, data, &buffer_) == gptoss_status_success,
73
+ "metal_buffer_create failed");
74
+ has_value_ = true;
75
+ }
76
+
77
+ void reset() {
78
+ if (has_value_) {
79
+ (void) gptoss_metal_buffer_release(&buffer_);
80
+ buffer_ = {};
81
+ has_value_ = false;
82
+ }
83
+ }
84
+
85
+ private:
86
+ gptoss_metal_buffer buffer_{};
87
+ bool has_value_ = false;
88
+ };
89
+
90
+ template <typename EncodeFn>
91
+ void run_metal_kernel(const char* kernel_symbol, EncodeFn&& encode_fn) {
92
+ gptoss_metal_device device{};
93
+ gptoss_metal_library library{};
94
+ gptoss_metal_function fn{};
95
+ gptoss_metal_command_queue cq{};
96
+ gptoss_metal_command_buffer cb{};
97
+
98
+ auto cleanup = [&]() {
99
+ (void) gptoss_metal_command_buffer_release(&cb);
100
+ (void) gptoss_metal_command_queue_release(&cq);
101
+ (void) gptoss_metal_function_release(&fn);
102
+ (void) gptoss_metal_library_release(&library);
103
+ (void) gptoss_metal_device_release(&device);
104
+ };
105
+
106
+ TORCH_CHECK(gptoss_metal_device_create_system_default(&device) == gptoss_status_success,
107
+ "device_create failed");
108
+ try {
109
+ TORCH_CHECK(gptoss_metal_library_create_default(&device, &library) == gptoss_status_success,
110
+ "library_create failed");
111
+ TORCH_CHECK(gptoss_metal_function_create(&library, kernel_symbol, &fn) == gptoss_status_success,
112
+ "function_create failed");
113
+ TORCH_CHECK(gptoss_metal_command_queue_create(&device, &cq) == gptoss_status_success,
114
+ "cq_create failed");
115
+ TORCH_CHECK(gptoss_metal_command_buffer_create(&cq, &cb) == gptoss_status_success,
116
+ "cb_create failed");
117
+
118
+ encode_fn(device, fn, cb);
119
+
120
+ TORCH_CHECK(gptoss_metal_command_buffer_commit(&cb) == gptoss_status_success,
121
+ "commit failed");
122
+ TORCH_CHECK(gptoss_metal_command_buffer_wait_completion(&cb, nullptr) == gptoss_status_success,
123
+ "wait failed");
124
+ } catch (...) {
125
+ cleanup();
126
+ throw;
127
+ }
128
+
129
+ cleanup();
130
+ }
131
+
132
+ at::Tensor to_cpu_contiguous(const at::Tensor& tensor) {
133
+ if (tensor.device().is_cpu() && tensor.is_contiguous()) {
134
+ return tensor;
135
+ }
136
+ return tensor.contiguous().to(at::kCPU);
137
+ }
138
+
139
+ at::Tensor empty_cpu_like(const at::Tensor& tensor) {
140
+ return at::empty_like(tensor, tensor.options().device(at::kCPU)).contiguous();
141
+ }
142
+
143
+ void copy_back(at::Tensor& dst, const at::Tensor& src_cpu) {
144
+ dst.copy_(src_cpu.to(dst.device(), /*non_blocking=*/false, /*copy=*/true));
145
+ }
146
+
147
+ void create_control_buffer(const gptoss_metal_device* device, MetalBuffer& buffer) {
148
+ struct gptoss_control ctrl {0};
149
+ buffer.create(device, sizeof(ctrl), &ctrl);
150
+ }
151
+
152
+ template <typename LaunchFn>
153
+ void run_dense_matmul_bf16(const char* kernel_symbol,
154
+ LaunchFn&& launch_fn,
155
+ const at::Tensor& input,
156
+ const at::Tensor& weight_bf16,
157
+ const at::Tensor& bias_bf16,
158
+ at::Tensor& output)
159
+ {
160
+ TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
161
+ TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
162
+ TORCH_CHECK(bias_bf16.dtype() == at::kBFloat16, "bias must be bfloat16");
163
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
164
+
165
+ TORCH_CHECK(input.dim() == 2, "input must be 2D");
166
+ TORCH_CHECK(weight_bf16.dim() == 2, "weight must be 2D");
167
+ TORCH_CHECK(bias_bf16.dim() == 1, "bias must be 1D");
168
+ TORCH_CHECK(output.dim() == 2, "output must be 2D");
169
+
170
+ const int64_t num_tokens = input.size(0);
171
+ const int64_t num_cols = input.size(1);
172
+ const int64_t num_rows = output.size(1);
173
+
174
+ TORCH_CHECK(output.size(0) == num_tokens,
175
+ "output first dimension must match number of tokens");
176
+ TORCH_CHECK(weight_bf16.size(0) == num_cols && weight_bf16.size(1) == num_rows,
177
+ "weight shape must be [num_cols, num_rows]");
178
+ TORCH_CHECK(bias_bf16.size(0) == num_rows,
179
+ "bias length must equal number of rows");
180
+
181
+ auto input_cpu = to_cpu_contiguous(input);
182
+ auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
183
+ auto bias_cpu = to_cpu_contiguous(bias_bf16);
184
+ auto out_cpu = empty_cpu_like(output);
185
+
186
+ const size_t in_bytes = static_cast<size_t>(input_cpu.numel()) * input_cpu.element_size();
187
+ const size_t weight_bytes = static_cast<size_t>(weight_cpu.numel()) * weight_cpu.element_size();
188
+ const size_t bias_bytes = static_cast<size_t>(bias_cpu.numel()) * bias_cpu.element_size();
189
+ const size_t out_bytes = static_cast<size_t>(out_cpu.numel()) * out_cpu.element_size();
190
+
191
+ MetalBuffer input_buf;
192
+ MetalBuffer weight_buf;
193
+ MetalBuffer bias_buf;
194
+ MetalBuffer out_buf;
195
+ MetalBuffer control_buf;
196
+
197
+ run_metal_kernel(kernel_symbol, [&](const gptoss_metal_device& device,
198
+ const gptoss_metal_function& fn,
199
+ gptoss_metal_command_buffer& cb) {
200
+ input_buf.wrap(&device, in_bytes, input_cpu.data_ptr());
201
+ weight_buf.wrap(&device, weight_bytes, weight_cpu.data_ptr());
202
+ bias_buf.wrap(&device, bias_bytes, bias_cpu.data_ptr());
203
+ out_buf.create(&device, out_bytes, nullptr);
204
+ create_control_buffer(&device, control_buf);
205
+
206
+ TORCH_CHECK(
207
+ launch_fn(
208
+ &cb, &fn,
209
+ input_buf.get(), 0,
210
+ weight_buf.get(), 0,
211
+ bias_buf.get(), 0,
212
+ out_buf.get(), 0,
213
+ control_buf.get(), 0,
214
+ static_cast<uint32_t>(num_tokens),
215
+ static_cast<uint32_t>(num_cols),
216
+ static_cast<uint32_t>(num_rows)) == gptoss_status_success,
217
+ "encode dense matmul failed");
218
+ });
219
+
220
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
221
+ copy_back(output, out_cpu);
222
+ }
223
+
224
+ } // namespace
225
 
226
  void f32_bf16w_matmul_torch(const at::Tensor &input,
227
  const at::Tensor &weight_bf16,
 
246
  TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_rows,
247
  "output shape must be [num_tokens, num_rows]");
248
 
249
+ auto input_cpu = to_cpu_contiguous(input);
250
  auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
251
+ auto bias_cpu = to_cpu_contiguous(bias_bf16);
252
+ auto out_cpu = empty_cpu_like(output);
253
 
254
+ const size_t in_bytes = static_cast<size_t>(num_tokens) * static_cast<size_t>(num_cols) * sizeof(float);
255
+ const size_t wt_bytes = static_cast<size_t>(num_rows) * static_cast<size_t>(num_cols) * sizeof(uint16_t);
256
+ const size_t bs_bytes = static_cast<size_t>(num_rows) * sizeof(uint16_t);
257
+ const size_t out_bytes = static_cast<size_t>(num_tokens) * static_cast<size_t>(num_rows) * sizeof(float);
258
+
259
+ MetalBuffer in_buf;
260
+ MetalBuffer wt_buf;
261
+ MetalBuffer bs_buf;
262
+ MetalBuffer out_buf;
263
+ MetalBuffer ctrl_buf;
264
+
265
+ run_metal_kernel("gptoss_f32_bf16w_matmul", [&](const gptoss_metal_device& device,
266
+ const gptoss_metal_function& fn,
267
+ gptoss_metal_command_buffer& cb) {
268
+ in_buf.wrap(&device, in_bytes, input_cpu.data_ptr());
269
+ wt_buf.wrap(&device, wt_bytes, weight_cpu.data_ptr());
270
+ bs_buf.wrap(&device, bs_bytes, bias_cpu.data_ptr());
271
+ out_buf.create(&device, out_bytes, nullptr);
272
+ create_control_buffer(&device, ctrl_buf);
273
+
274
+ TORCH_CHECK(
275
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
276
+ &cb, &fn, static_cast<size_t>(threadgroup_size),
277
+ in_buf.get(), 0,
278
+ wt_buf.get(), 0,
279
+ bs_buf.get(), 0,
280
+ out_buf.get(), 0,
281
+ ctrl_buf.get(), 0,
282
+ static_cast<uint32_t>(num_tokens),
283
+ static_cast<uint32_t>(num_cols),
284
+ static_cast<uint32_t>(num_rows)) == gptoss_status_success,
285
+ "encode failed");
286
+ });
287
+
288
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
289
+ copy_back(output, out_cpu);
290
+ }
291
+
292
+ void bf16_f32_embeddings_torch(const at::Tensor& token_ids,
293
+ const at::Tensor& weight_bf16,
294
+ at::Tensor& output,
295
+ int64_t threadgroup_size)
296
+ {
297
+ TORCH_CHECK(token_ids.dtype() == at::kInt || token_ids.dtype() == at::kLong,
298
+ "token_ids must be int32 or int64");
299
+ TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
300
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
301
+
302
+ TORCH_CHECK(token_ids.dim() == 1, "token_ids must be 1D");
303
+ TORCH_CHECK(weight_bf16.dim() == 2, "weight must be 2D");
304
+ TORCH_CHECK(output.dim() == 2, "output must be 2D");
305
+
306
+ const int64_t num_tokens = token_ids.size(0);
307
+ TORCH_CHECK(output.size(0) == num_tokens, "output first dimension must match num_tokens");
308
+ const int64_t num_channels = output.size(1);
309
+ TORCH_CHECK(num_channels % 4 == 0, "num_channels must be divisible by 4");
310
+ TORCH_CHECK(weight_bf16.size(1) == num_channels,
311
+ "weight second dimension must equal embedding dimension (num_channels)");
312
+
313
+ TORCH_CHECK(threadgroup_size >= 0, "threadgroup_size must be non-negative");
314
+
315
+ auto tokens_cpu = token_ids.dtype() == at::kInt
316
+ ? to_cpu_contiguous(token_ids)
317
+ : token_ids.to(at::kInt).contiguous().to(at::kCPU);
318
+ auto weight_cpu = to_cpu_contiguous(weight_bf16);
319
+ auto out_cpu = empty_cpu_like(output);
320
+
321
+ const size_t token_bytes = static_cast<size_t>(num_tokens) * sizeof(uint32_t);
322
+ const size_t weight_bytes = static_cast<size_t>(weight_cpu.numel()) * weight_cpu.element_size();
323
+ const size_t out_bytes = static_cast<size_t>(out_cpu.numel()) * out_cpu.element_size();
324
+
325
+ MetalBuffer tokens_buf;
326
+ MetalBuffer weight_buf;
327
+ MetalBuffer out_buf;
328
+ MetalBuffer control_buf;
329
+
330
+ run_metal_kernel("gptoss_bf16_f32_embeddings", [&](const gptoss_metal_device& device,
331
+ const gptoss_metal_function& fn,
332
+ gptoss_metal_command_buffer& cb) {
333
+ tokens_buf.wrap(&device, token_bytes, tokens_cpu.data_ptr());
334
+ weight_buf.wrap(&device, weight_bytes, weight_cpu.data_ptr());
335
+ out_buf.create(&device, out_bytes, nullptr);
336
+ create_control_buffer(&device, control_buf);
337
+
338
+ TORCH_CHECK(
339
+ gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
340
+ &cb, &fn, static_cast<size_t>(threadgroup_size),
341
+ tokens_buf.get(), 0,
342
+ weight_buf.get(), 0,
343
+ out_buf.get(), 0,
344
+ control_buf.get(), 0,
345
+ static_cast<uint32_t>(num_tokens),
346
+ static_cast<uint32_t>(num_channels)) == gptoss_status_success,
347
+ "encode embeddings failed");
348
+ });
349
+
350
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
351
+ copy_back(output, out_cpu);
352
+ }
353
+
354
+ void f32_bf16w_rmsnorm_torch(const at::Tensor& input,
355
+ const at::Tensor& weight_bf16,
356
+ at::Tensor& output,
357
+ double epsilon)
358
+ {
359
+ TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
360
+ TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
361
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
362
+
363
+ TORCH_CHECK(input.dim() == 2, "input must be 2D");
364
+ TORCH_CHECK(weight_bf16.dim() == 1, "weight must be 1D");
365
+ TORCH_CHECK(output.dim() == 2, "output must be 2D");
366
+
367
+ const int64_t num_tokens = input.size(0);
368
+ const int64_t num_channels = input.size(1);
369
+ TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_channels,
370
+ "output shape must match input shape");
371
+ TORCH_CHECK(weight_bf16.size(0) == num_channels,
372
+ "weight length must equal number of channels");
373
+ TORCH_CHECK(num_channels % 4 == 0, "num_channels must be divisible by 4");
374
+
375
+ auto input_cpu = to_cpu_contiguous(input);
376
+ auto weight_cpu = to_cpu_contiguous(weight_bf16);
377
+ auto out_cpu = empty_cpu_like(output);
378
+
379
+ const size_t in_bytes = static_cast<size_t>(input_cpu.numel()) * input_cpu.element_size();
380
+ const size_t weight_bytes = static_cast<size_t>(weight_cpu.numel()) * weight_cpu.element_size();
381
+ const size_t out_bytes = static_cast<size_t>(out_cpu.numel()) * out_cpu.element_size();
382
+
383
+ MetalBuffer input_buf;
384
+ MetalBuffer weight_buf;
385
+ MetalBuffer out_buf;
386
+ MetalBuffer control_buf;
387
+
388
+ run_metal_kernel("gptoss_f32_bf16w_rmsnorm", [&](const gptoss_metal_device& device,
389
+ const gptoss_metal_function& fn,
390
+ gptoss_metal_command_buffer& cb) {
391
+ input_buf.wrap(&device, in_bytes, input_cpu.data_ptr());
392
+ weight_buf.wrap(&device, weight_bytes, weight_cpu.data_ptr());
393
+ out_buf.create(&device, out_bytes, nullptr);
394
+ create_control_buffer(&device, control_buf);
395
+
396
+ TORCH_CHECK(
397
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
398
+ &cb, &fn,
399
+ input_buf.get(), 0,
400
+ weight_buf.get(), 0,
401
+ out_buf.get(), 0,
402
+ control_buf.get(), 0,
403
+ static_cast<uint32_t>(num_tokens),
404
+ static_cast<uint32_t>(num_channels),
405
+ static_cast<float>(epsilon)) == gptoss_status_success,
406
+ "encode rmsnorm failed");
407
+ });
408
+
409
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
410
+ copy_back(output, out_cpu);
411
+ }
412
+
413
+ void f32_bf16w_dense_matmul_qkv_torch(const at::Tensor& input,
414
+ const at::Tensor& weight_bf16,
415
+ const at::Tensor& bias_bf16,
416
+ at::Tensor& output)
417
+ {
418
+ run_dense_matmul_bf16(
419
+ "gptoss_f32_bf16w_dense_matmul_qkv",
420
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv,
421
+ input, weight_bf16, bias_bf16, output);
422
+ }
423
+
424
+ void f32_bf16w_dense_matmul_attn_output_torch(const at::Tensor& input,
425
+ const at::Tensor& weight_bf16,
426
+ const at::Tensor& bias_bf16,
427
+ at::Tensor& output)
428
+ {
429
+ run_dense_matmul_bf16(
430
+ "gptoss_f32_bf16w_dense_matmul_attn_output",
431
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output,
432
+ input, weight_bf16, bias_bf16, output);
433
+ }
434
+
435
+ void f32_bf16w_dense_matmul_mlp_gate_torch(const at::Tensor& input,
436
+ const at::Tensor& weight_bf16,
437
+ const at::Tensor& bias_bf16,
438
+ at::Tensor& output)
439
+ {
440
+ run_dense_matmul_bf16(
441
+ "gptoss_f32_bf16w_dense_matmul_mlp_gate",
442
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate,
443
+ input, weight_bf16, bias_bf16, output);
444
+ }
445
+
446
+ void f32_rope_torch(at::Tensor& activations,
447
+ double rope_base,
448
+ double interpolation_scale,
449
+ double yarn_offset,
450
+ double yarn_scale,
451
+ double yarn_multiplier,
452
+ int64_t num_tokens,
453
+ int64_t num_q_heads,
454
+ int64_t num_kv_heads,
455
+ int64_t attn_head_dim,
456
+ int64_t token_offset,
457
+ int64_t threadgroup_size)
458
+ {
459
+ TORCH_CHECK(activations.dtype() == at::kFloat, "activations must be float32");
460
+ TORCH_CHECK(num_tokens >= 0 && num_q_heads >= 0 && num_kv_heads >= 0 && attn_head_dim >= 0,
461
+ "shape parameters must be non-negative");
462
+ TORCH_CHECK(threadgroup_size >= 0, "threadgroup_size must be non-negative");
463
+
464
+ auto activations_cpu = to_cpu_contiguous(activations);
465
+ MetalBuffer activations_buf;
466
+ MetalBuffer control_buf;
467
+
468
+ const size_t activations_bytes = static_cast<size_t>(activations_cpu.numel()) * activations_cpu.element_size();
469
+
470
+ run_metal_kernel("gptoss_f32_rope", [&](const gptoss_metal_device& device,
471
+ const gptoss_metal_function& fn,
472
+ gptoss_metal_command_buffer& cb) {
473
+ activations_buf.wrap(&device, activations_bytes, activations_cpu.data_ptr());
474
+ create_control_buffer(&device, control_buf);
475
+
476
+ TORCH_CHECK(
477
+ gptoss_metal_command_buffer_encode_launch_f32_rope(
478
+ &cb, &fn,
479
+ static_cast<size_t>(threadgroup_size),
480
+ activations_buf.get(), 0,
481
+ control_buf.get(), 0,
482
+ static_cast<float>(rope_base),
483
+ static_cast<float>(interpolation_scale),
484
+ static_cast<float>(yarn_offset),
485
+ static_cast<float>(yarn_scale),
486
+ static_cast<float>(yarn_multiplier),
487
+ static_cast<uint32_t>(num_tokens),
488
+ static_cast<uint32_t>(num_q_heads),
489
+ static_cast<uint32_t>(num_kv_heads),
490
+ static_cast<uint32_t>(attn_head_dim),
491
+ static_cast<uint32_t>(token_offset)) == gptoss_status_success,
492
+ "encode rope failed");
493
+ });
494
+
495
+ copy_back(activations, activations_cpu);
496
+ }
497
+
498
+ void f32_bf16w_matmul_qkv_torch(const at::Tensor& input,
499
+ const at::Tensor& weight_bf16,
500
+ const at::Tensor& bias_bf16,
501
+ at::Tensor& output,
502
+ at::Tensor& kv_cache,
503
+ int64_t kv_cache_offset_bytes,
504
+ int64_t num_tokens,
505
+ int64_t num_cols,
506
+ int64_t num_q_heads,
507
+ int64_t num_kv_heads,
508
+ int64_t attn_head_dim,
509
+ int64_t token_offset,
510
+ int64_t max_tokens,
511
+ double rope_base,
512
+ double interpolation_scale,
513
+ double yarn_offset,
514
+ double yarn_scale,
515
+ double yarn_multiplier,
516
+ int64_t threadgroup_size)
517
+ {
518
+ TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
519
+ TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
520
+ TORCH_CHECK(bias_bf16.dtype() == at::kBFloat16, "bias must be bfloat16");
521
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
522
+ TORCH_CHECK(kv_cache.dtype() == at::kFloat, "kv_cache must be float32");
523
+
524
+ TORCH_CHECK(input.dim() == 2, "input must be 2D");
525
+ TORCH_CHECK(weight_bf16.dim() == 2, "weight must be 2D");
526
+ TORCH_CHECK(bias_bf16.dim() == 1, "bias must be 1D");
527
+ TORCH_CHECK(output.dim() == 2, "output must be 2D");
528
+
529
+ TORCH_CHECK(num_tokens >= 0 && num_cols >= 0 && num_q_heads >= 0 && num_kv_heads >= 0 && attn_head_dim >= 0 && max_tokens >= 0,
530
+ "shape parameters must be non-negative");
531
+ TORCH_CHECK(threadgroup_size >= 0, "threadgroup_size must be non-negative");
532
+ TORCH_CHECK(kv_cache_offset_bytes >= 0, "kv_cache_offset_bytes must be non-negative");
533
+
534
+ TORCH_CHECK(input.size(0) == num_tokens && input.size(1) == num_cols,
535
+ "input shape must be [num_tokens, num_cols]");
536
+ const int64_t num_rows = (num_q_heads + 2 * num_kv_heads) * attn_head_dim;
537
+ TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_rows,
538
+ "output shape must be [num_tokens, (num_q_heads + 2 * num_kv_heads) * attn_head_dim]");
539
+ TORCH_CHECK(weight_bf16.size(0) == num_cols && weight_bf16.size(1) == num_rows,
540
+ "weight shape must be [num_cols, (num_q_heads + 2 * num_kv_heads) * attn_head_dim]");
541
+ TORCH_CHECK(bias_bf16.size(0) == num_rows,
542
+ "bias length must equal output feature dimension");
543
+
544
+ auto input_cpu = to_cpu_contiguous(input);
545
+ auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
546
+ auto bias_cpu = to_cpu_contiguous(bias_bf16);
547
+ auto out_cpu = empty_cpu_like(output);
548
+ auto kv_cpu = to_cpu_contiguous(kv_cache);
549
+
550
+ const size_t in_bytes = static_cast<size_t>(input_cpu.numel()) * input_cpu.element_size();
551
+ const size_t weight_bytes = static_cast<size_t>(weight_cpu.numel()) * weight_cpu.element_size();
552
+ const size_t bias_bytes = static_cast<size_t>(bias_cpu.numel()) * bias_cpu.element_size();
553
+ const size_t out_bytes = static_cast<size_t>(out_cpu.numel()) * out_cpu.element_size();
554
+ const size_t kv_bytes = static_cast<size_t>(kv_cpu.numel()) * kv_cpu.element_size();
555
+
556
+ MetalBuffer input_buf;
557
+ MetalBuffer weight_buf;
558
+ MetalBuffer bias_buf;
559
+ MetalBuffer out_buf;
560
+ MetalBuffer kv_buf;
561
+ MetalBuffer control_buf;
562
+
563
+ run_metal_kernel("gptoss_f32_bf16w_matmul_qkv", [&](const gptoss_metal_device& device,
564
+ const gptoss_metal_function& fn,
565
+ gptoss_metal_command_buffer& cb) {
566
+ input_buf.wrap(&device, in_bytes, input_cpu.data_ptr());
567
+ weight_buf.wrap(&device, weight_bytes, weight_cpu.data_ptr());
568
+ bias_buf.wrap(&device, bias_bytes, bias_cpu.data_ptr());
569
+ out_buf.create(&device, out_bytes, nullptr);
570
+ kv_buf.wrap(&device, kv_bytes, kv_cpu.data_ptr());
571
+ create_control_buffer(&device, control_buf);
572
+
573
+ TORCH_CHECK(
574
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
575
+ &cb, &fn,
576
+ static_cast<size_t>(threadgroup_size),
577
+ input_buf.get(), 0,
578
+ weight_buf.get(), 0,
579
+ bias_buf.get(), 0,
580
+ out_buf.get(), 0,
581
+ kv_buf.get(), static_cast<size_t>(kv_cache_offset_bytes),
582
+ control_buf.get(), 0,
583
+ static_cast<uint32_t>(num_tokens),
584
+ static_cast<uint32_t>(num_cols),
585
+ static_cast<uint32_t>(num_q_heads),
586
+ static_cast<uint32_t>(num_kv_heads),
587
+ static_cast<uint32_t>(attn_head_dim),
588
+ static_cast<uint32_t>(token_offset),
589
+ static_cast<uint32_t>(max_tokens),
590
+ static_cast<float>(rope_base),
591
+ static_cast<float>(interpolation_scale),
592
+ static_cast<float>(yarn_offset),
593
+ static_cast<float>(yarn_scale),
594
+ static_cast<float>(yarn_multiplier)) == gptoss_status_success,
595
+ "encode matmul_qkv failed");
596
+ });
597
+
598
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
599
+ copy_back(output, out_cpu);
600
+ copy_back(kv_cache, kv_cpu);
601
+ }
602
+
603
+ void f32_sdpa_torch(const at::Tensor& q,
604
+ int64_t q_offset_bytes,
605
+ const at::Tensor& kv,
606
+ int64_t kv_offset_bytes,
607
+ const at::Tensor& s_bf16,
608
+ int64_t s_offset_bytes,
609
+ at::Tensor& output,
610
+ int64_t output_offset_bytes,
611
+ int64_t window,
612
+ int64_t kv_stride,
613
+ int64_t num_q_tokens,
614
+ int64_t num_kv_tokens,
615
+ int64_t num_q_heads,
616
+ int64_t num_kv_heads,
617
+ int64_t head_dim)
618
+ {
619
+ TORCH_CHECK(q.dtype() == at::kFloat, "q must be float32");
620
+ TORCH_CHECK(kv.dtype() == at::kFloat, "kv must be float32");
621
+ TORCH_CHECK(s_bf16.dtype() == at::kBFloat16, "s must be bfloat16");
622
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
623
+
624
+ TORCH_CHECK(q_offset_bytes >= 0 && kv_offset_bytes >= 0 && s_offset_bytes >= 0 && output_offset_bytes >= 0,
625
+ "offsets must be non-negative");
626
+ TORCH_CHECK(window >= 0 && kv_stride >= 0 && num_q_tokens >= 0 && num_kv_tokens >= 0 && num_q_heads >= 0 && num_kv_heads >= 0 && head_dim >= 0,
627
+ "shape parameters must be non-negative");
628
+
629
+ auto q_cpu = to_cpu_contiguous(q);
630
+ auto kv_cpu = to_cpu_contiguous(kv);
631
+ auto s_cpu = to_cpu_contiguous(s_bf16);
632
+ auto out_cpu = empty_cpu_like(output);
633
+
634
+ const size_t q_bytes = static_cast<size_t>(q_cpu.numel()) * q_cpu.element_size();
635
+ const size_t kv_bytes = static_cast<size_t>(kv_cpu.numel()) * kv_cpu.element_size();
636
+ const size_t s_bytes = static_cast<size_t>(s_cpu.numel()) * s_cpu.element_size();
637
+ const size_t out_bytes = static_cast<size_t>(out_cpu.numel()) * out_cpu.element_size();
638
+
639
+ MetalBuffer q_buf;
640
+ MetalBuffer kv_buf;
641
+ MetalBuffer s_buf;
642
+ MetalBuffer out_buf;
643
+ MetalBuffer control_buf;
644
+
645
+ run_metal_kernel("gptoss_f32_sdpa_q8_d64", [&](const gptoss_metal_device& device,
646
+ const gptoss_metal_function& fn,
647
+ gptoss_metal_command_buffer& cb) {
648
+ q_buf.wrap(&device, q_bytes, q_cpu.data_ptr());
649
+ kv_buf.wrap(&device, kv_bytes, kv_cpu.data_ptr());
650
+ s_buf.wrap(&device, s_bytes, s_cpu.data_ptr());
651
+ out_buf.create(&device, out_bytes, nullptr);
652
+ create_control_buffer(&device, control_buf);
653
+
654
+ TORCH_CHECK(
655
+ gptoss_metal_command_buffer_encode_launch_f32_sdpa(
656
+ &cb, &fn,
657
+ q_buf.get(), static_cast<size_t>(q_offset_bytes),
658
+ kv_buf.get(), static_cast<size_t>(kv_offset_bytes),
659
+ s_buf.get(), static_cast<size_t>(s_offset_bytes),
660
+ out_buf.get(), static_cast<size_t>(output_offset_bytes),
661
+ control_buf.get(), 0,
662
+ static_cast<uint32_t>(window),
663
+ static_cast<uint32_t>(kv_stride),
664
+ static_cast<uint32_t>(num_q_tokens),
665
+ static_cast<uint32_t>(num_kv_tokens),
666
+ static_cast<uint32_t>(num_q_heads),
667
+ static_cast<uint32_t>(num_kv_heads),
668
+ static_cast<uint32_t>(head_dim)) == gptoss_status_success,
669
+ "encode sdpa failed");
670
+ });
671
+
672
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
673
+ copy_back(output, out_cpu);
674
+ }
675
+
676
+ void f32_topk_torch(const at::Tensor& scores,
677
+ at::Tensor& expert_ids,
678
+ at::Tensor& expert_scores,
679
+ int64_t num_tokens,
680
+ int64_t num_experts,
681
+ int64_t num_active_experts)
682
+ {
683
+ TORCH_CHECK(scores.dtype() == at::kFloat, "scores must be float32");
684
+ TORCH_CHECK(expert_ids.dtype() == at::kInt, "expert_ids must be int32");
685
+ TORCH_CHECK(expert_scores.dtype() == at::kFloat, "expert_scores must be float32");
686
+
687
+ TORCH_CHECK(num_tokens >= 0 && num_experts >= 0 && num_active_experts >= 0,
688
+ "shape parameters must be non-negative");
689
+
690
+ TORCH_CHECK(scores.size(0) == num_tokens,
691
+ "scores first dimension must match num_tokens");
692
+ TORCH_CHECK(scores.numel() == num_tokens * num_experts,
693
+ "scores must have num_tokens * num_experts elements");
694
+ TORCH_CHECK(expert_ids.numel() == num_tokens * num_active_experts,
695
+ "expert_ids must have num_tokens * num_active_experts elements");
696
+ TORCH_CHECK(expert_scores.numel() == num_tokens * num_active_experts,
697
+ "expert_scores must have num_tokens * num_active_experts elements");
698
+
699
+ auto scores_cpu = to_cpu_contiguous(scores);
700
+ std::vector<gptoss_expert_prediction> predictions(static_cast<size_t>(num_tokens) * static_cast<size_t>(num_active_experts));
701
+
702
+ const size_t score_bytes = static_cast<size_t>(scores_cpu.numel()) * scores_cpu.element_size();
703
+ const size_t pred_bytes = predictions.size() * sizeof(gptoss_expert_prediction);
704
+
705
+ MetalBuffer score_buf;
706
+ MetalBuffer pred_buf;
707
+ MetalBuffer control_buf;
708
+
709
+ run_metal_kernel("gptoss_f32_topk_softmax_e128_k4", [&](const gptoss_metal_device& device,
710
+ const gptoss_metal_function& fn,
711
+ gptoss_metal_command_buffer& cb) {
712
+ score_buf.wrap(&device, score_bytes, scores_cpu.data_ptr());
713
+ pred_buf.wrap(&device, pred_bytes, predictions.data());
714
+ create_control_buffer(&device, control_buf);
715
+
716
+ TORCH_CHECK(
717
+ gptoss_metal_command_buffer_encode_launch_f32_topk(
718
+ &cb, &fn,
719
+ score_buf.get(), 0,
720
+ pred_buf.get(), 0,
721
+ control_buf.get(), 0,
722
+ static_cast<uint32_t>(num_tokens),
723
+ static_cast<uint32_t>(num_experts),
724
+ static_cast<uint32_t>(num_active_experts)) == gptoss_status_success,
725
+ "encode topk failed");
726
+ });
727
+
728
+ auto ids_cpu = expert_ids.to(at::kCPU).contiguous();
729
+ auto scores_out_cpu = expert_scores.to(at::kCPU).contiguous();
730
+ auto* ids_ptr = ids_cpu.data_ptr<int32_t>();
731
+ auto* scores_ptr = scores_out_cpu.data_ptr<float>();
732
+ const size_t total = predictions.size();
733
+ for (size_t i = 0; i < total; ++i) {
734
+ ids_ptr[i] = static_cast<int32_t>(predictions[i].expert_id);
735
+ scores_ptr[i] = predictions[i].score;
736
+ }
737
+ copy_back(expert_ids, ids_cpu);
738
+ copy_back(expert_scores, scores_out_cpu);
739
+ }
740
+
741
+ void expert_routing_metadata_torch(const at::Tensor& expert_ids,
742
+ const at::Tensor& expert_scores,
743
+ at::Tensor& expert_offsets,
744
+ at::Tensor& intra_expert_offsets,
745
+ int64_t num_tokens,
746
+ int64_t num_experts)
747
+ {
748
+ TORCH_CHECK(expert_ids.dtype() == at::kInt, "expert_ids must be int32");
749
+ TORCH_CHECK(expert_scores.dtype() == at::kFloat, "expert_scores must be float32");
750
+ TORCH_CHECK(expert_offsets.dtype() == at::kInt, "expert_offsets must be int32");
751
+ TORCH_CHECK(intra_expert_offsets.dtype() == at::kInt, "intra_expert_offsets must be int32");
752
+
753
+ TORCH_CHECK(num_tokens >= 0 && num_experts >= 0, "shape parameters must be non-negative");
754
+ TORCH_CHECK(expert_ids.numel() == num_tokens,
755
+ "expert_ids must have num_tokens elements");
756
+ TORCH_CHECK(expert_scores.numel() == num_tokens,
757
+ "expert_scores must have num_tokens elements");
758
+ TORCH_CHECK(intra_expert_offsets.numel() == num_tokens,
759
+ "intra_expert_offsets must have num_tokens elements");
760
+ TORCH_CHECK(expert_offsets.numel() == num_experts + 1,
761
+ "expert_offsets must have num_experts + 1 elements");
762
+
763
+ auto ids_cpu = to_cpu_contiguous(expert_ids);
764
+ auto scores_cpu = to_cpu_contiguous(expert_scores);
765
+ auto offsets_cpu = to_cpu_contiguous(expert_offsets);
766
+ auto intra_offsets_cpu = to_cpu_contiguous(intra_expert_offsets);
767
+
768
+ std::vector<gptoss_expert_prediction> predictions(static_cast<size_t>(num_tokens));
769
+ const auto* ids_ptr = ids_cpu.data_ptr<int32_t>();
770
+ const auto* scores_ptr = scores_cpu.data_ptr<float>();
771
+ for (int64_t i = 0; i < num_tokens; ++i) {
772
+ predictions[static_cast<size_t>(i)] = gptoss_expert_prediction {
773
+ .expert_id = static_cast<uint32_t>(ids_ptr[i]),
774
+ .score = scores_ptr[i],
775
+ };
776
+ }
777
+
778
+ const size_t pred_bytes = predictions.size() * sizeof(gptoss_expert_prediction);
779
+ const size_t offsets_bytes = static_cast<size_t>(offsets_cpu.numel()) * offsets_cpu.element_size();
780
+ const size_t intra_bytes = static_cast<size_t>(intra_offsets_cpu.numel()) * intra_offsets_cpu.element_size();
781
+
782
+ MetalBuffer pred_buf;
783
+ MetalBuffer offsets_buf;
784
+ MetalBuffer intra_offsets_buf;
785
+
786
+ run_metal_kernel("gptoss_f32_expert_routing_metadata", [&](const gptoss_metal_device& device,
787
+ const gptoss_metal_function& fn,
788
+ gptoss_metal_command_buffer& cb) {
789
+ pred_buf.wrap(&device, pred_bytes, predictions.data());
790
+ offsets_buf.wrap(&device, offsets_bytes, offsets_cpu.data_ptr());
791
+ intra_offsets_buf.wrap(&device, intra_bytes, intra_offsets_cpu.data_ptr());
792
+
793
+ TORCH_CHECK(
794
+ gptoss_metal_command_buffer_encode_launch_expert_routing_metadata(
795
+ &cb, &fn,
796
+ pred_buf.get(), 0,
797
+ offsets_buf.get(), 0,
798
+ intra_offsets_buf.get(), 0,
799
+ static_cast<uint32_t>(num_tokens),
800
+ static_cast<uint32_t>(num_experts)) == gptoss_status_success,
801
+ "encode expert_routing_metadata failed");
802
+ });
803
+
804
+ copy_back(expert_offsets, offsets_cpu);
805
+ copy_back(intra_expert_offsets, intra_offsets_cpu);
806
+ }
807
+
808
+ void f32_scatter_torch(const at::Tensor& input,
809
+ const at::Tensor& expert_ids,
810
+ const at::Tensor& expert_scores,
811
+ const at::Tensor& expert_offsets,
812
+ const at::Tensor& intra_expert_offsets,
813
+ at::Tensor& output,
814
+ int64_t num_channels,
815
+ int64_t num_tokens,
816
+ int64_t num_active_experts)
817
+ {
818
+ TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
819
+ TORCH_CHECK(expert_ids.dtype() == at::kInt, "expert_ids must be int32");
820
+ TORCH_CHECK(expert_scores.dtype() == at::kFloat, "expert_scores must be float32");
821
+ TORCH_CHECK(expert_offsets.dtype() == at::kInt, "expert_offsets must be int32");
822
+ TORCH_CHECK(intra_expert_offsets.dtype() == at::kInt, "intra_expert_offsets must be int32");
823
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
824
+
825
+ TORCH_CHECK(num_channels >= 0 && num_tokens >= 0 && num_active_experts >= 0,
826
+ "shape parameters must be non-negative");
827
+
828
+ TORCH_CHECK(input.numel() == static_cast<int64_t>(num_tokens / num_active_experts) * num_channels,
829
+ "input size mismatch");
830
+ TORCH_CHECK(expert_ids.numel() == num_tokens,
831
+ "expert_ids must have num_tokens elements");
832
+ TORCH_CHECK(expert_scores.numel() == num_tokens,
833
+ "expert_scores must have num_tokens elements");
834
+ TORCH_CHECK(intra_expert_offsets.numel() == num_tokens,
835
+ "intra_expert_offsets must have num_tokens elements");
836
+ TORCH_CHECK(output.numel() == num_tokens * num_channels / num_active_experts,
837
+ "output size mismatch");
838
+
839
+ auto input_cpu = to_cpu_contiguous(input);
840
+ auto expert_offsets_cpu = to_cpu_contiguous(expert_offsets);
841
+ auto intra_offsets_cpu = to_cpu_contiguous(intra_expert_offsets);
842
+ auto output_cpu = empty_cpu_like(output);
843
+
844
+ std::vector<gptoss_expert_prediction> predictions(static_cast<size_t>(num_tokens));
845
+ const auto* ids_ptr = expert_ids.to(at::kCPU).contiguous().data_ptr<int32_t>();
846
+ const auto* scores_ptr = expert_scores.to(at::kCPU).contiguous().data_ptr<float>();
847
+ for (int64_t i = 0; i < num_tokens; ++i) {
848
+ predictions[static_cast<size_t>(i)] = gptoss_expert_prediction {
849
+ .expert_id = static_cast<uint32_t>(ids_ptr[i]),
850
+ .score = scores_ptr[i],
851
+ };
852
+ }
853
+
854
+ const size_t input_bytes = static_cast<size_t>(input_cpu.numel()) * input_cpu.element_size();
855
+ const size_t pred_bytes = predictions.size() * sizeof(gptoss_expert_prediction);
856
+ const size_t offsets_bytes = static_cast<size_t>(expert_offsets_cpu.numel()) * expert_offsets_cpu.element_size();
857
+ const size_t intra_bytes = static_cast<size_t>(intra_offsets_cpu.numel()) * intra_offsets_cpu.element_size();
858
+ const size_t output_bytes = static_cast<size_t>(output_cpu.numel()) * output_cpu.element_size();
859
+
860
+ MetalBuffer input_buf;
861
+ MetalBuffer pred_buf;
862
+ MetalBuffer offsets_buf;
863
+ MetalBuffer intra_offsets_buf;
864
+ MetalBuffer output_buf;
865
+
866
+ run_metal_kernel("gptoss_f32_scatter_e4", [&](const gptoss_metal_device& device,
867
+ const gptoss_metal_function& fn,
868
+ gptoss_metal_command_buffer& cb) {
869
+ input_buf.wrap(&device, input_bytes, input_cpu.data_ptr());
870
+ pred_buf.wrap(&device, pred_bytes, predictions.data());
871
+ offsets_buf.wrap(&device, offsets_bytes, expert_offsets_cpu.data_ptr());
872
+ intra_offsets_buf.wrap(&device, intra_bytes, intra_offsets_cpu.data_ptr());
873
+ output_buf.create(&device, output_bytes, nullptr);
874
+
875
+ TORCH_CHECK(
876
+ gptoss_metal_command_buffer_encode_launch_f32_scatter(
877
+ &cb, &fn,
878
+ input_buf.get(), 0,
879
+ pred_buf.get(), 0,
880
+ offsets_buf.get(), 0,
881
+ intra_offsets_buf.get(), 0,
882
+ output_buf.get(), 0,
883
+ static_cast<uint32_t>(num_channels),
884
+ static_cast<uint32_t>(num_tokens / num_active_experts),
885
+ static_cast<uint32_t>(num_active_experts)) == gptoss_status_success,
886
+ "encode scatter failed");
887
+ });
888
+
889
+ std::memcpy(output_cpu.data_ptr(), output_buf.ptr(), output_bytes);
890
+ copy_back(output, output_cpu);
891
+ }
892
+
893
+ void f32_bf16w_matmul_add_torch(const at::Tensor& input,
894
+ const at::Tensor& weight_bf16,
895
+ const at::Tensor& bias_bf16,
896
+ at::Tensor& output,
897
+ int64_t num_tokens,
898
+ int64_t num_cols,
899
+ int64_t num_rows,
900
+ int64_t threadgroup_size)
901
+ {
902
+ TORCH_CHECK(input.dtype() == at::kFloat, "input must be float32");
903
+ TORCH_CHECK(weight_bf16.dtype() == at::kBFloat16, "weight must be bfloat16");
904
+ TORCH_CHECK(bias_bf16.dtype() == at::kBFloat16, "bias must be bfloat16");
905
+ TORCH_CHECK(output.dtype() == at::kFloat, "output must be float32");
906
+
907
+ TORCH_CHECK(input.dim() == 2, "input must be 2D");
908
+ TORCH_CHECK(weight_bf16.dim() == 2, "weight must be 2D");
909
+ TORCH_CHECK(bias_bf16.dim() == 1, "bias must be 1D");
910
+ TORCH_CHECK(output.dim() == 2, "output must be 2D");
911
+
912
+ TORCH_CHECK(input.size(0) == num_tokens && input.size(1) == num_cols,
913
+ "input shape must be [num_tokens, num_cols]");
914
+ TORCH_CHECK(weight_bf16.size(0) == num_cols && weight_bf16.size(1) == num_rows,
915
+ "weight shape must be [num_cols, num_rows]");
916
+ TORCH_CHECK(bias_bf16.size(0) == num_rows,
917
+ "bias length must equal num_rows");
918
+ TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_rows,
919
+ "output shape must be [num_tokens, num_rows]");
920
+
921
+ auto input_cpu = to_cpu_contiguous(input);
922
+ auto weight_cpu = weight_bf16.transpose(0, 1).contiguous().to(at::kCPU);
923
+ auto bias_cpu = to_cpu_contiguous(bias_bf16);
924
+ auto out_cpu = to_cpu_contiguous(output);
925
+
926
+ const size_t in_bytes = static_cast<size_t>(input_cpu.numel()) * input_cpu.element_size();
927
+ const size_t weight_bytes = static_cast<size_t>(weight_cpu.numel()) * weight_cpu.element_size();
928
+ const size_t bias_bytes = static_cast<size_t>(bias_cpu.numel()) * bias_cpu.element_size();
929
+ const size_t out_bytes = static_cast<size_t>(out_cpu.numel()) * out_cpu.element_size();
930
+
931
+ MetalBuffer input_buf;
932
+ MetalBuffer weight_buf;
933
+ MetalBuffer bias_buf;
934
+ MetalBuffer out_buf;
935
+ MetalBuffer control_buf;
936
+
937
+ run_metal_kernel("gptoss_f32_bf16w_matmul", [&](const gptoss_metal_device& device,
938
+ const gptoss_metal_function& fn,
939
+ gptoss_metal_command_buffer& cb) {
940
+ input_buf.wrap(&device, in_bytes, input_cpu.data_ptr());
941
+ weight_buf.wrap(&device, weight_bytes, weight_cpu.data_ptr());
942
+ bias_buf.wrap(&device, bias_bytes, bias_cpu.data_ptr());
943
+ out_buf.create(&device, out_bytes, nullptr);
944
+ std::memcpy(out_buf.ptr(), out_cpu.data_ptr(), out_bytes);
945
+ create_control_buffer(&device, control_buf);
946
+
947
+ TORCH_CHECK(
948
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
949
+ &cb, &fn,
950
+ static_cast<size_t>(threadgroup_size),
951
+ input_buf.get(), 0,
952
+ weight_buf.get(), 0,
953
+ bias_buf.get(), 0,
954
+ out_buf.get(), 0,
955
+ control_buf.get(), 0,
956
+ static_cast<uint32_t>(num_tokens),
957
+ static_cast<uint32_t>(num_cols),
958
+ static_cast<uint32_t>(num_rows)) == gptoss_status_success,
959
+ "encode matmul_add failed");
960
+ });
961
 
962
+ std::memcpy(out_cpu.data_ptr(), out_buf.ptr(), out_bytes);
963
+ copy_back(output, out_cpu);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
  }
torch-ext/gptoss_kernels/__init__.py CHANGED
@@ -1,8 +1,174 @@
1
  from ._ops import ops
2
  import torch
3
 
4
- def f32_bf16w_matmul(input: torch.Tensor, weight_bf16: torch.Tensor, bias_bf16: torch.Tensor, output: torch.Tensor, num_tokens: int, num_cols: int, num_rows: int, threadgroup_size: int) -> None:
5
- ops.f32_bf16w_matmul(input, weight_bf16, bias_bf16, output, num_tokens, num_cols, num_rows, threadgroup_size)
 
 
 
 
 
 
 
 
6
  return output
7
 
8
- __all__ = ["f32_bf16w_matmul"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from ._ops import ops
2
  import torch
3
 
4
+ def f32_bf16w_matmul(input: torch.Tensor,
5
+ weight_bf16: torch.Tensor,
6
+ bias_bf16: torch.Tensor,
7
+ output: torch.Tensor,
8
+ num_tokens: int,
9
+ num_cols: int,
10
+ num_rows: int,
11
+ threadgroup_size: int) -> torch.Tensor:
12
+ ops.f32_bf16w_matmul(input, weight_bf16, bias_bf16, output,
13
+ num_tokens, num_cols, num_rows, threadgroup_size)
14
  return output
15
 
16
+ def bf16_f32_embeddings(token_ids: torch.Tensor,
17
+ weight_bf16: torch.Tensor,
18
+ output: torch.Tensor,
19
+ threadgroup_size: int) -> torch.Tensor:
20
+ ops.bf16_f32_embeddings(token_ids, weight_bf16, output, threadgroup_size)
21
+ return output
22
+
23
+ def f32_bf16w_rmsnorm(input: torch.Tensor,
24
+ weight_bf16: torch.Tensor,
25
+ output: torch.Tensor,
26
+ epsilon: float) -> torch.Tensor:
27
+ ops.f32_bf16w_rmsnorm(input, weight_bf16, output, epsilon)
28
+ return output
29
+
30
+ def f32_bf16w_dense_matmul_qkv(input: torch.Tensor,
31
+ weight_bf16: torch.Tensor,
32
+ bias_bf16: torch.Tensor,
33
+ output: torch.Tensor) -> torch.Tensor:
34
+ ops.f32_bf16w_dense_matmul_qkv(input, weight_bf16, bias_bf16, output)
35
+ return output
36
+
37
+ def f32_bf16w_dense_matmul_attn_output(input: torch.Tensor,
38
+ weight_bf16: torch.Tensor,
39
+ bias_bf16: torch.Tensor,
40
+ output: torch.Tensor) -> torch.Tensor:
41
+ ops.f32_bf16w_dense_matmul_attn_output(input, weight_bf16, bias_bf16, output)
42
+ return output
43
+
44
+ def f32_bf16w_dense_matmul_mlp_gate(input: torch.Tensor,
45
+ weight_bf16: torch.Tensor,
46
+ bias_bf16: torch.Tensor,
47
+ output: torch.Tensor) -> torch.Tensor:
48
+ ops.f32_bf16w_dense_matmul_mlp_gate(input, weight_bf16, bias_bf16, output)
49
+ return output
50
+
51
+ def f32_rope(activations: torch.Tensor,
52
+ rope_base: float,
53
+ interpolation_scale: float,
54
+ yarn_offset: float,
55
+ yarn_scale: float,
56
+ yarn_multiplier: float,
57
+ num_tokens: int,
58
+ num_q_heads: int,
59
+ num_kv_heads: int,
60
+ attn_head_dim: int,
61
+ token_offset: int,
62
+ threadgroup_size: int) -> torch.Tensor:
63
+ ops.f32_rope(activations, rope_base, interpolation_scale, yarn_offset,
64
+ yarn_scale, yarn_multiplier, num_tokens, num_q_heads,
65
+ num_kv_heads, attn_head_dim, token_offset, threadgroup_size)
66
+ return activations
67
+
68
+ def f32_bf16w_matmul_qkv(input: torch.Tensor,
69
+ weight_bf16: torch.Tensor,
70
+ bias_bf16: torch.Tensor,
71
+ output: torch.Tensor,
72
+ kv_cache: torch.Tensor,
73
+ kv_cache_offset_bytes: int,
74
+ num_tokens: int,
75
+ num_cols: int,
76
+ num_q_heads: int,
77
+ num_kv_heads: int,
78
+ attn_head_dim: int,
79
+ token_offset: int,
80
+ max_tokens: int,
81
+ rope_base: float,
82
+ interpolation_scale: float,
83
+ yarn_offset: float,
84
+ yarn_scale: float,
85
+ yarn_multiplier: float,
86
+ threadgroup_size: int) -> torch.Tensor:
87
+ ops.f32_bf16w_matmul_qkv(input, weight_bf16, bias_bf16, output, kv_cache,
88
+ kv_cache_offset_bytes, num_tokens, num_cols,
89
+ num_q_heads, num_kv_heads, attn_head_dim,
90
+ token_offset, max_tokens, rope_base,
91
+ interpolation_scale, yarn_offset, yarn_scale,
92
+ yarn_multiplier, threadgroup_size)
93
+ return output
94
+
95
+ def f32_sdpa(q: torch.Tensor,
96
+ q_offset_bytes: int,
97
+ kv: torch.Tensor,
98
+ kv_offset_bytes: int,
99
+ s_bf16: torch.Tensor,
100
+ s_offset_bytes: int,
101
+ output: torch.Tensor,
102
+ output_offset_bytes: int,
103
+ window: int,
104
+ kv_stride: int,
105
+ num_q_tokens: int,
106
+ num_kv_tokens: int,
107
+ num_q_heads: int,
108
+ num_kv_heads: int,
109
+ head_dim: int) -> torch.Tensor:
110
+ ops.f32_sdpa(q, q_offset_bytes, kv, kv_offset_bytes, s_bf16, s_offset_bytes,
111
+ output, output_offset_bytes, window, kv_stride,
112
+ num_q_tokens, num_kv_tokens, num_q_heads, num_kv_heads, head_dim)
113
+ return output
114
+
115
+ def f32_topk(scores: torch.Tensor,
116
+ expert_ids: torch.Tensor,
117
+ expert_scores: torch.Tensor,
118
+ num_tokens: int,
119
+ num_experts: int,
120
+ num_active_experts: int) -> None:
121
+ ops.f32_topk(scores, expert_ids, expert_scores,
122
+ num_tokens, num_experts, num_active_experts)
123
+
124
+ def expert_routing_metadata(expert_ids: torch.Tensor,
125
+ expert_scores: torch.Tensor,
126
+ expert_offsets: torch.Tensor,
127
+ intra_expert_offsets: torch.Tensor,
128
+ num_tokens: int,
129
+ num_experts: int) -> None:
130
+ ops.expert_routing_metadata(expert_ids, expert_scores,
131
+ expert_offsets, intra_expert_offsets,
132
+ num_tokens, num_experts)
133
+
134
+ def f32_scatter(input: torch.Tensor,
135
+ expert_ids: torch.Tensor,
136
+ expert_scores: torch.Tensor,
137
+ expert_offsets: torch.Tensor,
138
+ intra_expert_offsets: torch.Tensor,
139
+ output: torch.Tensor,
140
+ num_channels: int,
141
+ num_tokens: int,
142
+ num_active_experts: int) -> torch.Tensor:
143
+ ops.f32_scatter(input, expert_ids, expert_scores,
144
+ expert_offsets, intra_expert_offsets,
145
+ output, num_channels, num_tokens, num_active_experts)
146
+ return output
147
+
148
+ def f32_bf16w_matmul_add(input: torch.Tensor,
149
+ weight_bf16: torch.Tensor,
150
+ bias_bf16: torch.Tensor,
151
+ output: torch.Tensor,
152
+ num_tokens: int,
153
+ num_cols: int,
154
+ num_rows: int,
155
+ threadgroup_size: int) -> torch.Tensor:
156
+ ops.f32_bf16w_matmul_add(input, weight_bf16, bias_bf16, output,
157
+ num_tokens, num_cols, num_rows, threadgroup_size)
158
+ return output
159
+
160
+ __all__ = [
161
+ "f32_bf16w_matmul",
162
+ "bf16_f32_embeddings",
163
+ "f32_bf16w_rmsnorm",
164
+ "f32_bf16w_dense_matmul_qkv",
165
+ "f32_bf16w_dense_matmul_attn_output",
166
+ "f32_bf16w_dense_matmul_mlp_gate",
167
+ "f32_rope",
168
+ "f32_bf16w_matmul_qkv",
169
+ "f32_sdpa",
170
+ "f32_topk",
171
+ "expert_routing_metadata",
172
+ "f32_scatter",
173
+ "f32_bf16w_matmul_add",
174
+ ]
torch-ext/gptoss_kernels/_ops.py DELETED
@@ -1,9 +0,0 @@
1
- import torch
2
- from . import _gptoss_kernels_3a886f8_dirty
3
- ops = torch.ops._gptoss_kernels_3a886f8_dirty
4
-
5
- def add_op_namespace_prefix(op_name: str):
6
- """
7
- Prefix op by namespace.
8
- """
9
- return f"_gptoss_kernels_3a886f8_dirty::{op_name}"
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.cpp CHANGED
@@ -3,8 +3,58 @@
3
  #include "registration.h"
4
 
5
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6
- ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
 
7
  ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  }
9
 
10
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
3
  #include "registration.h"
4
 
5
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
6
+ ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, "
7
+ "Tensor! output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
8
  ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch);
9
+
10
+ ops.def("bf16_f32_embeddings(Tensor token_ids, Tensor weight_bf16, Tensor! output, "
11
+ "int threadgroup_size) -> ()");
12
+ ops.impl("bf16_f32_embeddings", torch::kMPS, &bf16_f32_embeddings_torch);
13
+
14
+ ops.def("f32_bf16w_rmsnorm(Tensor input, Tensor weight_bf16, Tensor! output, float epsilon) -> ()");
15
+ ops.impl("f32_bf16w_rmsnorm", torch::kMPS, &f32_bf16w_rmsnorm_torch);
16
+
17
+ ops.def("f32_bf16w_dense_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
18
+ ops.impl("f32_bf16w_dense_matmul_qkv", torch::kMPS, &f32_bf16w_dense_matmul_qkv_torch);
19
+
20
+ ops.def("f32_bf16w_dense_matmul_attn_output(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
21
+ ops.impl("f32_bf16w_dense_matmul_attn_output", torch::kMPS, &f32_bf16w_dense_matmul_attn_output_torch);
22
+
23
+ ops.def("f32_bf16w_dense_matmul_mlp_gate(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
24
+ ops.impl("f32_bf16w_dense_matmul_mlp_gate", torch::kMPS, &f32_bf16w_dense_matmul_mlp_gate_torch);
25
+
26
+ ops.def("f32_rope(Tensor! activations, float rope_base, float interpolation_scale, float yarn_offset, "
27
+ "float yarn_scale, float yarn_multiplier, int num_tokens, int num_q_heads, int num_kv_heads, "
28
+ "int attn_head_dim, int token_offset, int threadgroup_size) -> ()");
29
+ ops.impl("f32_rope", torch::kMPS, &f32_rope_torch);
30
+
31
+ ops.def("f32_bf16w_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, Tensor kv_cache, "
32
+ "int kv_cache_offset_bytes, int num_tokens, int num_cols, int num_q_heads, int num_kv_heads, "
33
+ "int attn_head_dim, int token_offset, int max_tokens, float rope_base, float interpolation_scale, "
34
+ "float yarn_offset, float yarn_scale, float yarn_multiplier, int threadgroup_size) -> ()");
35
+ ops.impl("f32_bf16w_matmul_qkv", torch::kMPS, &f32_bf16w_matmul_qkv_torch);
36
+
37
+ ops.def("f32_sdpa(Tensor q, int q_offset_bytes, Tensor kv, int kv_offset_bytes, Tensor s_bf16, int s_offset_bytes, "
38
+ "Tensor! output, int output_offset_bytes, int window, int kv_stride, int num_q_tokens, int num_kv_tokens, "
39
+ "int num_q_heads, int num_kv_heads, int head_dim) -> ()");
40
+ ops.impl("f32_sdpa", torch::kMPS, &f32_sdpa_torch);
41
+
42
+ ops.def("f32_topk(Tensor scores, Tensor expert_ids, Tensor expert_scores, int num_tokens, int num_experts, "
43
+ "int num_active_experts) -> ()");
44
+ ops.impl("f32_topk", torch::kMPS, &f32_topk_torch);
45
+
46
+ ops.def("expert_routing_metadata(Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, "
47
+ "Tensor intra_expert_offsets, int num_tokens, int num_experts) -> ()");
48
+ ops.impl("expert_routing_metadata", torch::kMPS, &expert_routing_metadata_torch);
49
+
50
+ ops.def("f32_scatter(Tensor input, Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, "
51
+ "Tensor intra_expert_offsets, Tensor! output, int num_channels, int num_tokens, "
52
+ "int num_active_experts) -> ()");
53
+ ops.impl("f32_scatter", torch::kMPS, &f32_scatter_torch);
54
+
55
+ ops.def("f32_bf16w_matmul_add(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, "
56
+ "int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
57
+ ops.impl("f32_bf16w_matmul_add", torch::kMPS, &f32_bf16w_matmul_add_torch);
58
  }
59
 
60
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -2,4 +2,118 @@
2
 
3
  #include <torch/torch.h>
4
 
5
- void f32_bf16w_matmul_torch(const at::Tensor &input, const at::Tensor &weight_bf16, const at::Tensor &bias_bf16, at::Tensor &output, int64_t num_tokens, int64_t num_cols, int64_t num_rows, int64_t threadgroup_size);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ void f32_bf16w_matmul_torch(const at::Tensor& input,
6
+ const at::Tensor& weight_bf16,
7
+ const at::Tensor& bias_bf16,
8
+ at::Tensor& output,
9
+ int64_t num_tokens,
10
+ int64_t num_cols,
11
+ int64_t num_rows,
12
+ int64_t threadgroup_size);
13
+
14
+ void bf16_f32_embeddings_torch(const at::Tensor& token_ids,
15
+ const at::Tensor& weight_bf16,
16
+ at::Tensor& output,
17
+ int64_t threadgroup_size);
18
+
19
+ void f32_bf16w_rmsnorm_torch(const at::Tensor& input,
20
+ const at::Tensor& weight_bf16,
21
+ at::Tensor& output,
22
+ double epsilon);
23
+
24
+ void f32_bf16w_dense_matmul_qkv_torch(const at::Tensor& input,
25
+ const at::Tensor& weight_bf16,
26
+ const at::Tensor& bias_bf16,
27
+ at::Tensor& output);
28
+
29
+ void f32_bf16w_dense_matmul_attn_output_torch(const at::Tensor& input,
30
+ const at::Tensor& weight_bf16,
31
+ const at::Tensor& bias_bf16,
32
+ at::Tensor& output);
33
+
34
+ void f32_bf16w_dense_matmul_mlp_gate_torch(const at::Tensor& input,
35
+ const at::Tensor& weight_bf16,
36
+ const at::Tensor& bias_bf16,
37
+ at::Tensor& output);
38
+
39
+ void f32_rope_torch(at::Tensor& activations,
40
+ double rope_base,
41
+ double interpolation_scale,
42
+ double yarn_offset,
43
+ double yarn_scale,
44
+ double yarn_multiplier,
45
+ int64_t num_tokens,
46
+ int64_t num_q_heads,
47
+ int64_t num_kv_heads,
48
+ int64_t attn_head_dim,
49
+ int64_t token_offset,
50
+ int64_t threadgroup_size);
51
+
52
+ void f32_bf16w_matmul_qkv_torch(const at::Tensor& input,
53
+ const at::Tensor& weight_bf16,
54
+ const at::Tensor& bias_bf16,
55
+ at::Tensor& output,
56
+ at::Tensor& kv_cache,
57
+ int64_t kv_cache_offset_bytes,
58
+ int64_t num_tokens,
59
+ int64_t num_cols,
60
+ int64_t num_q_heads,
61
+ int64_t num_kv_heads,
62
+ int64_t attn_head_dim,
63
+ int64_t token_offset,
64
+ int64_t max_tokens,
65
+ double rope_base,
66
+ double interpolation_scale,
67
+ double yarn_offset,
68
+ double yarn_scale,
69
+ double yarn_multiplier,
70
+ int64_t threadgroup_size);
71
+
72
+ void f32_sdpa_torch(const at::Tensor& q,
73
+ int64_t q_offset_bytes,
74
+ const at::Tensor& kv,
75
+ int64_t kv_offset_bytes,
76
+ const at::Tensor& s_bf16,
77
+ int64_t s_offset_bytes,
78
+ at::Tensor& output,
79
+ int64_t output_offset_bytes,
80
+ int64_t window,
81
+ int64_t kv_stride,
82
+ int64_t num_q_tokens,
83
+ int64_t num_kv_tokens,
84
+ int64_t num_q_heads,
85
+ int64_t num_kv_heads,
86
+ int64_t head_dim);
87
+
88
+ void f32_topk_torch(const at::Tensor& scores,
89
+ at::Tensor& expert_ids,
90
+ at::Tensor& expert_scores,
91
+ int64_t num_tokens,
92
+ int64_t num_experts,
93
+ int64_t num_active_experts);
94
+
95
+ void expert_routing_metadata_torch(const at::Tensor& expert_ids,
96
+ const at::Tensor& expert_scores,
97
+ at::Tensor& expert_offsets,
98
+ at::Tensor& intra_expert_offsets,
99
+ int64_t num_tokens,
100
+ int64_t num_experts);
101
+
102
+ void f32_scatter_torch(const at::Tensor& input,
103
+ const at::Tensor& expert_ids,
104
+ const at::Tensor& expert_scores,
105
+ const at::Tensor& expert_offsets,
106
+ const at::Tensor& intra_expert_offsets,
107
+ at::Tensor& output,
108
+ int64_t num_channels,
109
+ int64_t num_tokens,
110
+ int64_t num_active_experts);
111
+
112
+ void f32_bf16w_matmul_add_torch(const at::Tensor& input,
113
+ const at::Tensor& weight_bf16,
114
+ const at::Tensor& bias_bf16,
115
+ at::Tensor& output,
116
+ int64_t num_tokens,
117
+ int64_t num_cols,
118
+ int64_t num_rows,
119
+ int64_t threadgroup_size);