Update core/data_architecture/sparse_ops.py
Browse files
core/data_architecture/sparse_ops.py
CHANGED
@@ -18,10 +18,13 @@ import apex
|
|
18 |
from apex import amp
|
19 |
from apex.optimizers import FusedAdam
|
20 |
|
|
|
|
|
|
|
21 |
class SparseLinear(nn.Module):
|
22 |
"""
|
23 |
Sparse Linear Layer with Tensor Core Optimizations and Dynamic Pruning.
|
24 |
-
|
25 |
"""
|
26 |
def __init__(self, in_features, out_features, sparsity=0.5, use_fp16=True, dynamic_pruning=False):
|
27 |
super(SparseLinear, self).__init__()
|
@@ -29,7 +32,7 @@ class SparseLinear(nn.Module):
|
|
29 |
self.out_features = out_features
|
30 |
self.sparsity = sparsity
|
31 |
self.use_fp16 = use_fp16
|
32 |
-
self.dynamic_pruning = dynamic_pruning
|
33 |
|
34 |
# Initialize dense weight and bias
|
35 |
self.weight = nn.Parameter(
|
@@ -39,7 +42,7 @@ class SparseLinear(nn.Module):
|
|
39 |
torch.zeros(out_features, dtype=torch.float16 if use_fp16 else torch.float32)
|
40 |
)
|
41 |
|
42 |
-
# Sparse mask
|
43 |
self.register_buffer("mask", self.generate_mask())
|
44 |
|
45 |
def generate_mask(self):
|
@@ -47,30 +50,29 @@ class SparseLinear(nn.Module):
|
|
47 |
Generates a binary mask based on weight magnitude for structured sparsity.
|
48 |
"""
|
49 |
if self.dynamic_pruning:
|
50 |
-
# Dynamic pruning will recompute this in forward pass
|
51 |
return torch.ones_like(self.weight)
|
52 |
weights_abs = self.weight.abs()
|
53 |
threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
|
54 |
-
return (weights_abs > threshold).
|
55 |
|
56 |
def update_mask(self):
|
57 |
"""Update mask dynamically based on current weight magnitudes."""
|
58 |
if self.dynamic_pruning:
|
59 |
weights_abs = self.weight.abs()
|
60 |
threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
|
61 |
-
self.mask.data = (weights_abs > threshold).
|
62 |
|
63 |
def forward(self, x):
|
64 |
if self.dynamic_pruning:
|
65 |
-
self.update_mask()
|
66 |
|
67 |
-
if self.use_fp16:
|
68 |
-
|
69 |
-
|
70 |
-
return F.linear(x, pruned_weight, self.bias)
|
71 |
else:
|
|
|
72 |
pruned_weight = self.weight.float() * self.mask.float()
|
73 |
-
return F.linear(x.float(), pruned_weight, self.bias.float())
|
74 |
|
75 |
|
76 |
class SparseConv2d(nn.Module):
|
@@ -84,7 +86,7 @@ class SparseConv2d(nn.Module):
|
|
84 |
self.use_fp16 = use_fp16
|
85 |
self.sparsity = sparsity
|
86 |
self.dynamic_pruning = dynamic_pruning
|
87 |
-
self.block_size = block_size
|
88 |
|
89 |
self.conv = nn.Conv2d(
|
90 |
in_channels,
|
@@ -97,16 +99,11 @@ class SparseConv2d(nn.Module):
|
|
97 |
self.register_buffer("mask", self.generate_mask())
|
98 |
|
99 |
def generate_mask(self):
|
100 |
-
"""
|
101 |
-
Generate a mask based on weight magnitude, optionally with block sparsity.
|
102 |
-
"""
|
103 |
weights = self.conv.weight
|
104 |
if self.dynamic_pruning:
|
105 |
return torch.ones_like(weights)
|
106 |
-
|
107 |
weights_abs = weights.abs()
|
108 |
-
if self.block_size:
|
109 |
-
# Reshape weights into blocks and compute block-wise magnitude
|
110 |
kh, kw = self.block_size
|
111 |
weights_reshaped = weights_abs.view(weights_abs.size(0), weights_abs.size(1),
|
112 |
weights_abs.size(2) // kh, kh,
|
@@ -114,7 +111,6 @@ class SparseConv2d(nn.Module):
|
|
114 |
block_magnitudes = weights_reshaped.norm(p=2, dim=(3, 4))
|
115 |
threshold = torch.quantile(block_magnitudes.flatten(), self.sparsity)
|
116 |
block_mask = (block_magnitudes > threshold).float()
|
117 |
-
# Expand block mask back to full weight shape
|
118 |
mask = block_mask.unsqueeze(-1).unsqueeze(-1).expand_as(weights_reshaped).reshape_as(weights)
|
119 |
else:
|
120 |
threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
|
@@ -122,7 +118,6 @@ class SparseConv2d(nn.Module):
|
|
122 |
return mask
|
123 |
|
124 |
def update_mask(self):
|
125 |
-
"""Update mask dynamically based on current weight magnitudes."""
|
126 |
if self.dynamic_pruning:
|
127 |
self.mask.data = self.generate_mask()
|
128 |
|
@@ -143,7 +138,7 @@ class SparseConv2d(nn.Module):
|
|
143 |
class SparseMLP(nn.Module):
|
144 |
"""
|
145 |
Sparse MLP with Tensor Core Acceleration and optional dynamic pruning.
|
146 |
-
Uses sparse linear layers
|
147 |
"""
|
148 |
def __init__(self, input_dim, hidden_dim, output_dim, sparsity=0.5,
|
149 |
use_fp16=True, dynamic_pruning=False):
|
@@ -155,9 +150,41 @@ class SparseMLP(nn.Module):
|
|
155 |
def forward(self, x):
|
156 |
if self.use_fp16:
|
157 |
with autocast():
|
158 |
-
x =
|
159 |
x = self.fc2(x)
|
160 |
return x
|
161 |
else:
|
162 |
-
x =
|
163 |
-
return self.fc2(x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from apex import amp
|
19 |
from apex.optimizers import FusedAdam
|
20 |
|
21 |
+
# Assuming fused_ops is compiled and available
|
22 |
+
import fused_ops # Custom CUDA extension from fused_ops.cu
|
23 |
+
|
24 |
class SparseLinear(nn.Module):
|
25 |
"""
|
26 |
Sparse Linear Layer with Tensor Core Optimizations and Dynamic Pruning.
|
27 |
+
Integrates fused GEMM + ReLU CUDA kernel for GPU efficiency.
|
28 |
"""
|
29 |
def __init__(self, in_features, out_features, sparsity=0.5, use_fp16=True, dynamic_pruning=False):
|
30 |
super(SparseLinear, self).__init__()
|
|
|
32 |
self.out_features = out_features
|
33 |
self.sparsity = sparsity
|
34 |
self.use_fp16 = use_fp16
|
35 |
+
self.dynamic_pruning = dynamic_pruning
|
36 |
|
37 |
# Initialize dense weight and bias
|
38 |
self.weight = nn.Parameter(
|
|
|
42 |
torch.zeros(out_features, dtype=torch.float16 if use_fp16 else torch.float32)
|
43 |
)
|
44 |
|
45 |
+
# Sparse mask
|
46 |
self.register_buffer("mask", self.generate_mask())
|
47 |
|
48 |
def generate_mask(self):
|
|
|
50 |
Generates a binary mask based on weight magnitude for structured sparsity.
|
51 |
"""
|
52 |
if self.dynamic_pruning:
|
|
|
53 |
return torch.ones_like(self.weight)
|
54 |
weights_abs = self.weight.abs()
|
55 |
threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
|
56 |
+
return (weights_abs > threshold).to(self.weight.dtype)
|
57 |
|
58 |
def update_mask(self):
|
59 |
"""Update mask dynamically based on current weight magnitudes."""
|
60 |
if self.dynamic_pruning:
|
61 |
weights_abs = self.weight.abs()
|
62 |
threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
|
63 |
+
self.mask.data = (weights_abs > threshold).to(self.weight.dtype)
|
64 |
|
65 |
def forward(self, x):
|
66 |
if self.dynamic_pruning:
|
67 |
+
self.update_mask()
|
68 |
|
69 |
+
if self.use_fp16 and x.is_cuda():
|
70 |
+
# Use fused CUDA kernel for GEMM + ReLU
|
71 |
+
return fused_ops.fused_sparse_gemm_relu(x, self.weight, self.mask, self.bias)
|
|
|
72 |
else:
|
73 |
+
# Fallback to PyTorch
|
74 |
pruned_weight = self.weight.float() * self.mask.float()
|
75 |
+
return F.relu(F.linear(x.float(), pruned_weight, self.bias.float()))
|
76 |
|
77 |
|
78 |
class SparseConv2d(nn.Module):
|
|
|
86 |
self.use_fp16 = use_fp16
|
87 |
self.sparsity = sparsity
|
88 |
self.dynamic_pruning = dynamic_pruning
|
89 |
+
self.block_size = block_size
|
90 |
|
91 |
self.conv = nn.Conv2d(
|
92 |
in_channels,
|
|
|
99 |
self.register_buffer("mask", self.generate_mask())
|
100 |
|
101 |
def generate_mask(self):
|
|
|
|
|
|
|
102 |
weights = self.conv.weight
|
103 |
if self.dynamic_pruning:
|
104 |
return torch.ones_like(weights)
|
|
|
105 |
weights_abs = weights.abs()
|
106 |
+
if self.block_size:
|
|
|
107 |
kh, kw = self.block_size
|
108 |
weights_reshaped = weights_abs.view(weights_abs.size(0), weights_abs.size(1),
|
109 |
weights_abs.size(2) // kh, kh,
|
|
|
111 |
block_magnitudes = weights_reshaped.norm(p=2, dim=(3, 4))
|
112 |
threshold = torch.quantile(block_magnitudes.flatten(), self.sparsity)
|
113 |
block_mask = (block_magnitudes > threshold).float()
|
|
|
114 |
mask = block_mask.unsqueeze(-1).unsqueeze(-1).expand_as(weights_reshaped).reshape_as(weights)
|
115 |
else:
|
116 |
threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
|
|
|
118 |
return mask
|
119 |
|
120 |
def update_mask(self):
|
|
|
121 |
if self.dynamic_pruning:
|
122 |
self.mask.data = self.generate_mask()
|
123 |
|
|
|
138 |
class SparseMLP(nn.Module):
|
139 |
"""
|
140 |
Sparse MLP with Tensor Core Acceleration and optional dynamic pruning.
|
141 |
+
Uses sparse linear layers with fused ops for efficiency.
|
142 |
"""
|
143 |
def __init__(self, input_dim, hidden_dim, output_dim, sparsity=0.5,
|
144 |
use_fp16=True, dynamic_pruning=False):
|
|
|
150 |
def forward(self, x):
|
151 |
if self.use_fp16:
|
152 |
with autocast():
|
153 |
+
x = self.fc1(x) # Already includes ReLU from fused kernel
|
154 |
x = self.fc2(x)
|
155 |
return x
|
156 |
else:
|
157 |
+
x = self.fc1(x) # Includes ReLU from fallback
|
158 |
+
return self.fc2(x)
|
159 |
+
|
160 |
+
# Example training loop with Apex mixed precision and FusedAdam
|
161 |
+
def train_sparse_mlp():
|
162 |
+
model = SparseMLP(784, 256, 10, sparsity=0.5, use_fp16=True).cuda()
|
163 |
+
optimizer = FusedAdam(model.parameters(), lr=0.001)
|
164 |
+
|
165 |
+
# Initialize Apex AMP
|
166 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
|
167 |
+
|
168 |
+
# Dummy data
|
169 |
+
inputs = torch.randn(32, 784).cuda()
|
170 |
+
targets = torch.randint(0, 10, (32,)).cuda()
|
171 |
+
|
172 |
+
# Training loop
|
173 |
+
for _ in range(100):
|
174 |
+
optimizer.zero_grad()
|
175 |
+
outputs = model(inputs)
|
176 |
+
loss = F.cross_entropy(outputs, targets)
|
177 |
+
|
178 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
179 |
+
scaled_loss.backward()
|
180 |
+
optimizer.step()
|
181 |
+
|
182 |
+
# Export to ONNX
|
183 |
+
torch.onnx.export(model, inputs, "sparse_mlp.onnx", opset_version=12)
|
184 |
+
|
185 |
+
# Convert to TensorRT
|
186 |
+
model_trt = torch2trt(model, [inputs], fp16_mode=True)
|
187 |
+
return model_trt
|
188 |
+
|
189 |
+
if __name__ == "__main__":
|
190 |
+
trt_model = train_sparse_mlp()
|