Add support for XPU (sycl) (#3)
Browse files- Add support for XPU(sycl) (b87db80d558d76f8193f293889e6aee81e9d3743)
- Revert changes to the build folder (f425aa577e0fb59ed08c568985b007ab686da778)
- Delete pyc files (acc75a3bf1a67e06659873b98909b4e997b3f45f)
Co-authored-by: Kai Yang <[email protected]>
- build.toml +8 -0
- flake.lock +13 -14
- flake.nix +3 -9
- rotary-xpu/rotary_xpu.cpp +40 -0
- rotary-xpu/rotary_xpu.hpp +375 -0
- tests/__init__.py +0 -0
- tests/test_rotary.py +127 -0
- tests/utils.py +23 -0
- torch-ext/torch_binding.cpp +17 -5
build.toml
CHANGED
@@ -9,3 +9,11 @@ src = ["torch-ext/torch_binding.cpp"]
|
|
9 |
backend = "cuda"
|
10 |
depends = ["torch"]
|
11 |
src = ["rotary/rotary_cuda.cu"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
backend = "cuda"
|
10 |
depends = ["torch"]
|
11 |
src = ["rotary/rotary_cuda.cu"]
|
12 |
+
|
13 |
+
[kernel.rotary_xpu]
|
14 |
+
backend = "xpu"
|
15 |
+
depends = ["torch"]
|
16 |
+
src = [
|
17 |
+
"rotary-xpu/rotary_xpu.cpp",
|
18 |
+
"rotary-xpu/rotary_xpu.hpp",
|
19 |
+
]
|
flake.lock
CHANGED
@@ -17,11 +17,11 @@
|
|
17 |
},
|
18 |
"flake-compat_2": {
|
19 |
"locked": {
|
20 |
-
"lastModified":
|
21 |
-
"narHash": "sha256-
|
22 |
"owner": "edolstra",
|
23 |
"repo": "flake-compat",
|
24 |
-
"rev": "
|
25 |
"type": "github"
|
26 |
},
|
27 |
"original": {
|
@@ -73,11 +73,11 @@
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
-
"lastModified":
|
77 |
-
"narHash": "sha256-
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
-
"rev": "
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
@@ -98,33 +98,32 @@
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
-
"lastModified":
|
102 |
-
"narHash": "sha256-
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
-
"rev": "
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
110 |
-
"ref": "torch-2.8",
|
111 |
"repo": "kernel-builder",
|
112 |
"type": "github"
|
113 |
}
|
114 |
},
|
115 |
"nixpkgs": {
|
116 |
"locked": {
|
117 |
-
"lastModified":
|
118 |
-
"narHash": "sha256-
|
119 |
"owner": "nixos",
|
120 |
"repo": "nixpkgs",
|
121 |
-
"rev": "
|
122 |
"type": "github"
|
123 |
},
|
124 |
"original": {
|
125 |
"owner": "nixos",
|
|
|
126 |
"repo": "nixpkgs",
|
127 |
-
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
128 |
"type": "github"
|
129 |
}
|
130 |
},
|
|
|
17 |
},
|
18 |
"flake-compat_2": {
|
19 |
"locked": {
|
20 |
+
"lastModified": 1747046372,
|
21 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
22 |
"owner": "edolstra",
|
23 |
"repo": "flake-compat",
|
24 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
25 |
"type": "github"
|
26 |
},
|
27 |
"original": {
|
|
|
73 |
"nixpkgs": "nixpkgs"
|
74 |
},
|
75 |
"locked": {
|
76 |
+
"lastModified": 1757493151,
|
77 |
+
"narHash": "sha256-eirWlcvs2rjZmU8JcF4CKN1IEnNfpQnGuf2qbK3IQh8=",
|
78 |
"owner": "huggingface",
|
79 |
"repo": "hf-nix",
|
80 |
+
"rev": "503cd4eb9866103c983dbef93d9ad5db4fb6b415",
|
81 |
"type": "github"
|
82 |
},
|
83 |
"original": {
|
|
|
98 |
]
|
99 |
},
|
100 |
"locked": {
|
101 |
+
"lastModified": 1757570810,
|
102 |
+
"narHash": "sha256-YFWQwy2LKbhjdLW8wkyNkE/+Vbdn6qlJif2CKvBT9Qo=",
|
103 |
"owner": "huggingface",
|
104 |
"repo": "kernel-builder",
|
105 |
+
"rev": "1201847af3ff757b65015c6e06b5bd75896d2d4b",
|
106 |
"type": "github"
|
107 |
},
|
108 |
"original": {
|
109 |
"owner": "huggingface",
|
|
|
110 |
"repo": "kernel-builder",
|
111 |
"type": "github"
|
112 |
}
|
113 |
},
|
114 |
"nixpkgs": {
|
115 |
"locked": {
|
116 |
+
"lastModified": 1755963616,
|
117 |
+
"narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
|
118 |
"owner": "nixos",
|
119 |
"repo": "nixpkgs",
|
120 |
+
"rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
|
121 |
"type": "github"
|
122 |
},
|
123 |
"original": {
|
124 |
"owner": "nixos",
|
125 |
+
"ref": "nixos-unstable-small",
|
126 |
"repo": "nixpkgs",
|
|
|
127 |
"type": "github"
|
128 |
}
|
129 |
},
|
flake.nix
CHANGED
@@ -1,15 +1,9 @@
|
|
1 |
{
|
2 |
-
description = "Flake for
|
3 |
-
|
4 |
inputs = {
|
5 |
-
kernel-builder.url = "github:huggingface/kernel-builder
|
6 |
};
|
7 |
-
|
8 |
-
outputs =
|
9 |
-
{
|
10 |
-
self,
|
11 |
-
kernel-builder,
|
12 |
-
}:
|
13 |
kernel-builder.lib.genFlakeOutputs {
|
14 |
path = ./.;
|
15 |
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
|
|
1 |
{
|
2 |
+
description = "Flake for Torch kernel extension";
|
|
|
3 |
inputs = {
|
4 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
5 |
};
|
6 |
+
outputs = { self, kernel-builder, }:
|
|
|
|
|
|
|
|
|
|
|
7 |
kernel-builder.lib.genFlakeOutputs {
|
8 |
path = ./.;
|
9 |
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
rotary-xpu/rotary_xpu.cpp
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/all.h>
|
2 |
+
#include "rotary_xpu.hpp"
|
3 |
+
|
4 |
+
void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
|
5 |
+
torch::Tensor const &cos, torch::Tensor const &sin,
|
6 |
+
torch::Tensor &out1, torch::Tensor &out2,
|
7 |
+
bool const conj) {
|
8 |
+
auto iter = at::TensorIteratorConfig()
|
9 |
+
.add_output(out1)
|
10 |
+
.add_output(out2)
|
11 |
+
.add_input(x1)
|
12 |
+
.add_input(x2)
|
13 |
+
.add_input(cos)
|
14 |
+
.add_input(sin)
|
15 |
+
.check_all_same_dtype(false)
|
16 |
+
.promote_inputs_to_common_dtype(false)
|
17 |
+
.build();
|
18 |
+
|
19 |
+
if (!conj) {
|
20 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel_xpu", [&] {
|
21 |
+
gpu_kernel_multiple_outputs(
|
22 |
+
iter, [] (scalar_t x1, scalar_t x2, scalar_t cos,
|
23 |
+
scalar_t sin) -> std::tuple<scalar_t, scalar_t> {
|
24 |
+
scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
|
25 |
+
scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
|
26 |
+
return {out1, out2};
|
27 |
+
});
|
28 |
+
});
|
29 |
+
} else {
|
30 |
+
AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel_xpu", [&] {
|
31 |
+
gpu_kernel_multiple_outputs(
|
32 |
+
iter, [] (scalar_t x1, scalar_t x2, scalar_t cos,
|
33 |
+
scalar_t sin) -> std::tuple<scalar_t, scalar_t> {
|
34 |
+
scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
|
35 |
+
scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
|
36 |
+
return {out1, out2};
|
37 |
+
});
|
38 |
+
});
|
39 |
+
}
|
40 |
+
}
|
rotary-xpu/rotary_xpu.hpp
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/core/TensorBody.h>
|
2 |
+
#include <ATen/detail/FunctionTraits.h>
|
3 |
+
#include <ATen/native/TensorIterator.h>
|
4 |
+
#include <sycl/sycl.hpp>
|
5 |
+
#include <ATen/core/Array.h>
|
6 |
+
#include <c10/macros/Macros.h>
|
7 |
+
#include <c10/util/Exception.h>
|
8 |
+
#include <c10/util/TypeCast.h>
|
9 |
+
#include <cstdint>
|
10 |
+
#include <type_traits>
|
11 |
+
#include <array>
|
12 |
+
#include <c10/core/ScalarType.h>
|
13 |
+
#include <c10/xpu/XPUStream.h>
|
14 |
+
#include <ATen/xpu/XPUContext.h>
|
15 |
+
|
16 |
+
constexpr int MAX_DIMS = 12;
|
17 |
+
|
18 |
+
struct LoadWithoutCast {
|
19 |
+
template <typename scalar_t>
|
20 |
+
C10_DEVICE scalar_t load(char* base_ptr, uint32_t offset, int arg) {
|
21 |
+
return c10::load(reinterpret_cast<scalar_t*>(base_ptr) + offset);
|
22 |
+
}
|
23 |
+
};
|
24 |
+
|
25 |
+
struct StoreWithoutCast {
|
26 |
+
template <typename scalar_t>
|
27 |
+
C10_DEVICE void store(scalar_t value, char* base_ptr, uint32_t offset, int arg = 0) {
|
28 |
+
*(reinterpret_cast<scalar_t*>(base_ptr) + offset) = value;
|
29 |
+
}
|
30 |
+
};
|
31 |
+
|
32 |
+
template <template <int i> typename func, int end, int current = 0>
|
33 |
+
struct static_unroll {
|
34 |
+
template <typename... Args>
|
35 |
+
static inline C10_HOST_DEVICE void with_args(Args&&... args) {
|
36 |
+
func<current>::apply(std::forward<Args>(args)...);
|
37 |
+
static_unroll<func, end, current + 1>::with_args(args...);
|
38 |
+
}
|
39 |
+
};
|
40 |
+
|
41 |
+
template <template <int i> typename func, int end>
|
42 |
+
struct static_unroll<func, end, end> {
|
43 |
+
template <typename... Args>
|
44 |
+
static inline C10_HOST_DEVICE void with_args(Args... args) {}
|
45 |
+
};
|
46 |
+
|
47 |
+
template <int current>
|
48 |
+
struct multi_outputs_store_helper {
|
49 |
+
template <int ntensors, int num_outputs, typename... Args>
|
50 |
+
static C10_HOST_DEVICE void apply(
|
51 |
+
at::detail::Array<char*, ntensors> data,
|
52 |
+
at::detail::Array<uint32_t, num_outputs> offsets,
|
53 |
+
std::tuple<Args...> ret) {
|
54 |
+
using T = typename std::tuple_element<current, std::tuple<Args...>>::type;
|
55 |
+
T* to = reinterpret_cast<T*>(data[current]) + offsets[current];
|
56 |
+
*to = std::get<current>(ret);
|
57 |
+
}
|
58 |
+
};
|
59 |
+
|
60 |
+
template <int arg_index>
|
61 |
+
struct unroll_load_helper {
|
62 |
+
template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
|
63 |
+
static C10_DEVICE void apply(
|
64 |
+
policy_t& self,
|
65 |
+
args_t* args,
|
66 |
+
offset_t offset,
|
67 |
+
loader_t loader,
|
68 |
+
int j,
|
69 |
+
int num_outputs) {
|
70 |
+
using arg_t = std::tuple_element_t<arg_index, args_t>;
|
71 |
+
std::get<arg_index>(args[j]) = loader.template load<arg_t>(
|
72 |
+
self.data[arg_index + num_outputs], offset[arg_index], arg_index);
|
73 |
+
}
|
74 |
+
};
|
75 |
+
|
76 |
+
template <int item_work_size, typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
|
77 |
+
struct multi_outputs_unroll {
|
78 |
+
data_t data;
|
79 |
+
int remaining;
|
80 |
+
inp_calc_t input_offset_calculator;
|
81 |
+
out_calc_t output_offset_calculator;
|
82 |
+
LoadWithoutCast loader;
|
83 |
+
StoreWithoutCast storer;
|
84 |
+
int item_idx;
|
85 |
+
int group_idx;
|
86 |
+
int num_items_per_group;
|
87 |
+
int group_work_size;
|
88 |
+
|
89 |
+
multi_outputs_unroll(
|
90 |
+
data_t data,
|
91 |
+
int remaining,
|
92 |
+
inp_calc_t ic,
|
93 |
+
out_calc_t oc,
|
94 |
+
int item_idx,
|
95 |
+
int group_idx,
|
96 |
+
int num_items_per_group)
|
97 |
+
: data(data),
|
98 |
+
remaining(remaining),
|
99 |
+
input_offset_calculator(ic),
|
100 |
+
output_offset_calculator(oc),
|
101 |
+
item_idx(item_idx),
|
102 |
+
group_idx(group_idx),
|
103 |
+
num_items_per_group(num_items_per_group),
|
104 |
+
group_work_size(item_work_size * num_items_per_group) {}
|
105 |
+
|
106 |
+
inline bool check_inbounds(int item_work_elem) const {
|
107 |
+
return (item_idx + item_work_elem * num_items_per_group < remaining);
|
108 |
+
}
|
109 |
+
|
110 |
+
template <typename args_t>
|
111 |
+
inline void load(args_t* args) {
|
112 |
+
constexpr int arity = std::tuple_size<args_t>::value;
|
113 |
+
int item_idx_ = item_idx;
|
114 |
+
#pragma unroll
|
115 |
+
for (int i = 0; i < item_work_size; i++) {
|
116 |
+
if (item_idx_ >= remaining) {
|
117 |
+
return;
|
118 |
+
}
|
119 |
+
int linear_idx = item_idx_ + group_work_size * group_idx;
|
120 |
+
auto offset = input_offset_calculator.get(linear_idx);
|
121 |
+
static_unroll<unroll_load_helper, arity>::with_args(
|
122 |
+
*this, args, offset, loader, i, num_outputs);
|
123 |
+
item_idx_ += num_items_per_group;
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
template <typename return_t>
|
128 |
+
inline void store(return_t* from) {
|
129 |
+
int item_idx_ = item_idx;
|
130 |
+
#pragma unroll
|
131 |
+
for (int i = 0; i < item_work_size; i++) {
|
132 |
+
if (item_idx_ >= this->remaining) {
|
133 |
+
return;
|
134 |
+
}
|
135 |
+
int linear_idx = item_idx_ + group_work_size * group_idx;
|
136 |
+
auto offsets = this->output_offset_calculator.get(linear_idx);
|
137 |
+
static_unroll<multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
|
138 |
+
item_idx_ += num_items_per_group;
|
139 |
+
}
|
140 |
+
}
|
141 |
+
};
|
142 |
+
|
143 |
+
template <int item_work_size, typename func_t, typename policy_t>
|
144 |
+
inline void elementwise_kernel_helper(func_t f, policy_t policy) {
|
145 |
+
using traits = function_traits<func_t>;
|
146 |
+
using return_t = typename traits::result_type;
|
147 |
+
using args_t = typename traits::ArgsTuple;
|
148 |
+
|
149 |
+
return_t results[item_work_size];
|
150 |
+
args_t args[item_work_size];
|
151 |
+
|
152 |
+
policy.load(args);
|
153 |
+
|
154 |
+
#pragma unroll
|
155 |
+
for (int i = 0; i < item_work_size; i++) {
|
156 |
+
if (policy.check_inbounds(i)) {
|
157 |
+
results[i] = std::apply(f, args[i]);
|
158 |
+
}
|
159 |
+
}
|
160 |
+
|
161 |
+
policy.store(results);
|
162 |
+
}
|
163 |
+
|
164 |
+
template <int num_outputs, typename func_t, typename array_t, typename in_calc_t, typename out_calc_t>
|
165 |
+
struct UnrolledElementwiseForMultiOutputsKernel {
|
166 |
+
static constexpr int item_work_size = 4;
|
167 |
+
|
168 |
+
void operator()(sycl::nd_item<1> item_id) const {
|
169 |
+
int grpsz = item_id.get_local_range(0);
|
170 |
+
int grpid = item_id.get_group(0);
|
171 |
+
int lid = item_id.get_local_id(0);
|
172 |
+
int remaining = numel_ - item_work_size * grpsz * grpid;
|
173 |
+
auto policy = multi_outputs_unroll<item_work_size, array_t, in_calc_t, out_calc_t, num_outputs>(
|
174 |
+
data_, remaining, ic_, oc_, lid, grpid, grpsz);
|
175 |
+
elementwise_kernel_helper<item_work_size>(f_, policy);
|
176 |
+
};
|
177 |
+
|
178 |
+
UnrolledElementwiseForMultiOutputsKernel(int numel, func_t f, array_t data, in_calc_t ic, out_calc_t oc)
|
179 |
+
: numel_(numel), f_(f), data_(data), ic_(ic), oc_(oc) {}
|
180 |
+
|
181 |
+
private:
|
182 |
+
int numel_;
|
183 |
+
func_t f_;
|
184 |
+
array_t data_;
|
185 |
+
in_calc_t ic_;
|
186 |
+
out_calc_t oc_;
|
187 |
+
};
|
188 |
+
|
189 |
+
template <typename Value>
|
190 |
+
struct IntDivider {
|
191 |
+
IntDivider() = default;
|
192 |
+
IntDivider(Value d) : divisor(d) {}
|
193 |
+
|
194 |
+
C10_HOST_DEVICE inline Value div(Value n) const {
|
195 |
+
return n / divisor;
|
196 |
+
}
|
197 |
+
C10_HOST_DEVICE inline Value mod(Value n) const {
|
198 |
+
return n % divisor;
|
199 |
+
}
|
200 |
+
C10_HOST_DEVICE inline auto divmod(Value n) const {
|
201 |
+
return std::make_pair(n / divisor, n % divisor);
|
202 |
+
}
|
203 |
+
|
204 |
+
Value divisor;
|
205 |
+
};
|
206 |
+
|
207 |
+
template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
|
208 |
+
struct OffsetCalculator {
|
209 |
+
using stride_t = std::conditional_t<signed_strides, std::make_signed_t<index_t>, index_t>;
|
210 |
+
using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
|
211 |
+
|
212 |
+
OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes = nullptr)
|
213 |
+
: dims(dims) {
|
214 |
+
TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
|
215 |
+
for (int i = 0; i < dims; i++) {
|
216 |
+
sizes_[i] = IntDivider<index_t>(sizes[i]);
|
217 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
218 |
+
int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
|
219 |
+
strides_[i][arg] = strides[arg][i] / element_size;
|
220 |
+
}
|
221 |
+
}
|
222 |
+
}
|
223 |
+
|
224 |
+
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
|
225 |
+
offset_type offsets;
|
226 |
+
#pragma unroll
|
227 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
228 |
+
offsets[arg] = 0;
|
229 |
+
}
|
230 |
+
|
231 |
+
#pragma unroll
|
232 |
+
for (int dim = 0; dim < MAX_DIMS; ++dim) {
|
233 |
+
if (dim == dims) {
|
234 |
+
break;
|
235 |
+
}
|
236 |
+
auto divmod = sizes_[dim].divmod(linear_idx);
|
237 |
+
linear_idx = divmod.first;
|
238 |
+
|
239 |
+
#pragma unroll
|
240 |
+
for (int arg = 0; arg < NARGS; arg++) {
|
241 |
+
offsets[arg] += divmod.second * strides_[dim][arg];
|
242 |
+
}
|
243 |
+
}
|
244 |
+
return offsets;
|
245 |
+
}
|
246 |
+
|
247 |
+
int dims;
|
248 |
+
IntDivider<index_t> sizes_[MAX_DIMS];
|
249 |
+
stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
|
250 |
+
};
|
251 |
+
|
252 |
+
template <int N>
|
253 |
+
static OffsetCalculator<N> make_input_offset_calculator(const at::TensorIteratorBase& iter) {
|
254 |
+
constexpr int array_size = std::max<int>(N, 1);
|
255 |
+
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
|
256 |
+
std::array<const int64_t*, array_size> strides;
|
257 |
+
int64_t element_sizes[array_size];
|
258 |
+
for (int i = 0; i < N; i++) {
|
259 |
+
strides[i] = iter.strides(i + iter.noutputs()).data();
|
260 |
+
element_sizes[i] = iter.element_size(i + iter.noutputs());
|
261 |
+
}
|
262 |
+
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
|
263 |
+
}
|
264 |
+
|
265 |
+
template <int num_outputs = 1>
|
266 |
+
static OffsetCalculator<num_outputs> make_output_offset_calculator(const at::TensorIteratorBase& iter) {
|
267 |
+
TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
|
268 |
+
std::array<const int64_t*, num_outputs> strides;
|
269 |
+
int64_t element_sizes[num_outputs];
|
270 |
+
for (int i = 0; i < num_outputs; i++) {
|
271 |
+
strides[i] = iter.strides(i).data();
|
272 |
+
element_sizes[i] = iter.element_size(i);
|
273 |
+
}
|
274 |
+
return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
|
275 |
+
}
|
276 |
+
|
277 |
+
static inline int64_t syclMaxWorkItemsPerSubSlice(at::DeviceIndex dev_id = c10::xpu::getCurrentXPUStream().device_index()) {
|
278 |
+
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
|
279 |
+
int64_t simd_width = dev_prop->sub_group_sizes[0];
|
280 |
+
int64_t eu_count = dev_prop->gpu_eu_count_per_subslice;
|
281 |
+
return simd_width * eu_count;
|
282 |
+
}
|
283 |
+
|
284 |
+
template<class T>
|
285 |
+
T ceil_div(T dividend, T divisor) {
|
286 |
+
return (dividend + divisor - 1) / divisor;
|
287 |
+
}
|
288 |
+
|
289 |
+
template <typename ker_t>
|
290 |
+
static inline void sycl_kernel_submit(int64_t global_range, int64_t local_range, ::sycl::queue q, ker_t ker) {
|
291 |
+
q.parallel_for(
|
292 |
+
sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)),
|
293 |
+
ker
|
294 |
+
);
|
295 |
+
}
|
296 |
+
|
297 |
+
template <int num_outputs, typename func_t, typename array_t, typename in_calc_t, typename out_calc_t>
|
298 |
+
static inline void launch_unrolled_kernel_for_multi_outputs(
|
299 |
+
int64_t N,
|
300 |
+
const func_t& f,
|
301 |
+
array_t data,
|
302 |
+
in_calc_t ic,
|
303 |
+
out_calc_t oc) {
|
304 |
+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
|
305 |
+
|
306 |
+
auto ker = UnrolledElementwiseForMultiOutputsKernel<num_outputs, func_t, array_t, in_calc_t, out_calc_t>(N, f, data, ic, oc);
|
307 |
+
using ker_t = decltype(ker);
|
308 |
+
|
309 |
+
int wg_sz = syclMaxWorkItemsPerSubSlice();
|
310 |
+
int num_wg = ceil_div<int>(N, ker_t::item_work_size * wg_sz);
|
311 |
+
sycl_kernel_submit(wg_sz * num_wg, wg_sz, c10::xpu::getCurrentXPUStream().queue(), ker);
|
312 |
+
}
|
313 |
+
|
314 |
+
template <int N>
|
315 |
+
struct TrivialOffsetCalculator {
|
316 |
+
using offset_type = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
|
317 |
+
|
318 |
+
C10_HOST_DEVICE offset_type get(uint32_t linear_idx) const {
|
319 |
+
offset_type offsets;
|
320 |
+
#pragma unroll
|
321 |
+
for (int arg = 0; arg < N; arg++) {
|
322 |
+
offsets[arg] = linear_idx;
|
323 |
+
}
|
324 |
+
return offsets;
|
325 |
+
}
|
326 |
+
};
|
327 |
+
|
328 |
+
template <typename func_t>
|
329 |
+
void gpu_kernel_multiple_outputs_impl(at::TensorIteratorBase& iter, const func_t& f) {
|
330 |
+
using traits = function_traits<func_t>;
|
331 |
+
using output_t = typename traits::result_type;
|
332 |
+
constexpr int num_outputs = std::tuple_size<output_t>::value;
|
333 |
+
constexpr int num_inputs = traits::arity;
|
334 |
+
constexpr int ntensors = num_outputs + num_inputs;
|
335 |
+
|
336 |
+
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
|
337 |
+
TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
|
338 |
+
|
339 |
+
at::detail::Array<char*, ntensors> data;
|
340 |
+
for (int i = 0; i < ntensors; i++) {
|
341 |
+
data[i] = (char*)iter.data_ptr(i);
|
342 |
+
}
|
343 |
+
|
344 |
+
int64_t numel = iter.numel();
|
345 |
+
|
346 |
+
if (iter.is_contiguous()) {
|
347 |
+
auto input_calc = TrivialOffsetCalculator<num_inputs>();
|
348 |
+
auto output_calc = TrivialOffsetCalculator<num_outputs>();
|
349 |
+
launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
|
350 |
+
} else {
|
351 |
+
auto input_calc = make_input_offset_calculator<num_inputs>(iter);
|
352 |
+
auto output_calc = make_output_offset_calculator<num_outputs>(iter);
|
353 |
+
launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
|
354 |
+
}
|
355 |
+
}
|
356 |
+
|
357 |
+
template <typename func_t>
|
358 |
+
void gpu_kernel_multiple_outputs(at::TensorIteratorBase& iter, const func_t& f) {
|
359 |
+
for (int arg = 0; arg < iter.ntensors(); arg++) {
|
360 |
+
TORCH_INTERNAL_ASSERT(iter.device(arg).is_xpu());
|
361 |
+
}
|
362 |
+
|
363 |
+
if (iter.numel() == 0) {
|
364 |
+
return;
|
365 |
+
}
|
366 |
+
|
367 |
+
if (!iter.can_use_32bit_indexing()) {
|
368 |
+
for (auto& sub_iter : iter.with_32bit_indexing()) {
|
369 |
+
gpu_kernel_multiple_outputs(sub_iter, f);
|
370 |
+
}
|
371 |
+
return;
|
372 |
+
}
|
373 |
+
|
374 |
+
gpu_kernel_multiple_outputs_impl(iter, f);
|
375 |
+
}
|
tests/__init__.py
ADDED
File without changes
|
tests/test_rotary.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from tests.utils import infer_device, supports_bfloat16
|
5 |
+
from kernels import get_local_kernel
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# import rotary
|
9 |
+
# from transformers.trainer_utils import set_seed
|
10 |
+
# set_seed(42)
|
11 |
+
|
12 |
+
# Set the local repo path, relative path
|
13 |
+
repo_path = Path(__file__).parent.parent
|
14 |
+
rotary = get_local_kernel(repo_path=repo_path, package_name="rotary")
|
15 |
+
|
16 |
+
def apply_rotary_torch(x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool = False):
|
17 |
+
assert x1.shape == x2.shape, "x1 and x2 must have the same shape"
|
18 |
+
|
19 |
+
if not conj:
|
20 |
+
out1 = x1 * cos - x2 * sin
|
21 |
+
out2 = x1 * sin + x2 * cos
|
22 |
+
else:
|
23 |
+
out1 = x1 * cos + x2 * sin
|
24 |
+
out2 = -x1 * sin + x2 * cos
|
25 |
+
return out1, out2
|
26 |
+
|
27 |
+
|
28 |
+
def apply_rotary_torch_wrapper(q, k, cos, sin, conj: bool = False):
|
29 |
+
"""the wrapper for apply_rotary_torch"""
|
30 |
+
rotary_dim = cos.shape[-1]
|
31 |
+
|
32 |
+
# apply rotation encoding to Q
|
33 |
+
q1 = q[..., :rotary_dim]
|
34 |
+
q2 = q[..., rotary_dim : 2 * rotary_dim]
|
35 |
+
q_out_1, q_out_2 = apply_rotary_torch(q1, q2, cos, sin, conj)
|
36 |
+
q_out = torch.cat([q_out_1, q_out_2, q[..., 2 * rotary_dim:]], dim=-1)
|
37 |
+
|
38 |
+
# apply rotation encoding to K
|
39 |
+
k1 = k[..., :rotary_dim]
|
40 |
+
k2 = k[..., rotary_dim : 2 * rotary_dim]
|
41 |
+
k_out_1, k_out_2 = apply_rotary_torch(k1, k2, cos, sin, conj)
|
42 |
+
k_out = torch.cat([k_out_1, k_out_2, k[..., 2 * rotary_dim:]], dim=-1)
|
43 |
+
|
44 |
+
return q_out, k_out
|
45 |
+
|
46 |
+
|
47 |
+
def apply_rotary_kernel_wrapper(q, k, cos, sin, conj: bool = False):
|
48 |
+
"""the wrapper for apply_rotary_kernel"""
|
49 |
+
rotary_dim = cos.shape[-1]
|
50 |
+
|
51 |
+
# apply rotation encoding to Q
|
52 |
+
q1 = q[..., :rotary_dim]
|
53 |
+
q2 = q[..., rotary_dim : 2 * rotary_dim]
|
54 |
+
rotary.apply_rotary(q1, q2, cos, sin, q1, q2, conj)
|
55 |
+
|
56 |
+
# apply rotation encoding to K
|
57 |
+
k1 = k[..., :rotary_dim]
|
58 |
+
k2 = k[..., rotary_dim : 2 * rotary_dim]
|
59 |
+
rotary.apply_rotary(k1, k2, cos, sin, k1, k2, conj)
|
60 |
+
|
61 |
+
|
62 |
+
@pytest.mark.parametrize("batch_size", [1, 2])
|
63 |
+
@pytest.mark.parametrize("nheads", [8, 16])
|
64 |
+
@pytest.mark.parametrize("seqlen", [128, 256])
|
65 |
+
@pytest.mark.parametrize("headdim, rotary_dim", [(64, 32), (128, 64), (64, 30)])
|
66 |
+
@pytest.mark.parametrize("qk_dim", [3, 4])
|
67 |
+
@pytest.mark.parametrize(
|
68 |
+
"dtype, atol, rtol",
|
69 |
+
[
|
70 |
+
(torch.float32, 1e-5, 1e-5),
|
71 |
+
pytest.param(
|
72 |
+
torch.bfloat16,
|
73 |
+
1e-1,
|
74 |
+
1e-5,
|
75 |
+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
|
76 |
+
),
|
77 |
+
],
|
78 |
+
)
|
79 |
+
@pytest.mark.parametrize("conj", [False, True])
|
80 |
+
@pytest.mark.flaky(max_runs=2, min_passes=1)
|
81 |
+
def test_rotary_equivalence(batch_size, nheads, seqlen, headdim, rotary_dim, qk_dim, dtype, atol, rtol, conj):
|
82 |
+
device = infer_device()
|
83 |
+
if device is None:
|
84 |
+
pytest.skip("No suitable device found for testing")
|
85 |
+
|
86 |
+
if qk_dim == 4:
|
87 |
+
q_shape = (batch_size, seqlen, nheads, headdim)
|
88 |
+
cos_sin_shape = (seqlen, 1, rotary_dim)
|
89 |
+
elif qk_dim == 3:
|
90 |
+
q_shape = (batch_size * seqlen, nheads, headdim)
|
91 |
+
cos_sin_shape = (batch_size * seqlen, 1, rotary_dim)
|
92 |
+
|
93 |
+
q_orig = torch.randn(q_shape, device=device, dtype=dtype)
|
94 |
+
k_orig = torch.randn(q_shape, device=device, dtype=dtype)
|
95 |
+
cos = torch.randn(cos_sin_shape, device=device, dtype=dtype)
|
96 |
+
sin = torch.randn(cos_sin_shape, device=device, dtype=dtype)
|
97 |
+
|
98 |
+
q_kernel, k_kernel = q_orig.clone(), k_orig.clone()
|
99 |
+
q_torch, k_torch = q_orig.clone(), k_orig.clone()
|
100 |
+
|
101 |
+
q_torch_out, k_torch_out = apply_rotary_torch_wrapper(q_torch, k_torch, cos, sin, conj)
|
102 |
+
apply_rotary_kernel_wrapper(q_kernel, k_kernel, cos, sin, conj)
|
103 |
+
|
104 |
+
# verify the rotation results of Q and K are consistent
|
105 |
+
try:
|
106 |
+
assert torch.allclose(q_torch_out, q_kernel, atol=atol, rtol=rtol), "Rotary transformation results for Q do not match"
|
107 |
+
except AssertionError:
|
108 |
+
diff_q = torch.abs(q_torch_out - q_kernel)
|
109 |
+
max_diff_q = torch.max(diff_q)
|
110 |
+
print(f"Max difference for Q: {max_diff_q}")
|
111 |
+
raise
|
112 |
+
try:
|
113 |
+
assert torch.allclose(k_torch_out, k_kernel, atol=atol, rtol=rtol), "Rotary transformation results for K do not match"
|
114 |
+
except AssertionError:
|
115 |
+
diff_k = torch.abs(k_torch_out - k_kernel)
|
116 |
+
max_diff_k = torch.max(diff_k)
|
117 |
+
print(f"Max difference for K: {max_diff_k}")
|
118 |
+
raise
|
119 |
+
|
120 |
+
# verify the non-rotated part of Q and K remains unchanged
|
121 |
+
if (2 * rotary_dim) < headdim:
|
122 |
+
assert torch.equal(
|
123 |
+
q_kernel[..., 2 * rotary_dim:], q_orig[..., 2 * rotary_dim:]
|
124 |
+
), "Non-rotated part of Q should be unchanged"
|
125 |
+
assert torch.equal(
|
126 |
+
k_kernel[..., 2 * rotary_dim:], k_orig[..., 2 * rotary_dim:]
|
127 |
+
), "Non-rotated part of K should be unchanged"
|
tests/utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def infer_device():
|
5 |
+
"""
|
6 |
+
Get current device name based on available devices
|
7 |
+
"""
|
8 |
+
if torch.cuda.is_available(): # Works for both Nvidia and AMD
|
9 |
+
return "cuda"
|
10 |
+
elif torch.xpu.is_available():
|
11 |
+
return "xpu"
|
12 |
+
else:
|
13 |
+
return None
|
14 |
+
|
15 |
+
|
16 |
+
def supports_bfloat16():
|
17 |
+
device = infer_device()
|
18 |
+
if device == "cuda":
|
19 |
+
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
|
20 |
+
elif device == "xpu":
|
21 |
+
return True
|
22 |
+
else:
|
23 |
+
return False
|
torch-ext/torch_binding.cpp
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
#include <torch/all.h>
|
|
|
|
|
2 |
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
|
|
3 |
|
4 |
#include "registration.h"
|
5 |
|
6 |
-
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
|
7 |
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
8 |
|
9 |
-
void
|
10 |
torch::Tensor const &cos, torch::Tensor const &sin,
|
11 |
torch::Tensor &out1, torch::Tensor &out2,
|
12 |
bool const conj);
|
@@ -27,16 +32,23 @@ void apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
|
|
27 |
TORCH_CHECK(cos.sizes() == sin.sizes());
|
28 |
TORCH_CHECK(out1.sizes() == out2.sizes());
|
29 |
|
|
|
30 |
// Otherwise the kernel will be launched from cuda:0 device
|
31 |
at::cuda::CUDAGuard device_guard{x1.device()};
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
}
|
35 |
|
36 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
37 |
ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
|
38 |
"Tensor! out1, Tensor! out2, bool conj) -> ()");
|
39 |
-
|
|
|
|
|
|
|
|
|
40 |
}
|
41 |
|
42 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
1 |
#include <torch/all.h>
|
2 |
+
|
3 |
+
#if defined(CUDA_KERNEL)
|
4 |
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
#elif defined(XPU_KERNEL)
|
6 |
+
#include <c10/core/DeviceGuard.h>
|
7 |
+
#endif
|
8 |
|
9 |
#include "registration.h"
|
10 |
|
11 |
+
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA || x.device().type() == torch::kXPU, #x " must be on CUDA or XPU")
|
12 |
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
|
14 |
+
void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
|
15 |
torch::Tensor const &cos, torch::Tensor const &sin,
|
16 |
torch::Tensor &out1, torch::Tensor &out2,
|
17 |
bool const conj);
|
|
|
32 |
TORCH_CHECK(cos.sizes() == sin.sizes());
|
33 |
TORCH_CHECK(out1.sizes() == out2.sizes());
|
34 |
|
35 |
+
#if defined(CUDA_KERNEL)
|
36 |
// Otherwise the kernel will be launched from cuda:0 device
|
37 |
at::cuda::CUDAGuard device_guard{x1.device()};
|
38 |
+
#elif defined(XPU_KERNEL)
|
39 |
+
c10::DeviceGuard device_guard{x1.device()};
|
40 |
+
#endif
|
41 |
+
_apply_rotary(x1, x2, cos, sin, out1, out2, conj);
|
42 |
}
|
43 |
|
44 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
45 |
ops.def("apply_rotary(Tensor x1, Tensor x2, Tensor cos, Tensor sin,"
|
46 |
"Tensor! out1, Tensor! out2, bool conj) -> ()");
|
47 |
+
#if defined(CUDA_KERNEL)
|
48 |
+
ops.impl("apply_rotary", torch::kCUDA, &apply_rotary);
|
49 |
+
#elif defined(XPU_KERNEL)
|
50 |
+
ops.impl("apply_rotary", torch::kXPU, &apply_rotary);
|
51 |
+
#endif
|
52 |
}
|
53 |
|
54 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|