kernel
Kaixuanliu commited on
Commit
6069b37
·
1 Parent(s): d7f24a6

add `apply_rotary_pos_emb` API to make it adapt to transformers

Browse files
Files changed (2) hide show
  1. rotary-xpu/rotary_xpu.cpp +38 -0
  2. rotary/rotary_cuda.cu +38 -0
rotary-xpu/rotary_xpu.cpp CHANGED
@@ -38,3 +38,41 @@ void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
38
  });
39
  }
40
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  });
39
  }
40
  }
41
+
42
+ std::tuple<torch::Tensor, torch::Tensor> apply_rotary_pos_emb(
43
+ torch::Tensor const &q, torch::Tensor const &k,
44
+ torch::Tensor const &cos, torch::Tensor const &sin,
45
+ torch::Tensor const &position_ids, int64_t unsqueeze_dim) {
46
+
47
+ // Handle unsqueeze_dim parameter
48
+ auto cos_unsqueezed = cos.unsqueeze(unsqueeze_dim);
49
+ auto sin_unsqueezed = sin.unsqueeze(unsqueeze_dim);
50
+
51
+ // Clone inputs since we'll modify them
52
+ auto q_rotated = q.clone();
53
+ auto k_rotated = k.clone();
54
+
55
+ // Get half dimension for rotation
56
+ int64_t half_dim = q.size(-1) / 2;
57
+
58
+ // Split Q and K for rotation
59
+ auto q1 = q_rotated.slice(-1, 0, half_dim);
60
+ auto q2 = q_rotated.slice(-1, half_dim, q.size(-1));
61
+ auto k1 = k_rotated.slice(-1, 0, half_dim);
62
+ auto k2 = k_rotated.slice(-1, half_dim, k.size(-1));
63
+
64
+ // Make sure cos/sin match the half dimension
65
+ auto cos_final = cos_unsqueezed;
66
+ auto sin_final = sin_unsqueezed;
67
+ if (cos_unsqueezed.size(-1) != half_dim) {
68
+ // Trim cos/sin to match half_dim
69
+ cos_final = cos_unsqueezed.slice(-1, 0, half_dim);
70
+ sin_final = sin_unsqueezed.slice(-1, 0, half_dim);
71
+ }
72
+
73
+ // Apply rotary embedding using our kernel
74
+ _apply_rotary(q1, q2, cos_final, sin_final, q1, q2, false);
75
+ _apply_rotary(k1, k2, cos_final, sin_final, k1, k2, false);
76
+
77
+ return std::make_tuple(q_rotated, k_rotated);
78
+ }
rotary/rotary_cuda.cu CHANGED
@@ -43,3 +43,41 @@ void _apply_rotary(torch::Tensor const &x1, torch::Tensor const &x2,
43
  });
44
  }
45
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  });
44
  }
45
  }
46
+
47
+ std::tuple<torch::Tensor, torch::Tensor> apply_rotary_pos_emb(
48
+ torch::Tensor const &q, torch::Tensor const &k,
49
+ torch::Tensor const &cos, torch::Tensor const &sin,
50
+ torch::Tensor const &position_ids, int64_t unsqueeze_dim) {
51
+
52
+ // Handle unsqueeze_dim parameter
53
+ auto cos_unsqueezed = cos.unsqueeze(unsqueeze_dim);
54
+ auto sin_unsqueezed = sin.unsqueeze(unsqueeze_dim);
55
+
56
+ // Clone inputs since we'll modify them
57
+ auto q_rotated = q.clone();
58
+ auto k_rotated = k.clone();
59
+
60
+ // Get half dimension for rotation
61
+ int64_t half_dim = q.size(-1) / 2;
62
+
63
+ // Split Q and K for rotation
64
+ auto q1 = q_rotated.slice(-1, 0, half_dim);
65
+ auto q2 = q_rotated.slice(-1, half_dim, q.size(-1));
66
+ auto k1 = k_rotated.slice(-1, 0, half_dim);
67
+ auto k2 = k_rotated.slice(-1, half_dim, k.size(-1));
68
+
69
+ // Make sure cos/sin match the half dimension
70
+ auto cos_final = cos_unsqueezed;
71
+ auto sin_final = sin_unsqueezed;
72
+ if (cos_unsqueezed.size(-1) != half_dim) {
73
+ // Trim cos/sin to match half_dim
74
+ cos_final = cos_unsqueezed.slice(-1, 0, half_dim);
75
+ sin_final = sin_unsqueezed.slice(-1, 0, half_dim);
76
+ }
77
+
78
+ // Apply rotary embedding using our kernel
79
+ _apply_rotary(q1, q2, cos_final, sin_final, q1, q2, false);
80
+ _apply_rotary(k1, k2, cos_final, sin_final, k1, k2, false);
81
+
82
+ return std::make_tuple(q_rotated, k_rotated);
83
+ }