Shawn Tan commited on
Commit
4778870
·
1 Parent(s): 906c204

Build files.

Browse files
.gitignore CHANGED
@@ -8,7 +8,6 @@ __pycache__/
8
 
9
  # Distribution / packaging
10
  .Python
11
- build/
12
  develop-eggs/
13
  dist/
14
  downloads/
 
8
 
9
  # Distribution / packaging
10
  .Python
 
11
  develop-eggs/
12
  dist/
13
  downloads/
build/torch-universal/scattermoe/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
+ from . import parallel_experts
3
+ from . import kernels
4
+ from . import mlp
5
+ from . import utils
6
+
7
+ __all__ = [
8
+ "flatten_sort_count",
9
+ "parallel_linear",
10
+ "ParallelExperts",
11
+ "parallel_experts",
12
+ "kernels",
13
+ "mlp",
14
+ "utils"
15
+ ]
build/torch-universal/scattermoe/kernels/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import ops
2
+
3
+ __all__ = ["ops"]
build/torch-universal/scattermoe/kernels/ops.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from typing import Optional
5
+
6
+ BLOCK_M = 128
7
+ ALLOW_TF32 = True
8
+
9
+
10
+
11
+ @triton.jit
12
+ def _compute_expert_block(
13
+ E_idx, E_mask,
14
+ M_in_idx,
15
+ N_block, N_mask,
16
+ X_ptr, stride_xm, stride_xk,
17
+ W_ptr, stride_we, stride_wk, stride_wn,
18
+ K,
19
+ acc,
20
+ no_k_mask,
21
+ BLOCK_K,
22
+ allow_tf32=True,
23
+ ):
24
+
25
+ K_block = tl.arange(0, BLOCK_K)
26
+ X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
27
+ W_blk_ptrs = W_ptr + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn + E_idx * stride_we
28
+ iters = tl.cdiv(K, BLOCK_K)
29
+
30
+ for K_block_id in range(iters):
31
+ if no_k_mask:
32
+ x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
33
+ w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
34
+ else:
35
+ K_mask = (K_block_id * BLOCK_K + K_block) < K
36
+ x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
37
+ w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
38
+
39
+ X_blk_ptrs += BLOCK_K * stride_xk
40
+ W_blk_ptrs += BLOCK_K * stride_wk
41
+ acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
42
+ return acc
43
+
44
+
45
+ def _scatter2scatter_configs():
46
+ return [
47
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
48
+ ]
49
+
50
+ @triton.autotune(configs=_scatter2scatter_configs(), key=['M', 'N', 'K'], )
51
+ @triton.heuristics({
52
+ "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
53
+ "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
54
+ })
55
+ @triton.jit
56
+ def _scatter2scatter(
57
+ X_ptr, stride_xm: tl.constexpr, stride_xk: tl.constexpr,
58
+ W_ptr, stride_we, stride_wk: tl.constexpr, stride_wn: tl.constexpr,
59
+ Y_ptr, stride_ym: tl.constexpr, stride_yn: tl.constexpr,
60
+ B_ptr, stride_be: tl.constexpr, stride_bn: tl.constexpr,
61
+ grouped_idx_ptr, expert_idxs_ptr,
62
+ # block_start_idx_ptr,
63
+ FAN_OUT: tl.constexpr,
64
+ M, K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
65
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
66
+ ACC_TYPE: tl.constexpr,
67
+ # OUT_M,
68
+ allow_tf32: tl.constexpr,
69
+ x_grouped: tl.constexpr, y_grouped: tl.constexpr,
70
+ NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
71
+ ):
72
+ pid = tl.program_id(axis=0)
73
+
74
+ N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
75
+ M_block_id = pid // N_BLOCK_COUNT
76
+ N_block_id = pid % N_BLOCK_COUNT
77
+
78
+ M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
79
+ N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
80
+ N_mask = N_block < N
81
+ M_boundary_mask = M_block < (FAN_OUT * M)
82
+ E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
83
+
84
+ no_k_mask = K % BLOCK_K == 0
85
+
86
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
87
+ E_first_idx = tl.min(E_idxs)
88
+ E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
89
+ M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
90
+ for E_idx in range(E_first_idx, E_last_idx + 1):
91
+ E_mask = E_idxs == E_idx
92
+ E_M_idx = M_idx
93
+ if x_grouped:
94
+ M_in_idx = M_block
95
+ else:
96
+ M_in_idx = E_M_idx // FAN_OUT
97
+ acc = _compute_expert_block(
98
+ E_idx, E_mask,
99
+ M_in_idx, N_block, N_mask,
100
+ X_ptr, stride_xm, stride_xk,
101
+ W_ptr, stride_we, stride_wk, stride_wn,
102
+ K,
103
+ acc,
104
+ no_k_mask,
105
+ BLOCK_K,
106
+ allow_tf32=allow_tf32,
107
+ )
108
+
109
+ if B_ptr is not None:
110
+ B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
111
+ acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
112
+
113
+ if y_grouped:
114
+ M_out_idx = M_block
115
+ else:
116
+ M_out_idx = M_idx
117
+ Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
118
+ tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
119
+
120
+ def scatter2scatter(X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
121
+ b=None,
122
+ x_grouped=False, y_grouped=False,
123
+ out=None):
124
+ assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
125
+ assert sorted_scattered_idxs.size(0) == X.size(0) * k
126
+ # Pre-kernel setup
127
+ y_dim = W.size(-1)
128
+ L_scattered = sorted_expert_idxs.size(0)
129
+ if out is None:
130
+ output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
131
+ else:
132
+ assert out.size(0) == L_scattered and out.size(1) == y_dim
133
+ output = out
134
+
135
+ scatter2scatter_compileable(output, W, X, k, sorted_expert_idxs, sorted_scattered_idxs,
136
+ b, x_grouped, y_grouped)
137
+ return output
138
+
139
+
140
+ @torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
141
+ def scatter2scatter_compileable(
142
+ output: torch.Tensor,
143
+ W: torch.Tensor,
144
+ X: torch.Tensor,
145
+ k: int,
146
+ sorted_expert_idxs: torch.Tensor,
147
+ sorted_scattered_idxs: torch.Tensor,
148
+ b: Optional[torch.Tensor],
149
+ x_grouped: bool, y_grouped: bool) -> None:
150
+ def grid(META):
151
+ grid_num = (
152
+ triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"]) *
153
+ triton.cdiv(META['N'], META['BLOCK_N']),
154
+ )
155
+ return grid_num
156
+
157
+ if b is None:
158
+ b = None
159
+ stride_be = stride_bk = 0
160
+ else:
161
+ stride_be, stride_bk = b.stride()
162
+
163
+ _scatter2scatter[grid](
164
+ # X_ptr, stride_xm, stride_xk,
165
+ X, X.stride(0), X.stride(1),
166
+ # W_ptr, stride_we, stride_wk, stride_wn,
167
+ W, W.stride(0), W.stride(1), W.stride(2),
168
+ # Y_ptr, stride_ym, stride_yn,
169
+ output, output.stride(0), output.stride(1),
170
+ # B_ptr, stride_be, stride_bk
171
+ b, stride_be, stride_bk,
172
+ grouped_idx_ptr=sorted_scattered_idxs,
173
+ expert_idxs_ptr=sorted_expert_idxs,
174
+ # block_start_idx_ptr=padded_block_idxs,
175
+ FAN_OUT=k,
176
+ M=X.size(0),
177
+ K=X.size(1),
178
+ N=output.size(1), E=W.size(0),
179
+ BLOCK_M=BLOCK_M,
180
+ ACC_TYPE=tl.float32,
181
+ allow_tf32=ALLOW_TF32,
182
+ x_grouped=x_grouped, y_grouped=y_grouped,
183
+ )
184
+
185
+
186
+ def _config_XtY():
187
+ return [
188
+ triton.Config({'BLOCK_N': 128, 'BLOCK_K': 128, 'BLOCK_M': 32}, num_stages=4, num_warps=4),
189
+ ]
190
+
191
+ def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
192
+ DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
193
+ DW = DWt.permute(0, 2, 1)
194
+ if has_bias:
195
+ Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
196
+ else:
197
+ Db = None
198
+ groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
199
+ return DW, Db
200
+
201
+
202
+ @torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW"})
203
+ def groupXtY_compileable(
204
+ E: int,
205
+ DW: torch.Tensor,
206
+ Db: Optional[torch.Tensor],
207
+ DY: torch.Tensor,
208
+ X: torch.Tensor,
209
+ expert_offsets: torch.Tensor) -> None:
210
+ def grid(META):
211
+ grid = (
212
+ E * triton.cdiv(META['K'], META['BLOCK_K']),
213
+ triton.cdiv(META['N'], META['BLOCK_N']),
214
+ )
215
+ return grid
216
+
217
+ if Db is None:
218
+ stride_dbe = 0
219
+ stride_dbn = 0
220
+ else:
221
+ stride_dbe, stride_dbn = Db.stride()
222
+
223
+ _groupXtY[grid](
224
+ # DY_ptr, stride_dym, stride_dyk,
225
+ DY, DY.stride(0), DY.stride(1),
226
+ # X_ptr, stride_xm, stride_xn,
227
+ X, X.stride(0), X.stride(1),
228
+ # DW_ptr, stride_dwe, stride_dwk, stride_dwn,
229
+ DW, DW.stride(0), DW.stride(1), DW.stride(2),
230
+ # Db_ptr, stride_dwe, stride_dbn,
231
+ Db, stride_dbe, stride_dbn,
232
+ # expert_offsets_ptr,
233
+ expert_offsets,
234
+ # K: tl.constexpr, N: tl.constexpr,
235
+ M=DY.size(0), N=DY.size(-1), K=X.size(-1),
236
+ # ACC_TYPE: tl.constexpr,
237
+ ACC_TYPE=tl.float32,
238
+ allow_tf32=ALLOW_TF32
239
+ )
240
+
241
+
242
+ @triton.autotune(configs=_config_XtY(), key=['M', 'N', 'K'], )
243
+ @triton.heuristics({
244
+ "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0,
245
+ "NO_N_MASK": lambda args: (args['N'] % args['BLOCK_N']) == 0,
246
+ })
247
+ @triton.jit
248
+ def _groupXtY(
249
+ DY_ptr, stride_dym, stride_dyk,
250
+ X_ptr, stride_xm, stride_xn,
251
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
252
+ Db_ptr, stride_dbe, stride_dbn,
253
+ expert_offsets_ptr,
254
+ M, K: tl.constexpr, N: tl.constexpr,
255
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
256
+ ACC_TYPE: tl.constexpr,
257
+ allow_tf32: tl.constexpr,
258
+ NO_K_MASK: tl.constexpr, NO_N_MASK: tl.constexpr
259
+ ):
260
+ pid0 = tl.program_id(axis=0)
261
+ pid1 = tl.program_id(axis=1)
262
+ num0 = tl.num_programs(0)
263
+ num1 = tl.num_programs(1)
264
+ # pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
265
+ pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
266
+
267
+ K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
268
+ E_idx = pid0 // K_BLOCK_COUNT
269
+ K_block_id = pid0 % K_BLOCK_COUNT
270
+ N_block_id = pid1
271
+
272
+ if E_idx == 0:
273
+ start_idx = 0
274
+ else:
275
+ start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
276
+ end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
277
+
278
+
279
+ if end_idx > start_idx:
280
+ M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
281
+
282
+ K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
283
+ K_mask = K_block < K
284
+ K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
285
+
286
+ N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
287
+ N_mask = N_block < N
288
+ N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
289
+
290
+ M_idxs = M_block
291
+ xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
292
+ dy_blk_ptrs = DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
293
+ if (Db_ptr is not None) and (K_block_id == 0):
294
+ _xty_and_bias(
295
+ E_idx, start_idx, end_idx,
296
+ M_block,
297
+ K_block, K_mask, N_block, N_mask,
298
+ dy_blk_ptrs, stride_dym,
299
+ xt_blk_ptrs, stride_xm,
300
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
301
+ Db_ptr, stride_dbe, stride_dbn,
302
+ BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
303
+ allow_tf32, NO_K_MASK, NO_N_MASK,
304
+ compute_bias=True
305
+ )
306
+ else:
307
+ _xty_and_bias(
308
+ E_idx, start_idx, end_idx,
309
+ M_block,
310
+ K_block, K_mask, N_block, N_mask,
311
+ dy_blk_ptrs, stride_dym,
312
+ xt_blk_ptrs, stride_xm,
313
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
314
+ Db_ptr, stride_dbe, stride_dbn,
315
+ BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
316
+ allow_tf32, NO_K_MASK, NO_N_MASK,
317
+ compute_bias=False
318
+ )
319
+
320
+
321
+ @triton.jit
322
+ def _xty_and_bias(
323
+ E_idx, start_idx, end_idx,
324
+ M_block,
325
+ K_block, K_mask, N_block, N_mask,
326
+ dy_blk_ptrs, stride_dym,
327
+ xt_blk_ptrs, stride_xm,
328
+ DW_ptr, stride_dwe, stride_dwk, stride_dwn,
329
+ Db_ptr, stride_dbe, stride_dbn,
330
+ BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,
331
+ allow_tf32, NO_K_MASK, NO_N_MASK,
332
+ compute_bias: tl.constexpr
333
+ ):
334
+
335
+ if compute_bias:
336
+ db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
337
+ else:
338
+ db_acc = None
339
+
340
+ acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
341
+ iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
342
+ for i in range(0, iters):
343
+ M_mask = (i * BLOCK_M + M_block) < end_idx
344
+ if NO_K_MASK:
345
+ xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
346
+ else:
347
+ xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
348
+ if NO_N_MASK:
349
+ dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
350
+ else:
351
+ dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
352
+
353
+ acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
354
+
355
+ xt_blk_ptrs += BLOCK_M * stride_xm
356
+ dy_blk_ptrs += BLOCK_M * stride_dym
357
+
358
+ if compute_bias:
359
+ db_acc += tl.sum(dy, axis=0)
360
+
361
+ DW_blk_ptrs = DW_ptr + E_idx * stride_dwe + K_block[:, None] * stride_dwk + N_block[None, :] * stride_dwn
362
+ acc = acc.to(DW_blk_ptrs.dtype.element_ty)
363
+ tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
364
+ if compute_bias:
365
+ Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
366
+ tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
367
+
368
+
369
+
370
+ def _config_grouping():
371
+ return [
372
+ triton.Config({'BLOCK_N': 256, 'BLOCK_K': 128}, num_stages=4, num_warps=4),
373
+ # triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
374
+ # triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
375
+ ]
376
+
377
+ def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
378
+ N = sorted_expert_idxs.size(0)
379
+ K = A.size(1)
380
+ assert A.size(0) * fan_out == N
381
+ if out is not None:
382
+ Y = out
383
+ else:
384
+ Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
385
+ group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
386
+ return Y
387
+
388
+
389
+ @torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
390
+ def group_compileable(
391
+ A: torch.Tensor,
392
+ K: int,
393
+ N: int,
394
+ Y: torch.Tensor,
395
+ coeff: torch.Tensor, has_coeff: bool,
396
+ fan_out: int,
397
+ sorted_expert_idxs: torch.Tensor) -> None:
398
+ def grid(META):
399
+ grid_num = (triton.cdiv(META['N'], META['BLOCK_N']),)
400
+ return grid_num
401
+ _group[grid](
402
+ # A_ptr, stride_an, stride_ai,
403
+ A, A.stride(0), A.stride(1), has_coeff, coeff, fan_out,
404
+ # Y_ptr, stride_yn, stride_yk,
405
+ Y, Y.stride(0), Y.stride(1),
406
+ # grouped_idx_ptr,
407
+ sorted_expert_idxs,
408
+ # N: tl.constexpr, K: tl.constexpr,
409
+ N, K
410
+ )
411
+
412
+
413
+ @triton.autotune(configs=_config_grouping(), key=['K'])
414
+ @triton.heuristics({
415
+ "NO_K_MASK": lambda args: (args['K'] % args['BLOCK_K']) == 0
416
+ })
417
+ @triton.jit
418
+ def _group(
419
+ src_ptr, stride_sn, stride_sk, has_coeff: tl.constexpr, coeff_ptr, FAN_OUT: tl.constexpr,
420
+ tgt_ptr, stride_tn, stride_ti,
421
+ grouped_idx_ptr,
422
+ N, K: tl.constexpr,
423
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
424
+ NO_K_MASK: tl.constexpr
425
+ ):
426
+ pid = tl.program_id(axis=0)
427
+
428
+ N_block_id = pid
429
+ N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
430
+ N_mask = N_blk < N
431
+ N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
432
+ N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
433
+
434
+ K_blk = tl.arange(0, BLOCK_K)
435
+ src_blk_ptrs = src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
436
+ tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
437
+
438
+ if has_coeff:
439
+ c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
440
+
441
+ iters = tl.cdiv(K, BLOCK_K)
442
+ for i in range(0, iters):
443
+ if NO_K_MASK or i < iters - 1:
444
+ block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
445
+ if has_coeff:
446
+ block *= c
447
+ tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
448
+
449
+ else:
450
+ K_mask = (i * BLOCK_K + K_blk) < K
451
+ mask = N_mask[:, None] & K_mask[None, :]
452
+ block = tl.load(src_blk_ptrs, mask=mask)
453
+ if has_coeff:
454
+ block *= c
455
+ tl.store(tgt_blk_ptrs, block, mask=mask)
456
+ src_blk_ptrs += BLOCK_K * stride_sk
457
+ tgt_blk_ptrs += BLOCK_K * stride_ti
build/torch-universal/scattermoe/kernels/single.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ @triton.jit
6
+ def _single2scatter(
7
+ X_ptr, stride_xm, stride_xk,
8
+ W_ptr, stride_we, stride_wk, stride_wn,
9
+ Y_ptr, stride_ym, stride_yn,
10
+ expert_idxs_ptr,
11
+ FAN_OUT: tl.constexpr,
12
+ K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
13
+ BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
14
+ ACC_TYPE: tl.constexpr,
15
+ ):
16
+ pid0 = tl.program_id(axis=0)
17
+ pid1 = tl.program_id(axis=1)
18
+
19
+ N_block_id = pid0
20
+ if FAN_OUT == 1:
21
+ in_idx = pid1
22
+ else:
23
+ in_idx = 0
24
+ out_idx = pid1
25
+
26
+ K_block = tl.arange(0, BLOCK_K)
27
+ N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
28
+ E_idx = tl.load(expert_idxs_ptr + pid1)
29
+ X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
30
+ W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
31
+ acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
32
+ for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
33
+ x = tl.load(X_blk_ptrs)
34
+ w = tl.load(W_blk_ptrs)
35
+ acc += tl.sum(x * w, axis=0)[None, :]
36
+ X_blk_ptrs += BLOCK_K * stride_xk
37
+ W_blk_ptrs += BLOCK_K * stride_wk
38
+ Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
39
+ tl.store(Y_blk_ptrs, acc)
40
+
41
+ def single2scatter(X, W, expert_idxs):
42
+ E, xdim, ydim = W.size()
43
+ k = expert_idxs.size(1)
44
+ assert X.size(0) == k or X.size(0) == 1
45
+ Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
46
+ BLOCK_N = 128
47
+ BLOCK_K = 128
48
+ grid = ydim // BLOCK_N, k
49
+ _single2scatter[grid](
50
+ X, X.stride(0), X.stride(1),
51
+ W, W.stride(0), W.stride(1), W.stride(2),
52
+ Y, Y.stride(0), Y.stride(1),
53
+ expert_idxs,
54
+ FAN_OUT=Y.size(0) // X.size(0),
55
+ K=xdim, N=ydim, E=E,
56
+ BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
57
+ ACC_TYPE=tl.float32
58
+ )
59
+ return Y
build/torch-universal/scattermoe/mlp.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .parallel_experts import ParallelExperts, flatten_sort_count
5
+
6
+ class MLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ input_size,
10
+ hidden_size,
11
+ num_experts,
12
+ top_k,
13
+ bias=False,
14
+ activation=None,
15
+ ):
16
+ super(MLP, self).__init__()
17
+
18
+ self.num_experts = num_experts
19
+ self.input_size = input_size
20
+ self.hidden_size = hidden_size
21
+ self.experts = ParallelExperts(num_experts, input_size, hidden_size, bias=bias)
22
+ self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias)
23
+ self.top_k = min(top_k, self.num_experts)
24
+ self.activation = activation
25
+
26
+ def extra_repr(self):
27
+ return 'k={}'.format(self.top_k)
28
+
29
+ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
30
+ x_shape = x.size()
31
+ x = x.view(-1, x_shape[-1])
32
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
33
+ flatten_sort_count(expert_idxs, num_experts=self.num_experts)
34
+
35
+ h = self.experts(
36
+ x, self.top_k,
37
+ sorted_expert_idxs, sorted_scattered_idxs,
38
+ expert_offsets,
39
+ grouped_out=True
40
+ )
41
+ h = self.activation(h)
42
+ y = self.output_experts(
43
+ h, 1, sorted_expert_idxs, sorted_scattered_idxs,
44
+ expert_offsets,
45
+ grouped_in=True,
46
+ gates=expert_p,
47
+ )
48
+ y = y.view(*x_shape[:-1], y.size(-1))
49
+ return y
50
+
51
+ class GLUMLP(nn.Module):
52
+ def __init__(
53
+ self,
54
+ input_size,
55
+ hidden_size,
56
+ num_experts,
57
+ top_k,
58
+ bias=False,
59
+ activation=nn.SiLU(),
60
+ ):
61
+ super(GLUMLP, self).__init__()
62
+
63
+ self.num_experts = num_experts
64
+ self.input_size = input_size
65
+ self.hidden_size = hidden_size
66
+ self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size, bias=bias)
67
+ self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias)
68
+ self.top_k = min(top_k, self.num_experts)
69
+ self.activation = activation
70
+
71
+ def extra_repr(self):
72
+ return 'k={}'.format(self.top_k)
73
+
74
+ def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
75
+ x_shape = x.size()
76
+ x = x.view(-1, x_shape[-1])
77
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
78
+ flatten_sort_count(expert_idxs, num_experts=self.num_experts)
79
+
80
+
81
+ h, gates = self.experts(
82
+ x, self.top_k,
83
+ sorted_expert_idxs, sorted_scattered_idxs,
84
+ expert_offsets,
85
+ grouped_out=True
86
+ ).chunk(2, dim=-1)
87
+ h = self.activation(gates) * h
88
+ y = self.output_experts(
89
+ h, 1, sorted_expert_idxs, sorted_scattered_idxs,
90
+ expert_offsets,
91
+ grouped_in=True,
92
+ gates=expert_p,
93
+ )
94
+ y = y.view(*x_shape[:-1], y.size(-1))
95
+ return y
96
+
build/torch-universal/scattermoe/parallel_experts.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from . import kernels
4
+ from typing import Optional
5
+
6
+ @torch.library.custom_op("scattermoe::bincount", mutates_args={})
7
+ def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
8
+ return x.bincount(minlength=minlength)
9
+
10
+ @compileable_bincount.register_fake
11
+ def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
12
+ return torch.empty(minlength, dtype=torch.long, device=x.device)
13
+
14
+ @torch.compile
15
+ def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
16
+ with torch.no_grad():
17
+ flattened_expert_idxs = expert_idxs.flatten()
18
+ sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
19
+ expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
20
+ expert_offsets = expert_counts.cumsum(-1)
21
+ return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
22
+
23
+
24
+
25
+ class ParallelLinear(torch.autograd.Function):
26
+ @staticmethod
27
+ def forward(
28
+ ctx,
29
+ x: torch.Tensor, expert_weights: torch.Tensor, k: int,
30
+ sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
31
+ expert_offsets: torch.Tensor,
32
+ expert_biases: Optional[torch.Tensor]=None,
33
+ gates: Optional[torch.Tensor]=None,
34
+ grouped_in: bool =False, grouped_out: bool=False,
35
+ ):
36
+ with torch.device(x.device):
37
+ output = kernels.ops.scatter2scatter(
38
+ X=x, W=expert_weights,
39
+ b=expert_biases, k=k,
40
+ sorted_expert_idxs=sorted_expert_idxs,
41
+ sorted_scattered_idxs=sorted_scattered_idxs,
42
+ x_grouped=grouped_in, y_grouped=grouped_out
43
+ )
44
+ if gates is not None:
45
+ output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
46
+ output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
47
+ else:
48
+ output_expanded = None
49
+
50
+ ctx.save_for_backward(
51
+ x, expert_weights,
52
+ expert_biases,
53
+ sorted_expert_idxs,
54
+ sorted_scattered_idxs,
55
+ expert_offsets,
56
+ gates,
57
+ output_expanded
58
+ )
59
+ ctx.grouped_in = grouped_in
60
+ ctx.grouped_out = grouped_out
61
+ ctx.k = k
62
+ return output
63
+ @staticmethod
64
+ def backward(ctx, grad_out: torch.Tensor):
65
+ with torch.device(grad_out.device):
66
+ (x, expert_weights, expert_biases,
67
+ sorted_expert_idxs,
68
+ sorted_scattered_idxs,
69
+ expert_offsets,
70
+ gates, output_expanded) = ctx.saved_tensors
71
+ k = ctx.k
72
+ grouped_in = ctx.grouped_in
73
+ grouped_out = ctx.grouped_out
74
+ # print("backward")
75
+
76
+ if gates is not None:
77
+ # calculate gates gradient
78
+ # d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
79
+ d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
80
+ gates_flat = gates.flatten()
81
+ gate_fan = gates.size(1)
82
+ grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
83
+ else:
84
+ d_gates = None
85
+ gates_flat = None
86
+ gate_fan = 1
87
+ grouped_grad_out = None
88
+
89
+ if grouped_out:
90
+ grouped_grad_out = grad_out
91
+ else:
92
+ grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
93
+ fan_out=gate_fan, coeff=gates_flat,
94
+ out=grouped_grad_out)
95
+ if grouped_in:
96
+ grouped_x = x
97
+ d_expanded_input = None
98
+ else:
99
+ grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
100
+ d_expanded_input = grouped_x
101
+
102
+ d_weights, d_biases = kernels.ops.group_bwd_W(
103
+ DY=grouped_grad_out, X=grouped_x,
104
+ expert_offsets=expert_offsets,
105
+ E=expert_weights.size(0),
106
+ has_bias=expert_biases is not None
107
+ )
108
+
109
+
110
+ d_expanded_input = kernels.ops.scatter2scatter(
111
+ X=grouped_grad_out, x_grouped=True,
112
+ W=expert_weights.permute(0, 2, 1),
113
+ sorted_expert_idxs=sorted_expert_idxs,
114
+ sorted_scattered_idxs=sorted_scattered_idxs,
115
+ k=1,
116
+ y_grouped=grouped_in,
117
+ out=d_expanded_input # Reuse grouped_x buffer
118
+ )
119
+
120
+ if k == 1:
121
+ d_input = d_expanded_input
122
+ else:
123
+ d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
124
+ # print("backward end.")
125
+ return (
126
+ # x, expert_weights,
127
+ d_input, d_weights,
128
+ # k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
129
+ None, None, None, None,
130
+ # bias, gates
131
+ d_biases, d_gates,
132
+ # grouped_in, grouped_out,
133
+ None, None
134
+ )
135
+
136
+ def parallel_linear(inputs, expert_weights, k,
137
+ sorted_expert_idxs, sorted_scattered_idxs,
138
+ expert_offsets,
139
+ expert_biases=None,
140
+ gates=None, grouped_in=False, grouped_out=False):
141
+ results = ParallelLinear.apply(inputs, expert_weights, k,
142
+ sorted_expert_idxs, sorted_scattered_idxs,
143
+ expert_offsets,
144
+ expert_biases,
145
+ gates, grouped_in, grouped_out)
146
+ return results
147
+
148
+ class ParallelExperts(nn.Module):
149
+ def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
150
+ super().__init__()
151
+ self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
152
+
153
+ if bias:
154
+ self.bias = nn.Parameter(torch.empty(num_experts, output_size))
155
+ else:
156
+ self.bias = None
157
+
158
+ self.num_experts = num_experts
159
+ self.input_size = input_size
160
+ self.output_size = output_size
161
+ self.reset_parameters()
162
+
163
+ def extra_repr(self):
164
+ return 'num_experts={}, input_size={}, output_size={}'.format(
165
+ self.num_experts, self.input_size, self.output_size)
166
+
167
+ def reset_parameters(self) -> None:
168
+ nn.init.normal_(self.weight, std=0.02)
169
+ if self.bias is not None:
170
+ nn.init.zeros_(self.bias)
171
+
172
+ def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
173
+ expert_offsets,
174
+ gates=None, grouped_in=False, grouped_out=False):
175
+
176
+ results = parallel_linear(
177
+ inputs, self.weight.permute(0, 2, 1), k,
178
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
179
+ expert_biases=self.bias,
180
+ gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
181
+ )
182
+ return results
build/torch-universal/scattermoe/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import replace_moe
2
+
3
+ __all__ = ['replace_moe']
build/torch-universal/scattermoe/utils/replace_moe.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .. import parallel_linear, flatten_sort_count
3
+ from torch.nn import functional as F
4
+ from torch import nn
5
+
6
+ import logging
7
+
8
+ def replace_function(cls, fun_name):
9
+ def decorator(fun):
10
+ def _fun(*args, **kwargs):
11
+ filename = fun.__code__.co_filename
12
+ name = fun.__name__
13
+ logging.info(f"Replacing `{cls.__name__}.{fun_name}` with {filename}:{name}")
14
+ setattr(cls, fun_name, fun)
15
+ return fun(*args, **kwargs)
16
+ setattr(cls, fun_name, _fun)
17
+ return decorator
18
+
19
+ try:
20
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts
21
+ @replace_function(cls=GptOssExperts, fun_name='forward')
22
+ def gpt_oss_forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
23
+ batch_size = hidden_states.shape[0]
24
+ hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
25
+ k = router_indices.shape[1]
26
+ selected_weights = torch.gather(routing_weights, dim=1, index=router_indices)
27
+ router_indices = router_indices.flatten()
28
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
29
+ flatten_sort_count(router_indices, num_experts=self.num_experts)
30
+
31
+ gate_up = parallel_linear(
32
+ hidden_states, self.gate_up_proj, k,
33
+ sorted_expert_idxs, sorted_scattered_idxs,
34
+ expert_offsets,
35
+ expert_biases=self.gate_up_proj_bias,
36
+ grouped_in=False, grouped_out=True,
37
+ )
38
+
39
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
40
+ gate = gate.clamp(min=None, max=self.limit)
41
+ up = up.clamp(min=-self.limit, max=self.limit)
42
+ glu = gate * torch.sigmoid(gate * self.alpha)
43
+ gated_output_ = (up + 1) * glu
44
+
45
+ out_scattered = parallel_linear(
46
+ gated_output_, self.down_proj, 1,
47
+ sorted_expert_idxs, sorted_scattered_idxs,
48
+ expert_offsets,
49
+ expert_biases=self.down_proj_bias,
50
+ grouped_in=True, grouped_out=False,
51
+ gates=selected_weights,
52
+ )
53
+
54
+ next_states = out_scattered.view(batch_size, -1, self.hidden_size)
55
+ return next_states
56
+ except Exception:
57
+ logging.info("Failed to replace GptOssExperts")
58
+
59
+
60
+ try:
61
+ from transformers.models.granitemoehybrid.modeling_granitemoehybrid import GraniteMoeHybridMoE
62
+ @replace_function(cls=GraniteMoeHybridMoE, fun_name='forward')
63
+ def granite_moe_forward(self, layer_input):
64
+ bsz, length, emb_size = layer_input.size()
65
+ layer_input = layer_input.reshape(-1, emb_size)
66
+ # compute the top_k routing decision
67
+ router_logits = self.router.layer(layer_input)
68
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
69
+ routing_weights, selected_experts = torch.topk(routing_weights, self.router.top_k, dim=-1)
70
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
71
+ routing_weights = routing_weights.to(layer_input.dtype)
72
+ sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
73
+ flatten_sort_count(selected_experts, num_experts=self.router.num_experts)
74
+
75
+ # compute experts
76
+ gates, h = parallel_linear(
77
+ layer_input, self.input_linear.weight.transpose(2, 1),
78
+ self.router.top_k,
79
+ sorted_expert_idxs, sorted_scattered_idxs,
80
+ expert_offsets,
81
+ grouped_in=False, grouped_out=True,
82
+ ).chunk(2, dim=-1)
83
+ h = self.activation(gates) * h
84
+ layer_output = parallel_linear(
85
+ h, self.output_linear.weight.transpose(2, 1),
86
+ 1,
87
+ sorted_expert_idxs, sorted_scattered_idxs,
88
+ expert_offsets,
89
+ grouped_in=True, grouped_out=False,
90
+ gates=routing_weights
91
+ )
92
+ layer_output = layer_output.view(bsz, length, emb_size)
93
+ return layer_output, router_logits
94
+ except Exception:
95
+ logging.info("Failed to replace GraniteMoeHybridMoE")
96
+
97
+ # TODO consolidating params into tensor. OOMs.
98
+ # try:
99
+ # from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
100
+ # fun = Qwen3MoeSparseMoeBlock.__init__
101
+ # def qwen3moe__init__(self: Qwen3MoeSparseMoeBlock, config):
102
+ # fun(self, config)
103
+
104
+ # weight_refs = {}
105
+ # def assemble_weights(self):
106
+ # down_proj_weights = [weight_refs[i, "down_proj"] for i in range(len(self.experts))]
107
+ # up_proj_weights = [weight_refs[i, "up_proj"] for i in range(len(self.experts))]
108
+ # gate_proj_weights = [weight_refs[i, "gate_proj"] for i in range(len(self.experts))]
109
+ # if (all(w is not None for w in down_proj_weights) and
110
+ # all(w is not None for w in up_proj_weights) and
111
+ # all(w is not None for w in gate_proj_weights)):
112
+ # self.act_fn = self.experts[0].act_fn
113
+ # mid_size = up_proj_weights[0].size(0)
114
+ # self.down_proj = nn.Parameter(torch.stack(down_proj_weights))
115
+ # gate_up_proj = torch.empty(
116
+ # self.num_experts, 2 * mid_size, config.hidden_size,
117
+ # dtype=up_proj_weights[0].dtype
118
+ # )
119
+ # for i in range(self.num_experts - 1, -1, -1):
120
+ # gate_up_proj[i, :mid_size] = gate_proj_weights[i]
121
+ # gate_up_proj[i, mid_size:] = up_proj_weights[i]
122
+ # del self.experts[i]
123
+ # self.gate_up_proj = nn.Parameter(gate_up_proj)
124
+ # del self.experts
125
+
126
+ # for i, e in enumerate(self.experts):
127
+ # def weight_tracker(expert_id, expert_weight_name):
128
+ # def fun(module, err_msgs):
129
+ # id_tup = (expert_id, expert_weight_name)
130
+ # assert id_tup in weight_refs
131
+ # weight_refs[id_tup] = module.weight
132
+ # assemble_weights(self)
133
+ # return fun
134
+ # e.gate_proj.register_load_state_dict_post_hook(weight_tracker(i, "gate_proj"))
135
+ # weight_refs[i, "gate_proj"] = None
136
+ # e.up_proj.register_load_state_dict_post_hook(weight_tracker(i, "up_proj"))
137
+ # weight_refs[i, "up_proj"] = None
138
+ # e.down_proj.register_load_state_dict_post_hook(weight_tracker(i, "down_proj"))
139
+ # weight_refs[i, "down_proj"] = None
140
+
141
+ # # def hook(module, state_dict, prefix, local_metadata):
142
+ # # local_keys = [k for k in state_dict.keys() if k.startswith(prefix)]
143
+ # # expert_keys = [k for k in local_keys if 'mlp.' in k]
144
+ # # if len(expert_keys) == 0: # deleted
145
+ # # assert prefix + "down_proj" in state_dict
146
+ # # assert prefix + "gate_up_proj" in state_dict
147
+ # # self.register_state_dict_post_hook(hook)
148
+
149
+ # state_dict_fun = Qwen3MoeSparseMoeBlock.state_dict
150
+ # def qwen3moe_state_dict(self: Qwen3MoeSparseMoeBlock, *args, destination=None, prefix="", keep_vars=False):
151
+ # destination = state_dict_fun(self, *args, destination, prefix, keep_vars)
152
+ # local_keys = [k for k in destination.keys() if k.startswith(prefix)]
153
+ # if ((prefix + "down_proj" in local_keys) and (prefix + "gate_up_proj" in local_keys)):
154
+ # # need to break down
155
+ # for i in range(self.num_experts):
156
+ # down_proj_name = prefix + f"experts.{i}.down_proj.weight"
157
+ # up_proj_name = prefix + f"experts.{i}.up_proj.weight"
158
+ # gate_proj_name = prefix + f"experts.{i}.gate_proj.weight"
159
+ # destination[down_proj_name] = destination[prefix + "down_proj"][i]
160
+ # gate_proj, up_proj = destination[prefix + "gate_up_proj"][i].chunk(2, dim=1)
161
+ # destination[gate_proj_name] = gate_proj.contiguous()
162
+ # destination[up_proj_name] = up_proj.contiguous()
163
+ # del destination[prefix + "down_proj"]
164
+ # del destination[prefix + "gate_up_proj"]
165
+
166
+ # return destination
167
+ # Qwen3MoeSparseMoeBlock.state_dict = qwen3moe_state_dict
168
+
169
+
170
+ # Qwen3MoeSparseMoeBlock.__init__ = qwen3moe__init__
171
+
172
+ # def qwen3moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
173
+ # """ """
174
+ # batch_size, sequence_length, hidden_dim = hidden_states.shape
175
+ # hidden_states = hidden_states.view(-1, hidden_dim)
176
+ # # router_logits: (batch * sequence_length, n_experts)
177
+ # router_logits = self.gate(hidden_states)
178
+
179
+ # routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
180
+ # routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
181
+ # if self.norm_topk_prob: # only diff with mixtral sparse moe block!
182
+ # routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
183
+ # # we cast back to the input dtype
184
+ # routing_weights = routing_weights.to(hidden_states.dtype)
185
+ # sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(selected_experts, num_experts=self.num_experts)
186
+
187
+ # gate_up = parallel_linear(
188
+ # hidden_states, self.gate_up_proj.transpose(1, 2), self.top_k,
189
+ # sorted_expert_idxs, sorted_scattered_idxs,
190
+ # expert_offsets,
191
+ # grouped_in=False, grouped_out=True,
192
+ # )
193
+
194
+ # _gate, up = gate_up.chunk(2, dim=-1)
195
+ # intermediate = self.act_fn(_gate) * up
196
+
197
+ # final_hidden_states = parallel_linear(
198
+ # intermediate, self.down_proj.transpose(1, 2), 1,
199
+ # sorted_expert_idxs, sorted_scattered_idxs,
200
+ # expert_offsets,
201
+ # grouped_in=True, grouped_out=False,
202
+ # gates=routing_weights,
203
+ # )
204
+ # final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
205
+ # return final_hidden_states, router_logits
206
+ # Qwen3MoeSparseMoeBlock.forward = qwen3moe_forward
207
+ # except Exception as e:
208
+ # pass
209
+