drbh commited on
Commit
ab4cc6a
·
1 Parent(s): d774688

fix: revise bindings and wrapper typing

Browse files
flash_attn/flash_api.cpp CHANGED
@@ -1476,10 +1476,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
1476
  } // namespace FLASH_NAMESPACE
1477
 
1478
  // NOTE: wrap the namespaced functions so all types are doubles and longs
1479
- std::vector<at::Tensor>
1480
- mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1481
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1482
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1483
  const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1484
  const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1485
  const double p_dropout,
@@ -1511,17 +1511,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqle
1511
  return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1512
  }
1513
 
1514
- std::vector<at::Tensor>
1515
- mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1516
- const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
1517
- const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
1518
- const std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
1519
- const at::Tensor &cu_seqlens_q, // b+1
1520
- const at::Tensor &cu_seqlens_k, // b+1
1521
- const std::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1522
- const std::optional<const at::Tensor> &leftpad_k_, // batch_size
1523
- const std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1524
- const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1525
  const int64_t max_seqlen_q,
1526
  const int64_t max_seqlen_k,
1527
  const double p_dropout,
@@ -1532,7 +1532,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
1532
  const int64_t window_size_right,
1533
  const double softcap,
1534
  const bool return_softmax,
1535
- const std::optional<at::Generator> gen_) {
 
1536
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1537
  // Prepare the optional arguments as non-const references.
1538
  std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
@@ -1564,26 +1565,26 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
1564
  softcap_float, return_softmax, gen);
1565
  }
1566
 
1567
- std::vector<at::Tensor>
1568
- mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
1569
- const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1570
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
1571
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1572
- const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1573
- const at::Tensor &softmax_lse, // b x h x seqlen_q
1574
- const std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
1575
- const std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
1576
- const std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
1577
- const std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1578
- const double p_dropout, // probability to drop
1579
  const double softmax_scale,
1580
  const bool is_causal,
1581
  const int64_t window_size_left,
1582
  const int64_t window_size_right,
1583
  const double softcap,
1584
  const bool deterministic,
1585
- std::optional<at::Generator> gen_,
1586
- std::optional<at::Tensor> &rng_state) {
1587
 
1588
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1589
 
@@ -1600,6 +1601,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1600
  int window_size_left_int = static_cast<int>(window_size_left);
1601
  int window_size_right_int = static_cast<int>(window_size_right);
1602
 
 
 
 
 
 
 
 
1603
  return FLASH_NAMESPACE::mha_bwd(
1604
  const_cast<at::Tensor &>(dout),
1605
  q, k, v, out, softmax_lse,
@@ -1608,23 +1616,23 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1608
  is_causal,
1609
  window_size_left_int, window_size_right_int,
1610
  softcap_float, deterministic,
1611
- gen, rng_state);
1612
  }
1613
 
1614
 
1615
- std::vector<at::Tensor>
1616
- mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
1617
- const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1618
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
1619
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1620
- const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1621
- const at::Tensor &softmax_lse, // b x h x seqlen_q
1622
- const std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
1623
- const std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
1624
- const std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
1625
- const at::Tensor &cu_seqlens_q, // batch_size + 1
1626
- const at::Tensor &cu_seqlens_k, // batch_size + 1
1627
- const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1628
  const int64_t max_seqlen_q,
1629
  const int64_t max_seqlen_k,
1630
  const double p_dropout,
@@ -1635,8 +1643,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1635
  const int64_t window_size_right,
1636
  const double softcap,
1637
  const bool deterministic,
1638
- std::optional<at::Generator> gen_,
1639
- std::optional<at::Tensor> &rng_state) {
1640
 
1641
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1642
 
@@ -1655,6 +1663,14 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1655
  int window_size_left_int = static_cast<int>(window_size_left);
1656
  int window_size_right_int = static_cast<int>(window_size_right);
1657
 
 
 
 
 
 
 
 
 
1658
  return FLASH_NAMESPACE::mha_varlen_bwd(
1659
  const_cast<at::Tensor &>(dout),
1660
  q, k, v, out, softmax_lse,
@@ -1666,31 +1682,30 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1666
  zero_tensors, is_causal,
1667
  window_size_left_int, window_size_right_int,
1668
  softcap_float, deterministic,
1669
- gen, rng_state);
1670
  }
1671
 
1672
- std::vector<at::Tensor>
1673
- mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1674
- const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1675
- const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1676
- const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1677
- const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1678
- const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
1679
- const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1680
- const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1681
- const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1682
- const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
1683
- const c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1684
- const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1685
- const c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
1686
  const double softmax_scale,
1687
  bool is_causal,
1688
  const int64_t window_size_left,
1689
  const int64_t window_size_right,
1690
  const double softcap,
1691
  bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1692
- const int64_t num_splits
1693
- ) {
1694
 
1695
  // Prepare the optional arguments as const references where needed
1696
  std::optional<const at::Tensor> k = k_.has_value() ? std::optional<const at::Tensor>(k_.value()) : std::nullopt;
 
1476
  } // namespace FLASH_NAMESPACE
1477
 
1478
  // NOTE: wrap the namespaced functions so all types are doubles and longs
1479
+ std::vector<torch::Tensor>
1480
+ mha_fwd(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1481
+ const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1482
+ const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1483
  const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1484
  const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1485
  const double p_dropout,
 
1511
  return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1512
  }
1513
 
1514
+ std::vector<torch::Tensor>
1515
+ mha_varlen_fwd(const torch::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1516
+ const torch::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
1517
+ const torch::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
1518
+ const c10::optional<torch::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
1519
+ const torch::Tensor &cu_seqlens_q, // b+1
1520
+ const torch::Tensor &cu_seqlens_k, // b+1
1521
+ const c10::optional<torch::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1522
+ const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
1523
+ const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1524
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1525
  const int64_t max_seqlen_q,
1526
  const int64_t max_seqlen_k,
1527
  const double p_dropout,
 
1532
  const int64_t window_size_right,
1533
  const double softcap,
1534
  const bool return_softmax,
1535
+ const c10::optional<at::Generator> gen_) {
1536
+
1537
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1538
  // Prepare the optional arguments as non-const references.
1539
  std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
 
1565
  softcap_float, return_softmax, gen);
1566
  }
1567
 
1568
+ std::vector<torch::Tensor>
1569
+ mha_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
1570
+ const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1571
+ const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
1572
+ const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1573
+ const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1574
+ const torch::Tensor &softmax_lse, // b x h x seqlen_q
1575
+ const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
1576
+ const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
1577
+ const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
1578
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1579
+ const double p_dropout, // probability to drop
1580
  const double softmax_scale,
1581
  const bool is_causal,
1582
  const int64_t window_size_left,
1583
  const int64_t window_size_right,
1584
  const double softcap,
1585
  const bool deterministic,
1586
+ c10::optional<torch::Generator> gen_,
1587
+ const c10::optional<torch::Tensor> &rng_state) {
1588
 
1589
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1590
 
 
1601
  int window_size_left_int = static_cast<int>(window_size_left);
1602
  int window_size_right_int = static_cast<int>(window_size_right);
1603
 
1604
+ // TODO: avoid copying rng_state if possible
1605
+ // Create a non-const copy of rng_state
1606
+ std::optional<at::Tensor> rng_state_copy;
1607
+ if (rng_state.has_value()) {
1608
+ rng_state_copy = rng_state.value().clone();
1609
+ }
1610
+
1611
  return FLASH_NAMESPACE::mha_bwd(
1612
  const_cast<at::Tensor &>(dout),
1613
  q, k, v, out, softmax_lse,
 
1616
  is_causal,
1617
  window_size_left_int, window_size_right_int,
1618
  softcap_float, deterministic,
1619
+ gen, rng_state_copy);
1620
  }
1621
 
1622
 
1623
+ std::vector<torch::Tensor>
1624
+ mha_varlen_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
1625
+ const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1626
+ const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
1627
+ const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1628
+ const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1629
+ const torch::Tensor &softmax_lse, // b x h x seqlen_q
1630
+ const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
1631
+ const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
1632
+ const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
1633
+ const torch::Tensor &cu_seqlens_q, // batch_size + 1
1634
+ const torch::Tensor &cu_seqlens_k, // batch_size + 1
1635
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1636
  const int64_t max_seqlen_q,
1637
  const int64_t max_seqlen_k,
1638
  const double p_dropout,
 
1643
  const int64_t window_size_right,
1644
  const double softcap,
1645
  const bool deterministic,
1646
+ c10::optional<torch::Generator> gen_,
1647
+ const c10::optional<torch::Tensor> &rng_state) {
1648
 
1649
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1650
 
 
1663
  int window_size_left_int = static_cast<int>(window_size_left);
1664
  int window_size_right_int = static_cast<int>(window_size_right);
1665
 
1666
+
1667
+ // TODO: avoid copying rng_state if possible
1668
+ // Create a non-const copy of rng_state
1669
+ std::optional<at::Tensor> rng_state_copy;
1670
+ if (rng_state.has_value()) {
1671
+ rng_state_copy = rng_state.value().clone();
1672
+ }
1673
+
1674
  return FLASH_NAMESPACE::mha_varlen_bwd(
1675
  const_cast<at::Tensor &>(dout),
1676
  q, k, v, out, softmax_lse,
 
1682
  zero_tensors, is_causal,
1683
  window_size_left_int, window_size_right_int,
1684
  softcap_float, deterministic,
1685
+ gen, rng_state_copy);
1686
  }
1687
 
1688
+ std::vector<torch::Tensor>
1689
+ mha_fwd_kvcache(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1690
+ const torch::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1691
+ const torch::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1692
+ const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1693
+ const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1694
+ const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
1695
+ const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1696
+ const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1697
+ const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1698
+ const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
1699
+ const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1700
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1701
+ const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
1702
  const double softmax_scale,
1703
  bool is_causal,
1704
  const int64_t window_size_left,
1705
  const int64_t window_size_right,
1706
  const double softcap,
1707
  bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1708
+ const int64_t num_splits) {
 
1709
 
1710
  // Prepare the optional arguments as const references where needed
1711
  std::optional<const at::Tensor> k = k_.has_value() ? std::optional<const at::Tensor>(k_.value()) : std::nullopt;
torch-ext/torch_binding.cpp CHANGED
@@ -17,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
17
  ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
18
  ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
19
 
20
- ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
21
  ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
22
 
23
  ops.def("mha_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
 
17
  ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
18
  ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
19
 
20
+ ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor? seqused_k_, Tensor? leftpad_k_, Tensor? block_table_, Tensor? alibi_slopes_, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
21
  ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
22
 
23
  ops.def("mha_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
torch-ext/torch_binding.h CHANGED
@@ -2,11 +2,11 @@
2
 
3
  #include <torch/torch.h>
4
 
5
- std::vector<at::Tensor>
6
- mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
7
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
8
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
9
- const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
10
  const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
11
  const double p_dropout,
12
  const double softmax_scale,
@@ -15,4 +15,101 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x roun
15
  const int64_t window_size_right,
16
  const double softcap,
17
  const bool return_softmax,
18
- const c10::optional<at::Generator> gen_);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ std::vector<torch::Tensor>
6
+ mha_fwd(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
7
+ const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
8
+ const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
9
+ const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
10
  const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
11
  const double p_dropout,
12
  const double softmax_scale,
 
15
  const int64_t window_size_right,
16
  const double softcap,
17
  const bool return_softmax,
18
+ const c10::optional<at::Generator> gen_);
19
+
20
+ std::vector<torch::Tensor>
21
+ mha_varlen_fwd(
22
+ const torch::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
23
+ const torch::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
24
+ const torch::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
25
+ const c10::optional<torch::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
26
+ const torch::Tensor &cu_seqlens_q, // b+1
27
+ const torch::Tensor &cu_seqlens_k, // b+1
28
+ const c10::optional<torch::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
29
+ const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
30
+ const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
31
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
32
+ const int64_t max_seqlen_q,
33
+ const int64_t max_seqlen_k,
34
+ const double p_dropout,
35
+ const double softmax_scale,
36
+ const bool zero_tensors,
37
+ const bool is_causal,
38
+ const int64_t window_size_left,
39
+ const int64_t window_size_right,
40
+ const double softcap,
41
+ const bool return_softmax,
42
+ const c10::optional<at::Generator> gen_);
43
+
44
+
45
+ std::vector<torch::Tensor>
46
+ mha_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
47
+ const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
48
+ const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
49
+ const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
50
+ const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
51
+ const torch::Tensor &softmax_lse, // b x h x seqlen_q
52
+ const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
53
+ const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
54
+ const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
55
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
56
+ const double p_dropout, // probability to drop
57
+ const double softmax_scale,
58
+ const bool is_causal,
59
+ const int64_t window_size_left,
60
+ const int64_t window_size_right,
61
+ const double softcap,
62
+ const bool deterministic,
63
+ c10::optional<at::Generator> gen_,
64
+ const c10::optional<torch::Tensor> &rng_state);
65
+
66
+
67
+ std::vector<torch::Tensor>
68
+ mha_varlen_bwd(
69
+ const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
70
+ const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
71
+ const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
72
+ const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
73
+ const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
74
+ const torch::Tensor &softmax_lse, // b x h x seqlen_q
75
+ const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
76
+ const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
77
+ const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
78
+ const torch::Tensor &cu_seqlens_q, // batch_size + 1
79
+ const torch::Tensor &cu_seqlens_k, // batch_size + 1
80
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
81
+ const int64_t max_seqlen_q,
82
+ const int64_t max_seqlen_k,
83
+ const double p_dropout,
84
+ const double softmax_scale,
85
+ const bool zero_tensors,
86
+ const bool is_causal,
87
+ const int64_t window_size_left,
88
+ const int64_t window_size_right,
89
+ const double softcap,
90
+ const bool deterministic,
91
+ c10::optional<at::Generator> gen_,
92
+ const c10::optional<torch::Tensor> &rng_state);
93
+
94
+ std::vector<torch::Tensor>
95
+ mha_fwd_kvcache(
96
+ const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
97
+ const torch::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
98
+ const torch::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
99
+ const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
100
+ const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
101
+ const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
102
+ const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
103
+ const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
104
+ const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
105
+ const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
106
+ const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
107
+ const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
108
+ const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
109
+ const double softmax_scale,
110
+ bool is_causal,
111
+ const int64_t window_size_left,
112
+ const int64_t window_size_right,
113
+ const double softcap,
114
+ bool is_rotary_interleaved,
115
+ const int64_t num_splits);