iamwyldecat commited on
Commit
d14fd4d
·
1 Parent(s): 704692b

feat(poly-norm): add perf test

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.so filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.so filter=lfs diff=lfs merge=lfs -text
37
+ tests/perf.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -24,3 +24,17 @@ x = torch.randn(10, 10)
24
 
25
  print(poly_norm(x))
26
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  print(poly_norm(x))
26
  ```
27
+
28
+ ## Performance
29
+
30
+ ### PolyNorm
31
+
32
+ - Test cases are from the Motif LLM
33
+ - You can reproduce the results with:
34
+ ```bash
35
+ cd tests
36
+ pytest --run-perf --do-plot
37
+ ```
38
+
39
+ ![PolyNorm Performance](./tests/perf.png)
40
+
build/torch26-cxx11-rocm62-x86_64-linux/activation/{_activation_883cc1c_dirty.abi3.so → _activation_704692b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a9d74188efdcb10158b338cf363749494f86e9712797722310f0a6ac5310efdd
3
  size 2401160
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:417cf142fb8234b05f7e5b0be321d3a95ceafd7c0b3e5d3469579a52d78ddb1e
3
  size 2401160
build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_883cc1c_dirty
3
- ops = torch.ops._activation_883cc1c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_883cc1c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_704692b_dirty
3
+ ops = torch.ops._activation_704692b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_704692b_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/{_activation_883cc1c_dirty.abi3.so → _activation_704692b_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:719fc6521c0824b253cb11ea9e564ef7835e2102e5bc6399cfdb69203d6d5c26
3
  size 2395176
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fe6163d88e95c0d6847b3fe993cd80de677f89cfde7fc4d5c3ec2d0d96c9de8
3
  size 2395176
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_883cc1c_dirty
3
- ops = torch.ops._activation_883cc1c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_883cc1c_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_704692b_dirty
3
+ ops = torch.ops._activation_704692b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_704692b_dirty::{op_name}"
tests/conftest.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ import pytest
6
+
7
+ from .kernels.test_perf import PERF_RESULTS, PerfResult
8
+
9
+ logger = logging.getLogger(__name__)
10
+ DO_PLOT = False
11
+
12
+
13
+ def plot(perf_results: list[PerfResult]):
14
+ x_labels = [f"{r.type}, {r.shape}, {r.dtype}" for r in perf_results]
15
+ kernel_speedup = [r.speedup for r in perf_results]
16
+ torch_speedup = [1 for _ in perf_results]
17
+
18
+ geo_mean = float(np.exp(np.mean(np.log(kernel_speedup))))
19
+ x_labels.append("Geometric Mean")
20
+ kernel_speedup.append(geo_mean)
21
+ torch_speedup.append(1.0)
22
+
23
+ fig = go.Figure()
24
+
25
+ bar_width = 0.2
26
+ fig.add_trace(
27
+ go.Bar(
28
+ x=x_labels,
29
+ y=kernel_speedup,
30
+ name="Activation",
31
+ marker_color="rgb(100, 100, 100)",
32
+ text=[f"x{v:.2f}" for v in kernel_speedup],
33
+ textfont=dict(size=14),
34
+ textposition="outside",
35
+ # width=[bar_width] * len(x_labels),
36
+ )
37
+ )
38
+
39
+ fig.add_trace(
40
+ go.Bar(
41
+ x=x_labels,
42
+ y=torch_speedup,
43
+ name="Torch",
44
+ marker_color="rgb(30, 30, 30)",
45
+ text=[f"x{v:.2f}" for v in torch_speedup],
46
+ textfont=dict(size=14),
47
+ textposition="outside",
48
+ # width=[bar_width] * len(x_labels),
49
+ )
50
+ )
51
+
52
+ fig.update_layout(
53
+ title=dict(
54
+ text="<b>Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)</b>",
55
+ font=dict(size=24),
56
+ ),
57
+ legend=dict(
58
+ x=0.01,
59
+ y=0.99,
60
+ xanchor="left",
61
+ yanchor="top",
62
+ bgcolor="rgba(0,0,0,0)",
63
+ bordercolor="black",
64
+ borderwidth=1,
65
+ ),
66
+ font=dict(size=16),
67
+ yaxis_title="Speedup (torch / activation)",
68
+ barmode="group",
69
+ bargroupgap=0,
70
+ bargap=0.2,
71
+ xaxis_tickangle=-45,
72
+ template="plotly_white",
73
+ yaxis_type="log",
74
+ shapes=[
75
+ dict(
76
+ type="rect",
77
+ xref="x",
78
+ yref="paper", # y축 전체 범위 (0~1)
79
+ x0=-0.5,
80
+ x1=len(x_labels) - 0.5,
81
+ y0=0,
82
+ y1=1,
83
+ line=dict(
84
+ color="black",
85
+ width=1.5,
86
+ ),
87
+ fillcolor="rgba(0,0,0,0)", # 투명 배경
88
+ layer="above", # bar 아래에 그리기
89
+ )
90
+ ],
91
+ )
92
+
93
+ output_file = "perf_result.html"
94
+ fig.write_html(output_file)
95
+ logger.info(f"Plotting performance results to {output_file}")
96
+
97
+
98
+ def pytest_addoption(parser):
99
+ parser.addoption(
100
+ "--run-perf", action="store_true", default=False, help="Run perf tests"
101
+ )
102
+ parser.addoption(
103
+ "--do-plot", action="store_true", default=False, help="Plot performance results"
104
+ )
105
+
106
+
107
+ @pytest.fixture
108
+ def do_plot(request):
109
+ return request.config.getoption("--do-plot")
110
+
111
+
112
+ def pytest_configure(config):
113
+ global DO_PLOT
114
+ DO_PLOT = config.getoption("--do-plot")
115
+ run_perf = config.getoption("--run-perf")
116
+
117
+ if DO_PLOT and not run_perf:
118
+ raise ValueError(
119
+ "Cannot plot performance results without running performance tests. "
120
+ "Please use --run-perf option."
121
+ )
122
+
123
+ config.addinivalue_line("markers", "perf: mark test as performance-related")
124
+
125
+
126
+ def pytest_collection_modifyitems(config, items):
127
+ run_perf = config.getoption("--run-perf")
128
+
129
+ skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
130
+ skip_normal = pytest.mark.skip(
131
+ reason="normal tests skipped when --run-perf is used"
132
+ )
133
+ for item in items:
134
+ if "perf" in item.keywords and not run_perf:
135
+ item.add_marker(skip_perf)
136
+ elif "perf" not in item.keywords and run_perf:
137
+ item.add_marker(skip_normal)
138
+
139
+
140
+ def pytest_sessionfinish(session, exitstatus) -> None:
141
+ if DO_PLOT:
142
+ plot(PERF_RESULTS)
143
+ else:
144
+ logger.info(PERF_RESULTS)
tests/kernels/test_perf.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+
4
+ import pytest
5
+ import torch
6
+
7
+ import activation
8
+
9
+ from .test_activation import poly_norm
10
+ from .utils import assert_close
11
+
12
+ CASES = [
13
+ ((1, 2048, 8192), torch.bfloat16),
14
+ ((1, 2048, 16384), torch.bfloat16),
15
+ ((1, 16384, 8192), torch.bfloat16),
16
+ ((1, 16384, 16384), torch.bfloat16),
17
+ ]
18
+ NUM_REP = 100
19
+
20
+
21
+ @dataclass
22
+ class PerfResult:
23
+ type: str # forward or backward
24
+ shape: tuple
25
+ dtype: torch.dtype
26
+ kernel_time_ms: float
27
+ torch_time_ms: float
28
+
29
+ @property
30
+ def speedup(self) -> float:
31
+ return self.torch_time_ms / self.kernel_time_ms
32
+
33
+
34
+ PERF_RESULTS: list[PerfResult] = []
35
+
36
+
37
+ @pytest.mark.parametrize("cases", CASES)
38
+ @pytest.mark.perf
39
+ def test_poly_norm(
40
+ cases: tuple,
41
+ do_plot: bool,
42
+ ) -> None:
43
+ random.seed(12345)
44
+ torch.manual_seed(12345)
45
+
46
+ torch.set_default_device("cuda")
47
+
48
+ shape, dtype = cases
49
+ x = torch.randn(shape, dtype=dtype, requires_grad=True)
50
+ weight = torch.randn(3, dtype=dtype, requires_grad=True)
51
+ bias = torch.randn(1, dtype=dtype, requires_grad=True)
52
+ eps = 1e-05
53
+
54
+ x.retain_grad()
55
+ weight.retain_grad()
56
+ bias.retain_grad()
57
+ # To separate gradient computation, clone the inputs
58
+
59
+ x_ref = x.detach().clone().requires_grad_(True)
60
+ weight_ref = weight.detach().clone().requires_grad_(True)
61
+ bias_ref = bias.detach().clone().requires_grad_(True)
62
+
63
+ torch_fn = poly_norm
64
+ layer = activation.layers.PolyNorm(eps)
65
+ layer.weight = torch.nn.Parameter(weight)
66
+ layer.bias = torch.nn.Parameter(bias)
67
+
68
+ # Check correctness
69
+ mod_out = layer(x)
70
+ ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
71
+ assert_close(mod_out, ref_out)
72
+
73
+ out_grad = torch.rand_like(ref_out)
74
+ out_grad = out_grad / out_grad.norm()
75
+
76
+ ref_out.backward(out_grad, retain_graph=True)
77
+ mod_out.backward(out_grad, retain_graph=True)
78
+
79
+ assert_close(x.grad, x_ref.grad)
80
+ assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
81
+ assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
82
+
83
+ def time_cuda(fn):
84
+ start = torch.cuda.Event(enable_timing=True)
85
+ end = torch.cuda.Event(enable_timing=True)
86
+
87
+ for _ in range(5):
88
+ fn()
89
+ start.record()
90
+ for _ in range(NUM_REP):
91
+ fn()
92
+ end.record()
93
+ torch.cuda.synchronize()
94
+ return start.elapsed_time(end) / NUM_REP
95
+
96
+ kernel_time_ms = time_cuda(lambda: layer(x))
97
+ torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
98
+
99
+ PERF_RESULTS.append(
100
+ PerfResult(
101
+ type="forward",
102
+ shape=shape,
103
+ dtype=dtype,
104
+ kernel_time_ms=kernel_time_ms,
105
+ torch_time_ms=torch_fn_time,
106
+ )
107
+ )
108
+
109
+ kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True))
110
+ torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True))
111
+
112
+ PERF_RESULTS.append(
113
+ PerfResult(
114
+ type="backward",
115
+ shape=shape,
116
+ dtype=dtype,
117
+ kernel_time_ms=kernel_time_ms,
118
+ torch_time_ms=torch_fn_time,
119
+ )
120
+ )
tests/perf.png ADDED

Git LFS Details

  • SHA256: 12f88f9ac4511cb37f38a34e3572e4347bd0c857144a4aaf64bd5981d6b50877
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
tests/perf_result.html ADDED
The diff for this file is too large to render. See raw diff