GeminiFan207 commited on
Commit
6ae562b
·
verified ·
1 Parent(s): 6088ebf

Update core/data_architecture/sparse_ops.py

Browse files
Files changed (1) hide show
  1. core/data_architecture/sparse_ops.py +52 -25
core/data_architecture/sparse_ops.py CHANGED
@@ -18,10 +18,13 @@ import apex
18
  from apex import amp
19
  from apex.optimizers import FusedAdam
20
 
 
 
 
21
  class SparseLinear(nn.Module):
22
  """
23
  Sparse Linear Layer with Tensor Core Optimizations and Dynamic Pruning.
24
- Prunes weights based on magnitude to improve efficiency on GPU.
25
  """
26
  def __init__(self, in_features, out_features, sparsity=0.5, use_fp16=True, dynamic_pruning=False):
27
  super(SparseLinear, self).__init__()
@@ -29,7 +32,7 @@ class SparseLinear(nn.Module):
29
  self.out_features = out_features
30
  self.sparsity = sparsity
31
  self.use_fp16 = use_fp16
32
- self.dynamic_pruning = dynamic_pruning # Toggle dynamic vs static pruning
33
 
34
  # Initialize dense weight and bias
35
  self.weight = nn.Parameter(
@@ -39,7 +42,7 @@ class SparseLinear(nn.Module):
39
  torch.zeros(out_features, dtype=torch.float16 if use_fp16 else torch.float32)
40
  )
41
 
42
- # Sparse mask (static unless dynamic_pruning is enabled)
43
  self.register_buffer("mask", self.generate_mask())
44
 
45
  def generate_mask(self):
@@ -47,30 +50,29 @@ class SparseLinear(nn.Module):
47
  Generates a binary mask based on weight magnitude for structured sparsity.
48
  """
49
  if self.dynamic_pruning:
50
- # Dynamic pruning will recompute this in forward pass
51
  return torch.ones_like(self.weight)
52
  weights_abs = self.weight.abs()
53
  threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
54
- return (weights_abs > threshold).float()
55
 
56
  def update_mask(self):
57
  """Update mask dynamically based on current weight magnitudes."""
58
  if self.dynamic_pruning:
59
  weights_abs = self.weight.abs()
60
  threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
61
- self.mask.data = (weights_abs > threshold).float()
62
 
63
  def forward(self, x):
64
  if self.dynamic_pruning:
65
- self.update_mask() # Recompute mask if dynamic pruning is enabled
66
 
67
- if self.use_fp16:
68
- with autocast():
69
- pruned_weight = self.weight * self.mask
70
- return F.linear(x, pruned_weight, self.bias)
71
  else:
 
72
  pruned_weight = self.weight.float() * self.mask.float()
73
- return F.linear(x.float(), pruned_weight, self.bias.float())
74
 
75
 
76
  class SparseConv2d(nn.Module):
@@ -84,7 +86,7 @@ class SparseConv2d(nn.Module):
84
  self.use_fp16 = use_fp16
85
  self.sparsity = sparsity
86
  self.dynamic_pruning = dynamic_pruning
87
- self.block_size = block_size # Optional block sparsity (e.g., (2, 2))
88
 
89
  self.conv = nn.Conv2d(
90
  in_channels,
@@ -97,16 +99,11 @@ class SparseConv2d(nn.Module):
97
  self.register_buffer("mask", self.generate_mask())
98
 
99
  def generate_mask(self):
100
- """
101
- Generate a mask based on weight magnitude, optionally with block sparsity.
102
- """
103
  weights = self.conv.weight
104
  if self.dynamic_pruning:
105
  return torch.ones_like(weights)
106
-
107
  weights_abs = weights.abs()
108
- if self.block_size: # Block sparsity
109
- # Reshape weights into blocks and compute block-wise magnitude
110
  kh, kw = self.block_size
111
  weights_reshaped = weights_abs.view(weights_abs.size(0), weights_abs.size(1),
112
  weights_abs.size(2) // kh, kh,
@@ -114,7 +111,6 @@ class SparseConv2d(nn.Module):
114
  block_magnitudes = weights_reshaped.norm(p=2, dim=(3, 4))
115
  threshold = torch.quantile(block_magnitudes.flatten(), self.sparsity)
116
  block_mask = (block_magnitudes > threshold).float()
117
- # Expand block mask back to full weight shape
118
  mask = block_mask.unsqueeze(-1).unsqueeze(-1).expand_as(weights_reshaped).reshape_as(weights)
119
  else:
120
  threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
@@ -122,7 +118,6 @@ class SparseConv2d(nn.Module):
122
  return mask
123
 
124
  def update_mask(self):
125
- """Update mask dynamically based on current weight magnitudes."""
126
  if self.dynamic_pruning:
127
  self.mask.data = self.generate_mask()
128
 
@@ -143,7 +138,7 @@ class SparseConv2d(nn.Module):
143
  class SparseMLP(nn.Module):
144
  """
145
  Sparse MLP with Tensor Core Acceleration and optional dynamic pruning.
146
- Uses sparse linear layers to reduce computation.
147
  """
148
  def __init__(self, input_dim, hidden_dim, output_dim, sparsity=0.5,
149
  use_fp16=True, dynamic_pruning=False):
@@ -155,9 +150,41 @@ class SparseMLP(nn.Module):
155
  def forward(self, x):
156
  if self.use_fp16:
157
  with autocast():
158
- x = F.relu(self.fc1(x))
159
  x = self.fc2(x)
160
  return x
161
  else:
162
- x = F.relu(self.fc1(x.float()))
163
- return self.fc2(x.float())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from apex import amp
19
  from apex.optimizers import FusedAdam
20
 
21
+ # Assuming fused_ops is compiled and available
22
+ import fused_ops # Custom CUDA extension from fused_ops.cu
23
+
24
  class SparseLinear(nn.Module):
25
  """
26
  Sparse Linear Layer with Tensor Core Optimizations and Dynamic Pruning.
27
+ Integrates fused GEMM + ReLU CUDA kernel for GPU efficiency.
28
  """
29
  def __init__(self, in_features, out_features, sparsity=0.5, use_fp16=True, dynamic_pruning=False):
30
  super(SparseLinear, self).__init__()
 
32
  self.out_features = out_features
33
  self.sparsity = sparsity
34
  self.use_fp16 = use_fp16
35
+ self.dynamic_pruning = dynamic_pruning
36
 
37
  # Initialize dense weight and bias
38
  self.weight = nn.Parameter(
 
42
  torch.zeros(out_features, dtype=torch.float16 if use_fp16 else torch.float32)
43
  )
44
 
45
+ # Sparse mask
46
  self.register_buffer("mask", self.generate_mask())
47
 
48
  def generate_mask(self):
 
50
  Generates a binary mask based on weight magnitude for structured sparsity.
51
  """
52
  if self.dynamic_pruning:
 
53
  return torch.ones_like(self.weight)
54
  weights_abs = self.weight.abs()
55
  threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
56
+ return (weights_abs > threshold).to(self.weight.dtype)
57
 
58
  def update_mask(self):
59
  """Update mask dynamically based on current weight magnitudes."""
60
  if self.dynamic_pruning:
61
  weights_abs = self.weight.abs()
62
  threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
63
+ self.mask.data = (weights_abs > threshold).to(self.weight.dtype)
64
 
65
  def forward(self, x):
66
  if self.dynamic_pruning:
67
+ self.update_mask()
68
 
69
+ if self.use_fp16 and x.is_cuda():
70
+ # Use fused CUDA kernel for GEMM + ReLU
71
+ return fused_ops.fused_sparse_gemm_relu(x, self.weight, self.mask, self.bias)
 
72
  else:
73
+ # Fallback to PyTorch
74
  pruned_weight = self.weight.float() * self.mask.float()
75
+ return F.relu(F.linear(x.float(), pruned_weight, self.bias.float()))
76
 
77
 
78
  class SparseConv2d(nn.Module):
 
86
  self.use_fp16 = use_fp16
87
  self.sparsity = sparsity
88
  self.dynamic_pruning = dynamic_pruning
89
+ self.block_size = block_size
90
 
91
  self.conv = nn.Conv2d(
92
  in_channels,
 
99
  self.register_buffer("mask", self.generate_mask())
100
 
101
  def generate_mask(self):
 
 
 
102
  weights = self.conv.weight
103
  if self.dynamic_pruning:
104
  return torch.ones_like(weights)
 
105
  weights_abs = weights.abs()
106
+ if self.block_size:
 
107
  kh, kw = self.block_size
108
  weights_reshaped = weights_abs.view(weights_abs.size(0), weights_abs.size(1),
109
  weights_abs.size(2) // kh, kh,
 
111
  block_magnitudes = weights_reshaped.norm(p=2, dim=(3, 4))
112
  threshold = torch.quantile(block_magnitudes.flatten(), self.sparsity)
113
  block_mask = (block_magnitudes > threshold).float()
 
114
  mask = block_mask.unsqueeze(-1).unsqueeze(-1).expand_as(weights_reshaped).reshape_as(weights)
115
  else:
116
  threshold = torch.quantile(weights_abs.flatten(), self.sparsity)
 
118
  return mask
119
 
120
  def update_mask(self):
 
121
  if self.dynamic_pruning:
122
  self.mask.data = self.generate_mask()
123
 
 
138
  class SparseMLP(nn.Module):
139
  """
140
  Sparse MLP with Tensor Core Acceleration and optional dynamic pruning.
141
+ Uses sparse linear layers with fused ops for efficiency.
142
  """
143
  def __init__(self, input_dim, hidden_dim, output_dim, sparsity=0.5,
144
  use_fp16=True, dynamic_pruning=False):
 
150
  def forward(self, x):
151
  if self.use_fp16:
152
  with autocast():
153
+ x = self.fc1(x) # Already includes ReLU from fused kernel
154
  x = self.fc2(x)
155
  return x
156
  else:
157
+ x = self.fc1(x) # Includes ReLU from fallback
158
+ return self.fc2(x)
159
+
160
+ # Example training loop with Apex mixed precision and FusedAdam
161
+ def train_sparse_mlp():
162
+ model = SparseMLP(784, 256, 10, sparsity=0.5, use_fp16=True).cuda()
163
+ optimizer = FusedAdam(model.parameters(), lr=0.001)
164
+
165
+ # Initialize Apex AMP
166
+ model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
167
+
168
+ # Dummy data
169
+ inputs = torch.randn(32, 784).cuda()
170
+ targets = torch.randint(0, 10, (32,)).cuda()
171
+
172
+ # Training loop
173
+ for _ in range(100):
174
+ optimizer.zero_grad()
175
+ outputs = model(inputs)
176
+ loss = F.cross_entropy(outputs, targets)
177
+
178
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
179
+ scaled_loss.backward()
180
+ optimizer.step()
181
+
182
+ # Export to ONNX
183
+ torch.onnx.export(model, inputs, "sparse_mlp.onnx", opset_version=12)
184
+
185
+ # Convert to TensorRT
186
+ model_trt = torch2trt(model, [inputs], fp16_mode=True)
187
+ return model_trt
188
+
189
+ if __name__ == "__main__":
190
+ trt_model = train_sparse_mlp()