Shawn Tan commited on
Commit
192f087
·
1 Parent(s): d81c9ad

Remove mlp.

Browse files
build/torch-universal/scattermoe/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
  from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
  from . import parallel_experts
3
  from . import kernels
4
- from . import mlp
5
  from . import layers
6
 
7
  __all__ = [
@@ -10,6 +9,5 @@ __all__ = [
10
  "ParallelExperts",
11
  "parallel_experts",
12
  "kernels",
13
- "mlp",
14
  "layers"
15
  ]
 
1
  from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
  from . import parallel_experts
3
  from . import kernels
 
4
  from . import layers
5
 
6
  __all__ = [
 
9
  "ParallelExperts",
10
  "parallel_experts",
11
  "kernels",
 
12
  "layers"
13
  ]
build/torch-universal/scattermoe/layers.py CHANGED
@@ -48,5 +48,5 @@ class ScatterMoEGatedMLP(nn.Module):
48
  gates=routing_weights
49
  )
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
- return layer_output
52
 
 
48
  gates=routing_weights
49
  )
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
+ return layer_output, router_logits
52
 
build/torch-universal/scattermoe/mlp.py DELETED
@@ -1,96 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from .parallel_experts import ParallelExperts, flatten_sort_count
5
-
6
- class MLP(nn.Module):
7
- def __init__(
8
- self,
9
- input_size,
10
- hidden_size,
11
- num_experts,
12
- top_k,
13
- bias=False,
14
- activation=None,
15
- ):
16
- super(MLP, self).__init__()
17
-
18
- self.num_experts = num_experts
19
- self.input_size = input_size
20
- self.hidden_size = hidden_size
21
- self.experts = ParallelExperts(num_experts, input_size, hidden_size, bias=bias)
22
- self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias)
23
- self.top_k = min(top_k, self.num_experts)
24
- self.activation = activation
25
-
26
- def extra_repr(self):
27
- return 'k={}'.format(self.top_k)
28
-
29
- def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
30
- x_shape = x.size()
31
- x = x.view(-1, x_shape[-1])
32
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
33
- flatten_sort_count(expert_idxs, num_experts=self.num_experts)
34
-
35
- h = self.experts(
36
- x, self.top_k,
37
- sorted_expert_idxs, sorted_scattered_idxs,
38
- expert_offsets,
39
- grouped_out=True
40
- )
41
- h = self.activation(h)
42
- y = self.output_experts(
43
- h, 1, sorted_expert_idxs, sorted_scattered_idxs,
44
- expert_offsets,
45
- grouped_in=True,
46
- gates=expert_p,
47
- )
48
- y = y.view(*x_shape[:-1], y.size(-1))
49
- return y
50
-
51
- class GLUMLP(nn.Module):
52
- def __init__(
53
- self,
54
- input_size,
55
- hidden_size,
56
- num_experts,
57
- top_k,
58
- bias=False,
59
- activation=nn.SiLU(),
60
- ):
61
- super(GLUMLP, self).__init__()
62
-
63
- self.num_experts = num_experts
64
- self.input_size = input_size
65
- self.hidden_size = hidden_size
66
- self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size, bias=bias)
67
- self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias)
68
- self.top_k = min(top_k, self.num_experts)
69
- self.activation = activation
70
-
71
- def extra_repr(self):
72
- return 'k={}'.format(self.top_k)
73
-
74
- def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
75
- x_shape = x.size()
76
- x = x.view(-1, x_shape[-1])
77
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
78
- flatten_sort_count(expert_idxs, num_experts=self.num_experts)
79
-
80
-
81
- h, gates = self.experts(
82
- x, self.top_k,
83
- sorted_expert_idxs, sorted_scattered_idxs,
84
- expert_offsets,
85
- grouped_out=True
86
- ).chunk(2, dim=-1)
87
- h = self.activation(gates) * h
88
- y = self.output_experts(
89
- h, 1, sorted_expert_idxs, sorted_scattered_idxs,
90
- expert_offsets,
91
- grouped_in=True,
92
- gates=expert_p,
93
- )
94
- y = y.view(*x_shape[:-1], y.size(-1))
95
- return y
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/scattermoe/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
  from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
  from . import parallel_experts
3
  from . import kernels
4
- from . import mlp
5
  from . import layers
6
 
7
  __all__ = [
@@ -10,6 +9,5 @@ __all__ = [
10
  "ParallelExperts",
11
  "parallel_experts",
12
  "kernels",
13
- "mlp",
14
  "layers"
15
  ]
 
1
  from .parallel_experts import flatten_sort_count, parallel_linear, ParallelExperts
2
  from . import parallel_experts
3
  from . import kernels
 
4
  from . import layers
5
 
6
  __all__ = [
 
9
  "ParallelExperts",
10
  "parallel_experts",
11
  "kernels",
 
12
  "layers"
13
  ]
torch-ext/scattermoe/layers.py CHANGED
@@ -48,5 +48,5 @@ class ScatterMoEGatedMLP(nn.Module):
48
  gates=routing_weights
49
  )
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
- return layer_output
52
 
 
48
  gates=routing_weights
49
  )
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
+ return layer_output, router_logits
52
 
torch-ext/scattermoe/mlp.py DELETED
@@ -1,96 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- from .parallel_experts import ParallelExperts, flatten_sort_count
5
-
6
- class MLP(nn.Module):
7
- def __init__(
8
- self,
9
- input_size,
10
- hidden_size,
11
- num_experts,
12
- top_k,
13
- bias=False,
14
- activation=None,
15
- ):
16
- super(MLP, self).__init__()
17
-
18
- self.num_experts = num_experts
19
- self.input_size = input_size
20
- self.hidden_size = hidden_size
21
- self.experts = ParallelExperts(num_experts, input_size, hidden_size, bias=bias)
22
- self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias)
23
- self.top_k = min(top_k, self.num_experts)
24
- self.activation = activation
25
-
26
- def extra_repr(self):
27
- return 'k={}'.format(self.top_k)
28
-
29
- def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
30
- x_shape = x.size()
31
- x = x.view(-1, x_shape[-1])
32
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
33
- flatten_sort_count(expert_idxs, num_experts=self.num_experts)
34
-
35
- h = self.experts(
36
- x, self.top_k,
37
- sorted_expert_idxs, sorted_scattered_idxs,
38
- expert_offsets,
39
- grouped_out=True
40
- )
41
- h = self.activation(h)
42
- y = self.output_experts(
43
- h, 1, sorted_expert_idxs, sorted_scattered_idxs,
44
- expert_offsets,
45
- grouped_in=True,
46
- gates=expert_p,
47
- )
48
- y = y.view(*x_shape[:-1], y.size(-1))
49
- return y
50
-
51
- class GLUMLP(nn.Module):
52
- def __init__(
53
- self,
54
- input_size,
55
- hidden_size,
56
- num_experts,
57
- top_k,
58
- bias=False,
59
- activation=nn.SiLU(),
60
- ):
61
- super(GLUMLP, self).__init__()
62
-
63
- self.num_experts = num_experts
64
- self.input_size = input_size
65
- self.hidden_size = hidden_size
66
- self.experts = ParallelExperts(num_experts, input_size, 2 * hidden_size, bias=bias)
67
- self.output_experts = ParallelExperts(num_experts, hidden_size, input_size, bias=bias)
68
- self.top_k = min(top_k, self.num_experts)
69
- self.activation = activation
70
-
71
- def extra_repr(self):
72
- return 'k={}'.format(self.top_k)
73
-
74
- def forward(self, x: torch.Tensor, expert_p: torch.Tensor, expert_idxs: torch.Tensor):
75
- x_shape = x.size()
76
- x = x.view(-1, x_shape[-1])
77
- sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = \
78
- flatten_sort_count(expert_idxs, num_experts=self.num_experts)
79
-
80
-
81
- h, gates = self.experts(
82
- x, self.top_k,
83
- sorted_expert_idxs, sorted_scattered_idxs,
84
- expert_offsets,
85
- grouped_out=True
86
- ).chunk(2, dim=-1)
87
- h = self.activation(gates) * h
88
- y = self.output_experts(
89
- h, 1, sorted_expert_idxs, sorted_scattered_idxs,
90
- expert_offsets,
91
- grouped_in=True,
92
- gates=expert_p,
93
- )
94
- y = y.view(*x_shape[:-1], y.size(-1))
95
- return y
96
-