kernel
danieldk HF Staff YangKai0616 commited on
Commit
e94ff91
·
verified ·
1 Parent(s): 74b8263

Add support for XPU (sycl) (#3)

Browse files

- Add support for XPU(sycl) (b87db80d558d76f8193f293889e6aee81e9d3743)
- Revert changes to the build folder (f425aa577e0fb59ed08c568985b007ab686da778)
- Delete pyc files (acc75a3bf1a67e06659873b98909b4e997b3f45f)


Co-authored-by: Kai Yang <[email protected]>

build.toml CHANGED
@@ -9,3 +9,11 @@ src = ["torch-ext/torch_binding.cpp"]
9
  backend = "cuda"
10
  depends = ["torch"]
11
  src = ["rotary/rotary_cuda.cu"]
 
 
 
 
 
 
 
 
 
9
  backend = "cuda"
10
  depends = ["torch"]
11
  src = ["rotary/rotary_cuda.cu"]
12
+
13
+ [kernel.rotary_xpu]
14
+ backend = "xpu"
15
+ depends = ["torch"]
16
+ src = [
17
+ "rotary-xpu/rotary_xpu.cpp",
18
+ "rotary-xpu/rotary_xpu.hpp",
19
+ ]
flake.lock CHANGED
@@ -17,11 +17,11 @@
17
  },
18
  "flake-compat_2": {
19
  "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
  "owner": "edolstra",
23
  "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
  "type": "github"
26
  },
27
  "original": {
@@ -73,11 +73,11 @@
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
- "lastModified": 1753354560,
77
- "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
- "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
  "type": "github"
82
  },
83
  "original": {
@@ -98,33 +98,32 @@
98
  ]
99
  },
100
  "locked": {
101
- "lastModified": 1753354632,
102
- "narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
- "rev": "524b628fd8e58525dbd28455bffb0628092c5265",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
110
- "ref": "torch-2.8",
111
  "repo": "kernel-builder",
112
  "type": "github"
113
  }
114
  },
115
  "nixpkgs": {
116
  "locked": {
117
- "lastModified": 1752785354,
118
- "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
  "owner": "nixos",
120
  "repo": "nixpkgs",
121
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
  "type": "github"
123
  },
124
  "original": {
125
  "owner": "nixos",
 
126
  "repo": "nixpkgs",
127
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
  "type": "github"
129
  }
130
  },
 
17
  },
18
  "flake-compat_2": {
19
  "locked": {
20
+ "lastModified": 1747046372,
21
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
  "owner": "edolstra",
23
  "repo": "flake-compat",
24
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
  "type": "github"
26
  },
27
  "original": {
 
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
+ "lastModified": 1757493151,
77
+ "narHash": "sha256-eirWlcvs2rjZmU8JcF4CKN1IEnNfpQnGuf2qbK3IQh8=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
+ "rev": "503cd4eb9866103c983dbef93d9ad5db4fb6b415",
81
  "type": "github"
82
  },
83
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1757570810,
102
+ "narHash": "sha256-YFWQwy2LKbhjdLW8wkyNkE/+Vbdn6qlJif2CKvBT9Qo=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "1201847af3ff757b65015c6e06b5bd75896d2d4b",
106
  "type": "github"
107
  },
108
  "original": {
109
  "owner": "huggingface",
 
110
  "repo": "kernel-builder",
111
  "type": "github"
112
  }
113
  },
114
  "nixpkgs": {
115
  "locked": {
116
+ "lastModified": 1755963616,
117
+ "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
  "owner": "nixos",
119
  "repo": "nixpkgs",
120
+ "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
  "type": "github"
122
  },
123
  "original": {
124
  "owner": "nixos",
125
+ "ref": "nixos-unstable-small",
126
  "repo": "nixpkgs",
 
127
  "type": "github"
128
  }
129
  },
flake.nix CHANGED
@@ -1,15 +1,9 @@
1
  {
2
- description = "Flake for rotary kernel";
3
-
4
  inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
  };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
  kernel-builder.lib.genFlakeOutputs {
14
  path = ./.;
15
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
 
1
  {
2
+ description = "Flake for Torch kernel extension";
 
3
  inputs = {
4
+ kernel-builder.url = "github:huggingface/kernel-builder";
5
  };
6
+ outputs = { self, kernel-builder, }:
 
 
 
 
 
7
  kernel-builder.lib.genFlakeOutputs {
8
  path = ./.;
9
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
rotary-xpu/rotary_xpu.cpp ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/all.h>
2
+ #include "rotary_xpu.hpp"
3
+
4
+ void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
5
+ torch::Tensor const &cos, torch::Tensor const &sin,
6
+ torch::Tensor &out1, torch::Tensor &out2,
7
+ bool const conj) {
8
+ auto iter = at::TensorIteratorConfig()
9
+ .add_output(out1)
10
+ .add_output(out2)
11
+ .add_input(x1)
12
+ .add_input(x2)
13
+ .add_input(cos)
14
+ .add_input(sin)
15
+ .check_all_same_dtype(false)
16
+ .promote_inputs_to_common_dtype(false)
17
+ .build();
18
+
19
+ if (!conj) {
20
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel_xpu", [&] {
21
+ gpu_kernel_multiple_outputs(
22
+ iter, [] (scalar_t x1, scalar_t x2, scalar_t cos,
23
+ scalar_t sin) -> std::tuple<scalar_t, scalar_t> {
24
+ scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
25
+ scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
26
+ return {out1, out2};
27
+ });
28
+ });
29
+ } else {
30
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel_xpu", [&] {
31
+ gpu_kernel_multiple_outputs(
32
+ iter, [] (scalar_t x1, scalar_t x2, scalar_t cos,
33
+ scalar_t sin) -> std::tuple<scalar_t, scalar_t> {
34
+ scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
35
+ scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
36
+ return {out1, out2};
37
+ });
38
+ });
39
+ }
40
+ }
rotary-xpu/rotary_xpu.hpp ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/core/TensorBody.h>
2
+ #include <ATen/detail/FunctionTraits.h>
3
+ #include <ATen/native/TensorIterator.h>
4
+ #include <sycl/sycl.hpp>
5
+ #include <ATen/core/Array.h>
6
+ #include <c10/macros/Macros.h>
7
+ #include <c10/util/Exception.h>
8
+ #include <c10/util/TypeCast.h>
9
+ #include <cstdint>
10
+ #include <type_traits>
11
+ #include <array>
12
+ #include <c10/core/ScalarType.h>
13
+ #include <c10/xpu/XPUStream.h>
14
+ #include <ATen/xpu/XPUContext.h>
15
+
16
+ constexpr int MAX_DIMS = 12;
17
+
18
+ struct LoadWithoutCast {
19
+ template <typename scalar_t>
20
+ C10_DEVICE scalar_t load(char* base_ptr, uint32_t offset, int arg) {
21
+ return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
22
+ }
23
+ };
24
+
25
+ struct StoreWithoutCast {
26
+ template <typename scalar_t>
27
+ C10_DEVICE void store(scalar_t value, char* base_ptr, uint32_t offset, int arg = 0) {
28
+ *(reinterpret_cast<scalar_t*>(base_ptr) + offset) = value;
29
+ }
30
+ };
31
+
32
+ template <template <int i> typename func, int end, int current = 0>
33
+ struct static_unroll {
34
+ template <typename... Args>
35
+ static inline C10_HOST_DEVICE void with_args(Args&&... args) {
36
+ func<current>::apply(std::forward<Args>(args)...);
37
+ static_unroll<func, end, current + 1>::with_args(args...);
38
+ }
39
+ };
40
+
41
+ template <template <int i> typename func, int end>
42
+ struct static_unroll<func, end, end> {
43
+ template <typename... Args>
44
+ static inline C10_HOST_DEVICE void with_args(Args... args) {}
45
+ };
46
+
47
+ template <int current>
48
+ struct multi_outputs_store_helper {
49
+ template <int ntensors, int num_outputs, typename... Args>
50
+ static C10_HOST_DEVICE void apply(
51
+ at::detail::Array<char*, ntensors> data,
52
+ at::detail::Array<uint32_t, num_outputs> offsets,
53
+ std::tuple<Args...> ret) {
54
+ using T = typename std::tuple_element<current, std::tuple<Args...>>::type;
55
+ T* to = reinterpret_cast<T*>(data[current]) + offsets[current];
56
+ *to = std::get<current>(ret);
57
+ }
58
+ };
59
+
60
+ template <int arg_index>
61
+ struct unroll_load_helper {
62
+ template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
63
+ static C10_DEVICE void apply(
64
+ policy_t& self,
65
+ args_t* args,
66
+ offset_t offset,
67
+ loader_t loader,
68
+ int j,
69
+ int num_outputs) {
70
+ using arg_t = std::tuple_element_t<arg_index, args_t>;
71
+ std::get<arg_index>(args[j]) = loader.template load<arg_t>(
72
+ self.data[arg_index + num_outputs], offset[arg_index], arg_index);
73
+ }
74
+ };
75
+
76
+ template <int item_work_size, typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
77
+ struct multi_outputs_unroll {
78
+ data_t data;
79
+ int remaining;
80
+ inp_calc_t input_offset_calculator;
81
+ out_calc_t output_offset_calculator;
82
+ LoadWithoutCast loader;
83
+ StoreWithoutCast storer;
84
+ int item_idx;
85
+ int group_idx;
86
+ int num_items_per_group;
87
+ int group_work_size;
88
+
89
+ multi_outputs_unroll(
90
+ data_t data,
91
+ int remaining,
92
+ inp_calc_t ic,
93
+ out_calc_t oc,
94
+ int item_idx,
95
+ int group_idx,
96
+ int num_items_per_group)
97
+ : data(data),
98
+ remaining(remaining),
99
+ input_offset_calculator(ic),
100
+ output_offset_calculator(oc),
101
+ item_idx(item_idx),
102
+ group_idx(group_idx),
103
+ num_items_per_group(num_items_per_group),
104
+ group_work_size(item_work_size * num_items_per_group) {}
105
+
106
+ inline bool check_inbounds(int item_work_elem) const {
107
+ return (item_idx + item_work_elem * num_items_per_group < remaining);
108
+ }
109
+
110
+ template <typename args_t>
111
+ inline void load(args_t* args) {
112
+ constexpr int arity = std::tuple_size<args_t>::value;
113
+ int item_idx_ = item_idx;
114
+ #pragma unroll
115
+ for (int i = 0; i < item_work_size; i++) {
116
+ if (item_idx_ >= remaining) {
117
+ return;
118
+ }
119
+ int linear_idx = item_idx_ + group_work_size * group_idx;
120
+ auto offset = input_offset_calculator.get(linear_idx);
121
+ static_unroll<unroll_load_helper, arity>::with_args(
122
+ *this, args, offset, loader, i, num_outputs);
123
+ item_idx_ += num_items_per_group;
124
+ }
125
+ }
126
+
127
+ template <typename return_t>
128
+ inline void store(return_t* from) {
129
+ int item_idx_ = item_idx;
130
+ #pragma unroll
131
+ for (int i = 0; i < item_work_size; i++) {
132
+ if (item_idx_ >= this->remaining) {
133
+ return;
134
+ }
135
+ int linear_idx = item_idx_ + group_work_size * group_idx;
136
+ auto offsets = this->output_offset_calculator.get(linear_idx);
137
+ static_unroll<multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
138
+ item_idx_ += num_items_per_group;
139
+ }
140
+ }
141
+ };
142
+
143
+ template <int item_work_size, typename func_t, typename policy_t>
144
+ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
145
+ using traits = function_traits<func_t>;
146
+ using return_t = typename traits::result_type;
147
+ using args_t = typename traits::ArgsTuple;
148
+
149
+ return_t results[item_work_size];
150
+ args_t args[item_work_size];
151
+
152
+ policy.load(args);
153
+
154
+ #pragma unroll
155
+ for (int i = 0; i < item_work_size; i++) {
156
+ if (policy.check_inbounds(i)) {
157
+ results[i] = std::apply(f, args[i]);
158
+ }
159
+ }
160
+
161
+ policy.store(results);
162
+ }
163
+
164
+ template <int num_outputs, typename func_t, typename array_t, typename in_calc_t, typename out_calc_t>
165
+ struct UnrolledElementwiseForMultiOutputsKernel {
166
+ static constexpr int item_work_size = 4;
167
+
168
+ void operator()(sycl::nd_item<1> item_id) const {
169
+ int grpsz = item_id.get_local_range(0);
170
+ int grpid = item_id.get_group(0);
171
+ int lid = item_id.get_local_id(0);
172
+ int remaining = numel_ - item_work_size * grpsz * grpid;
173
+ auto policy = multi_outputs_unroll<item_work_size, array_t, in_calc_t, out_calc_t, num_outputs>(
174
+ data_, remaining, ic_, oc_, lid, grpid, grpsz);
175
+ elementwise_kernel_helper<item_work_size>(f_, policy);
176
+ };
177
+
178
+ UnrolledElementwiseForMultiOutputsKernel(int numel, func_t f, array_t data, in_calc_t ic, out_calc_t oc)
179
+ : numel_(numel), f_(f), data_(data), ic_(ic), oc_(oc) {}
180
+
181
+ private:
182
+ int numel_;
183
+ func_t f_;
184
+ array_t data_;
185
+ in_calc_t ic_;
186
+ out_calc_t oc_;
187
+ };
188
+
189
+ template <typename Value>
190
+ struct IntDivider {
191
+ IntDivider() = default;
192
+ IntDivider(Value d) : divisor(d) {}
193
+
194
+ C10_HOST_DEVICE inline Value div(Value n) const {
195
+ return n / divisor;
196
+ }
197
+ C10_HOST_DEVICE inline Value mod(Value n) const {
198
+ return n % divisor;
199
+ }
200
+ C10_HOST_DEVICE inline auto divmod(Value n) const {
201
+ return std::make_pair(n / divisor, n % divisor);
202
+ }
203
+
204
+ Value divisor;
205
+ };
206
+
207
+ template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
208
+ struct OffsetCalculator {
209
+ using stride_t = std::conditional_t<signed_strides, std::make_signed_t<index_t>, index_t>;
210
+ using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
211
+
212
+ OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes = nullptr)
213
+ : dims(dims) {
214
+ TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
215
+ for (int i = 0; i < dims; i++) {
216
+ sizes_[i] = IntDivider<index_t>(sizes[i]);
217
+ for (int arg = 0; arg < NARGS; arg++) {
218
+ int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
219
+ strides_[i][arg] = strides[arg][i] / element_size;
220
+ }
221
+ }
222
+ }
223
+
224
+ C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
225
+ offset_type offsets;
226
+ #pragma unroll
227
+ for (int arg = 0; arg < NARGS; arg++) {
228
+ offsets[arg] = 0;
229
+ }
230
+
231
+ #pragma unroll
232
+ for (int dim = 0; dim < MAX_DIMS; ++dim) {
233
+ if (dim == dims) {
234
+ break;
235
+ }
236
+ auto divmod = sizes_[dim].divmod(linear_idx);
237
+ linear_idx = divmod.first;
238
+
239
+ #pragma unroll
240
+ for (int arg = 0; arg < NARGS; arg++) {
241
+ offsets[arg] += divmod.second * strides_[dim][arg];
242
+ }
243
+ }
244
+ return offsets;
245
+ }
246
+
247
+ int dims;
248
+ IntDivider<index_t> sizes_[MAX_DIMS];
249
+ stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
250
+ };
251
+
252
+ template <int N>
253
+ static OffsetCalculator<N> make_input_offset_calculator(const at::TensorIteratorBase& iter) {
254
+ constexpr int array_size = std::max<int>(N, 1);
255
+ TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
256
+ std::array<const int64_t*, array_size> strides;
257
+ int64_t element_sizes[array_size];
258
+ for (int i = 0; i < N; i++) {
259
+ strides[i] = iter.strides(i + iter.noutputs()).data();
260
+ element_sizes[i] = iter.element_size(i + iter.noutputs());
261
+ }
262
+ return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
263
+ }
264
+
265
+ template <int num_outputs = 1>
266
+ static OffsetCalculator<num_outputs> make_output_offset_calculator(const at::TensorIteratorBase& iter) {
267
+ TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
268
+ std::array<const int64_t*, num_outputs> strides;
269
+ int64_t element_sizes[num_outputs];
270
+ for (int i = 0; i < num_outputs; i++) {
271
+ strides[i] = iter.strides(i).data();
272
+ element_sizes[i] = iter.element_size(i);
273
+ }
274
+ return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
275
+ }
276
+
277
+ static inline int64_t syclMaxWorkItemsPerSubSlice(at::DeviceIndex dev_id = c10::xpu::getCurrentXPUStream().device_index()) {
278
+ auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
279
+ int64_t simd_width = dev_prop->sub_group_sizes[0];
280
+ int64_t eu_count = dev_prop->gpu_eu_count_per_subslice;
281
+ return simd_width * eu_count;
282
+ }
283
+
284
+ template<class T>
285
+ T ceil_div(T dividend, T divisor) {
286
+ return (dividend + divisor - 1) / divisor;
287
+ }
288
+
289
+ template <typename ker_t>
290
+ static inline void sycl_kernel_submit(int64_t global_range, int64_t local_range, ::sycl::queue q, ker_t ker) {
291
+ q.parallel_for(
292
+ sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)),
293
+ ker
294
+ );
295
+ }
296
+
297
+ template <int num_outputs, typename func_t, typename array_t, typename in_calc_t, typename out_calc_t>
298
+ static inline void launch_unrolled_kernel_for_multi_outputs(
299
+ int64_t N,
300
+ const func_t& f,
301
+ array_t data,
302
+ in_calc_t ic,
303
+ out_calc_t oc) {
304
+ TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
305
+
306
+ auto ker = UnrolledElementwiseForMultiOutputsKernel<num_outputs, func_t, array_t, in_calc_t, out_calc_t>(N, f, data, ic, oc);
307
+ using ker_t = decltype(ker);
308
+
309
+ int wg_sz = syclMaxWorkItemsPerSubSlice();
310
+ int num_wg = ceil_div<int>(N, ker_t::item_work_size * wg_sz);
311
+ sycl_kernel_submit(wg_sz * num_wg, wg_sz, c10::xpu::getCurrentXPUStream().queue(), ker);
312
+ }
313
+
314
+ template <int N>
315
+ struct TrivialOffsetCalculator {
316
+ using offset_type = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
317
+
318
+ C10_HOST_DEVICE offset_type get(uint32_t linear_idx) const {
319
+ offset_type offsets;
320
+ #pragma unroll
321
+ for (int arg = 0; arg < N; arg++) {
322
+ offsets[arg] = linear_idx;
323
+ }
324
+ return offsets;
325
+ }
326
+ };
327
+
328
+ template <typename func_t>
329
+ void gpu_kernel_multiple_outputs_impl(at::TensorIteratorBase& iter, const func_t& f) {
330
+ using traits = function_traits<func_t>;
331
+ using output_t = typename traits::result_type;
332
+ constexpr int num_outputs = std::tuple_size<output_t>::value;
333
+ constexpr int num_inputs = traits::arity;
334
+ constexpr int ntensors = num_outputs + num_inputs;
335
+
336
+ TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
337
+ TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
338
+
339
+ at::detail::Array<char*, ntensors> data;
340
+ for (int i = 0; i < ntensors; i++) {
341
+ data[i] = (char*)iter.data_ptr(i);
342
+ }
343
+
344
+ int64_t numel = iter.numel();
345
+
346
+ if (iter.is_contiguous()) {
347
+ auto input_calc = TrivialOffsetCalculator<num_inputs>();
348
+ auto output_calc = TrivialOffsetCalculator<num_outputs>();
349
+ launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
350
+ } else {
351
+ auto input_calc = make_input_offset_calculator<num_inputs>(iter);
352
+ auto output_calc = make_output_offset_calculator<num_outputs>(iter);
353
+ launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
354
+ }
355
+ }
356
+
357
+ template <typename func_t>
358
+ void gpu_kernel_multiple_outputs(at::TensorIteratorBase& iter, const func_t& f) {
359
+ for (int arg = 0; arg < iter.ntensors(); arg++) {
360
+ TORCH_INTERNAL_ASSERT(iter.device(arg).is_xpu());
361
+ }
362
+
363
+ if (iter.numel() == 0) {
364
+ return;
365
+ }
366
+
367
+ if (!iter.can_use_32bit_indexing()) {
368
+ for (auto& sub_iter : iter.with_32bit_indexing()) {
369
+ gpu_kernel_multiple_outputs(sub_iter, f);
370
+ }
371
+ return;
372
+ }
373
+
374
+ gpu_kernel_multiple_outputs_impl(iter, f);
375
+ }
tests/__init__.py ADDED
File without changes
tests/test_rotary.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+
4
+ from tests.utils import infer_device, supports_bfloat16
5
+ from kernels import get_local_kernel
6
+ from pathlib import Path
7
+
8
+ # import rotary
9
+ # from transformers.trainer_utils import set_seed
10
+ # set_seed(42)
11
+
12
+ # Set the local repo path, relative path
13
+ repo_path = Path(__file__).parent.parent
14
+ rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
15
+
16
+ def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
17
+ assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
18
+
19
+ if not conj:
20
+ out1 = x1 * cos - x2 * sin
21
+ out2 = x1 * sin + x2 * cos
22
+ else:
23
+ out1 = x1 * cos + x2 * sin
24
+ out2 = -x1 * sin + x2 * cos
25
+ return out1, out2
26
+
27
+
28
+ def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False):
29
+ """the wrapper for apply_rotary_torch"""
30
+ rotary_dim = cos.shape[-1]
31
+
32
+ # apply rotation encoding to Q
33
+ q1 = q[..., :rotary_dim]
34
+ q2 = q[..., rotary_dim : 2 * rotary_dim]
35
+ q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj)
36
+ q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1)
37
+
38
+ # apply rotation encoding to K
39
+ k1 = k[..., :rotary_dim]
40
+ k2 = k[..., rotary_dim : 2 * rotary_dim]
41
+ k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj)
42
+ k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1)
43
+
44
+ return q_out, k_out
45
+
46
+
47
+ def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False):
48
+ """the wrapper for apply_rotary_kernel"""
49
+ rotary_dim = cos.shape[-1]
50
+
51
+ # apply rotation encoding to Q
52
+ q1 = q[..., :rotary_dim]
53
+ q2 = q[..., rotary_dim : 2 * rotary_dim]
54
+ rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj)
55
+
56
+ # apply rotation encoding to K
57
+ k1 = k[..., :rotary_dim]
58
+ k2 = k[..., rotary_dim : 2 * rotary_dim]
59
+ rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj)
60
+
61
+
62
+ @pytest.mark.parametrize("batch_size", [1, 2])
63
+ @pytest.mark.parametrize("nheads", [8, 16])
64
+ @pytest.mark.parametrize("seqlen", [128, 256])
65
+ @pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)])
66
+ @pytest.mark.parametrize("qk_dim", [3, 4])
67
+ @pytest.mark.parametrize(
68
+ "dtype, atol, rtol",
69
+ [
70
+ (torch.float32, 1e-5, 1e-5),
71
+ pytest.param(
72
+ torch.bfloat16,
73
+ 1e-1,
74
+ 1e-5,
75
+ marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
76
+ ),
77
+ ],
78
+ )
79
+ @pytest.mark.parametrize("conj", [False, True])
80
+ @pytest.mark.flaky(max_runs=2, min_passes=1)
81
+ def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj):
82
+ device = infer_device()
83
+ if device is None:
84
+ pytest.skip("No suitable device found for testing")
85
+
86
+ if qk_dim == 4:
87
+ q_shape = (batch_size, seqlen, nheads, headdim)
88
+ cos_sin_shape = (seqlen, 1, rotary_dim)
89
+ elif qk_dim == 3:
90
+ q_shape = (batch_size * seqlen, nheads, headdim)
91
+ cos_sin_shape = (batch_size * seqlen, 1, rotary_dim)
92
+
93
+ q_orig = torch.randn(q_shape, device=device, dtype=dtype)
94
+ k_orig = torch.randn(q_shape, device=device, dtype=dtype)
95
+ cos = torch.randn(cos_sin_shape, device=device, dtype=dtype)
96
+ sin = torch.randn(cos_sin_shape, device=device, dtype=dtype)
97
+
98
+ q_kernel, k_kernel = q_orig.clone(), k_orig.clone()
99
+ q_torch, k_torch = q_orig.clone(), k_orig.clone()
100
+
101
+ q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj)
102
+ apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj)
103
+
104
+ # verify the rotation results of Q and K are consistent
105
+ try:
106
+ assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match"
107
+ except AssertionError:
108
+ diff_q = torch.abs(q_torch_out - q_kernel)
109
+ max_diff_q = torch.max(diff_q)
110
+ print(f"Max difference for Q: {max_diff_q}")
111
+ raise
112
+ try:
113
+ assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match"
114
+ except AssertionError:
115
+ diff_k = torch.abs(k_torch_out - k_kernel)
116
+ max_diff_k = torch.max(diff_k)
117
+ print(f"Max difference for K: {max_diff_k}")
118
+ raise
119
+
120
+ # verify the non-rotated part of Q and K remains unchanged
121
+ if (2 * rotary_dim) < headdim:
122
+ assert torch.equal(
123
+ q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:]
124
+ ), "Non-rotated part of Q should be unchanged"
125
+ assert torch.equal(
126
+ k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:]
127
+ ), "Non-rotated part of K should be unchanged"
tests/utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def infer_device():
5
+ """
6
+ Get current device name based on available devices
7
+ """
8
+ if torch.cuda.is_available(): # Works for both Nvidia and AMD
9
+ return "cuda"
10
+ elif torch.xpu.is_available():
11
+ return "xpu"
12
+ else:
13
+ return None
14
+
15
+
16
+ def supports_bfloat16():
17
+ device = infer_device()
18
+ if device == "cuda":
19
+ return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
20
+ elif device == "xpu":
21
+ return True
22
+ else:
23
+ return False
torch-ext/torch_binding.cpp CHANGED
@@ -1,12 +1,17 @@
1
  #include <torch/all.h>
 
 
2
  #include <c10/cuda/CUDAGuard.h>
 
 
 
3
 
4
  #include "registration.h"
5
 
6
- #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
7
  #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
8
 
9
- void apply_rotary_cuda(torch::Tensor const &x1, torch::Tensor const &x2,
10
  torch::Tensor const &cos, torch::Tensor const &sin,
11
  torch::Tensor &out1, torch::Tensor &out2,
12
  bool const conj);
@@ -27,16 +32,23 @@ void apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
27
  TORCH_CHECK(cos.sizes() == sin.sizes());
28
  TORCH_CHECK(out1.sizes() == out2.sizes());
29
 
 
30
  // Otherwise the kernel will be launched from cuda:0 device
31
  at::cuda::CUDAGuard device_guard{x1.device()};
32
-
33
- apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
 
 
34
  }
35
 
36
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
37
  ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
38
  "Tensor! out1, Tensor! out2, bool conj) -> ()");
39
- ops.impl("apply_rotary", torch::kCUDA, &apply_rotary);
 
 
 
 
40
  }
41
 
42
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
1
  #include <torch/all.h>
2
+
3
+ #if defined(CUDA_KERNEL)
4
  #include <c10/cuda/CUDAGuard.h>
5
+ #elif defined(XPU_KERNEL)
6
+ #include <c10/core/DeviceGuard.h>
7
+ #endif
8
 
9
  #include "registration.h"
10
 
11
+ #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA || x.device().type() == torch::kXPU, #x " must be on CUDA or XPU")
12
  #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
 
14
+ void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
15
  torch::Tensor const &cos, torch::Tensor const &sin,
16
  torch::Tensor &out1, torch::Tensor &out2,
17
  bool const conj);
 
32
  TORCH_CHECK(cos.sizes() == sin.sizes());
33
  TORCH_CHECK(out1.sizes() == out2.sizes());
34
 
35
+ #if defined(CUDA_KERNEL)
36
  // Otherwise the kernel will be launched from cuda:0 device
37
  at::cuda::CUDAGuard device_guard{x1.device()};
38
+ #elif defined(XPU_KERNEL)
39
+ c10::DeviceGuard device_guard{x1.device()};
40
+ #endif
41
+ _apply_rotary(x1, x2, cos, sin, out1, out2, conj);
42
  }
43
 
44
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
45
  ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
46
  "Tensor! out1, Tensor! out2, bool conj) -> ()");
47
+ #if defined(CUDA_KERNEL)
48
+ ops.impl("apply_rotary", torch::kCUDA, &apply_rotary);
49
+ #elif defined(XPU_KERNEL)
50
+ ops.impl("apply_rotary", torch::kXPU, &apply_rotary);
51
+ #endif
52
  }
53
 
54
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)