drbh
commited on
Commit
·
ab4cc6a
1
Parent(s):
d774688
fix: revise bindings and wrapper typing
Browse files- flash_attn/flash_api.cpp +78 -63
- torch-ext/torch_binding.cpp +1 -1
- torch-ext/torch_binding.h +103 -6
flash_attn/flash_api.cpp
CHANGED
@@ -1476,10 +1476,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
1476 |
} // namespace FLASH_NAMESPACE
|
1477 |
|
1478 |
// NOTE: wrap the namespaced functions so all types are doubles and longs
|
1479 |
-
std::vector<
|
1480 |
-
mha_fwd(const
|
1481 |
-
const
|
1482 |
-
const
|
1483 |
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1484 |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1485 |
const double p_dropout,
|
@@ -1511,17 +1511,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqle
|
|
1511 |
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
1512 |
}
|
1513 |
|
1514 |
-
std::vector<
|
1515 |
-
mha_varlen_fwd(
|
1516 |
-
const
|
1517 |
-
const
|
1518 |
-
const
|
1519 |
-
const
|
1520 |
-
const
|
1521 |
-
const
|
1522 |
-
const
|
1523 |
-
const
|
1524 |
-
const
|
1525 |
const int64_t max_seqlen_q,
|
1526 |
const int64_t max_seqlen_k,
|
1527 |
const double p_dropout,
|
@@ -1532,7 +1532,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
1532 |
const int64_t window_size_right,
|
1533 |
const double softcap,
|
1534 |
const bool return_softmax,
|
1535 |
-
const
|
|
|
1536 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1537 |
// Prepare the optional arguments as non-const references.
|
1538 |
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
@@ -1564,26 +1565,26 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
1564 |
softcap_float, return_softmax, gen);
|
1565 |
}
|
1566 |
|
1567 |
-
std::vector<
|
1568 |
-
mha_bwd(const
|
1569 |
-
const
|
1570 |
-
const
|
1571 |
-
const
|
1572 |
-
const
|
1573 |
-
const
|
1574 |
-
const
|
1575 |
-
const
|
1576 |
-
const
|
1577 |
-
const
|
1578 |
-
const double p_dropout,
|
1579 |
const double softmax_scale,
|
1580 |
const bool is_causal,
|
1581 |
const int64_t window_size_left,
|
1582 |
const int64_t window_size_right,
|
1583 |
const double softcap,
|
1584 |
const bool deterministic,
|
1585 |
-
|
1586 |
-
|
1587 |
|
1588 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1589 |
|
@@ -1600,6 +1601,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
1600 |
int window_size_left_int = static_cast<int>(window_size_left);
|
1601 |
int window_size_right_int = static_cast<int>(window_size_right);
|
1602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1603 |
return FLASH_NAMESPACE::mha_bwd(
|
1604 |
const_cast<at::Tensor &>(dout),
|
1605 |
q, k, v, out, softmax_lse,
|
@@ -1608,23 +1616,23 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
1608 |
is_causal,
|
1609 |
window_size_left_int, window_size_right_int,
|
1610 |
softcap_float, deterministic,
|
1611 |
-
gen,
|
1612 |
}
|
1613 |
|
1614 |
|
1615 |
-
std::vector<
|
1616 |
-
mha_varlen_bwd(const
|
1617 |
-
const
|
1618 |
-
const
|
1619 |
-
const
|
1620 |
-
const
|
1621 |
-
const
|
1622 |
-
const
|
1623 |
-
const
|
1624 |
-
const
|
1625 |
-
const
|
1626 |
-
const
|
1627 |
-
const
|
1628 |
const int64_t max_seqlen_q,
|
1629 |
const int64_t max_seqlen_k,
|
1630 |
const double p_dropout,
|
@@ -1635,8 +1643,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
1635 |
const int64_t window_size_right,
|
1636 |
const double softcap,
|
1637 |
const bool deterministic,
|
1638 |
-
|
1639 |
-
|
1640 |
|
1641 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1642 |
|
@@ -1655,6 +1663,14 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
1655 |
int window_size_left_int = static_cast<int>(window_size_left);
|
1656 |
int window_size_right_int = static_cast<int>(window_size_right);
|
1657 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1658 |
return FLASH_NAMESPACE::mha_varlen_bwd(
|
1659 |
const_cast<at::Tensor &>(dout),
|
1660 |
q, k, v, out, softmax_lse,
|
@@ -1666,31 +1682,30 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
|
|
1666 |
zero_tensors, is_causal,
|
1667 |
window_size_left_int, window_size_right_int,
|
1668 |
softcap_float, deterministic,
|
1669 |
-
gen,
|
1670 |
}
|
1671 |
|
1672 |
-
std::vector<
|
1673 |
-
mha_fwd_kvcache(const
|
1674 |
-
const
|
1675 |
-
const
|
1676 |
-
const c10::optional<torch::Tensor> &k_,
|
1677 |
-
const c10::optional<torch::Tensor> &v_,
|
1678 |
-
const c10::optional<torch::Tensor> &seqlens_k_,
|
1679 |
-
const c10::optional<torch::Tensor> &rotary_cos_,
|
1680 |
-
const c10::optional<torch::Tensor> &rotary_sin_,
|
1681 |
-
const c10::optional<torch::Tensor> &cache_batch_idx_,
|
1682 |
-
const c10::optional<torch::Tensor> &leftpad_k_,
|
1683 |
-
const c10::optional<
|
1684 |
-
const c10::optional<
|
1685 |
-
const c10::optional<
|
1686 |
const double softmax_scale,
|
1687 |
bool is_causal,
|
1688 |
const int64_t window_size_left,
|
1689 |
const int64_t window_size_right,
|
1690 |
const double softcap,
|
1691 |
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
1692 |
-
const int64_t num_splits
|
1693 |
-
) {
|
1694 |
|
1695 |
// Prepare the optional arguments as const references where needed
|
1696 |
std::optional<const at::Tensor> k = k_.has_value() ? std::optional<const at::Tensor>(k_.value()) : std::nullopt;
|
|
|
1476 |
} // namespace FLASH_NAMESPACE
|
1477 |
|
1478 |
// NOTE: wrap the namespaced functions so all types are doubles and longs
|
1479 |
+
std::vector<torch::Tensor>
|
1480 |
+
mha_fwd(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1481 |
+
const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
1482 |
+
const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
1483 |
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1484 |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1485 |
const double p_dropout,
|
|
|
1511 |
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
1512 |
}
|
1513 |
|
1514 |
+
std::vector<torch::Tensor>
|
1515 |
+
mha_varlen_fwd(const torch::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
1516 |
+
const torch::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
|
1517 |
+
const torch::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
|
1518 |
+
const c10::optional<torch::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
1519 |
+
const torch::Tensor &cu_seqlens_q, // b+1
|
1520 |
+
const torch::Tensor &cu_seqlens_k, // b+1
|
1521 |
+
const c10::optional<torch::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
1522 |
+
const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
|
1523 |
+
const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
1524 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
1525 |
const int64_t max_seqlen_q,
|
1526 |
const int64_t max_seqlen_k,
|
1527 |
const double p_dropout,
|
|
|
1532 |
const int64_t window_size_right,
|
1533 |
const double softcap,
|
1534 |
const bool return_softmax,
|
1535 |
+
const c10::optional<at::Generator> gen_) {
|
1536 |
+
|
1537 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1538 |
// Prepare the optional arguments as non-const references.
|
1539 |
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
|
|
1565 |
softcap_float, return_softmax, gen);
|
1566 |
}
|
1567 |
|
1568 |
+
std::vector<torch::Tensor>
|
1569 |
+
mha_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
1570 |
+
const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
1571 |
+
const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
1572 |
+
const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
1573 |
+
const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
1574 |
+
const torch::Tensor &softmax_lse, // b x h x seqlen_q
|
1575 |
+
const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
1576 |
+
const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
1577 |
+
const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
1578 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1579 |
+
const double p_dropout, // probability to drop
|
1580 |
const double softmax_scale,
|
1581 |
const bool is_causal,
|
1582 |
const int64_t window_size_left,
|
1583 |
const int64_t window_size_right,
|
1584 |
const double softcap,
|
1585 |
const bool deterministic,
|
1586 |
+
c10::optional<torch::Generator> gen_,
|
1587 |
+
const c10::optional<torch::Tensor> &rng_state) {
|
1588 |
|
1589 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1590 |
|
|
|
1601 |
int window_size_left_int = static_cast<int>(window_size_left);
|
1602 |
int window_size_right_int = static_cast<int>(window_size_right);
|
1603 |
|
1604 |
+
// TODO: avoid copying rng_state if possible
|
1605 |
+
// Create a non-const copy of rng_state
|
1606 |
+
std::optional<at::Tensor> rng_state_copy;
|
1607 |
+
if (rng_state.has_value()) {
|
1608 |
+
rng_state_copy = rng_state.value().clone();
|
1609 |
+
}
|
1610 |
+
|
1611 |
return FLASH_NAMESPACE::mha_bwd(
|
1612 |
const_cast<at::Tensor &>(dout),
|
1613 |
q, k, v, out, softmax_lse,
|
|
|
1616 |
is_causal,
|
1617 |
window_size_left_int, window_size_right_int,
|
1618 |
softcap_float, deterministic,
|
1619 |
+
gen, rng_state_copy);
|
1620 |
}
|
1621 |
|
1622 |
|
1623 |
+
std::vector<torch::Tensor>
|
1624 |
+
mha_varlen_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
1625 |
+
const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
1626 |
+
const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
1627 |
+
const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
1628 |
+
const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
1629 |
+
const torch::Tensor &softmax_lse, // b x h x seqlen_q
|
1630 |
+
const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
1631 |
+
const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
1632 |
+
const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
1633 |
+
const torch::Tensor &cu_seqlens_q, // batch_size + 1
|
1634 |
+
const torch::Tensor &cu_seqlens_k, // batch_size + 1
|
1635 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
1636 |
const int64_t max_seqlen_q,
|
1637 |
const int64_t max_seqlen_k,
|
1638 |
const double p_dropout,
|
|
|
1643 |
const int64_t window_size_right,
|
1644 |
const double softcap,
|
1645 |
const bool deterministic,
|
1646 |
+
c10::optional<torch::Generator> gen_,
|
1647 |
+
const c10::optional<torch::Tensor> &rng_state) {
|
1648 |
|
1649 |
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1650 |
|
|
|
1663 |
int window_size_left_int = static_cast<int>(window_size_left);
|
1664 |
int window_size_right_int = static_cast<int>(window_size_right);
|
1665 |
|
1666 |
+
|
1667 |
+
// TODO: avoid copying rng_state if possible
|
1668 |
+
// Create a non-const copy of rng_state
|
1669 |
+
std::optional<at::Tensor> rng_state_copy;
|
1670 |
+
if (rng_state.has_value()) {
|
1671 |
+
rng_state_copy = rng_state.value().clone();
|
1672 |
+
}
|
1673 |
+
|
1674 |
return FLASH_NAMESPACE::mha_varlen_bwd(
|
1675 |
const_cast<at::Tensor &>(dout),
|
1676 |
q, k, v, out, softmax_lse,
|
|
|
1682 |
zero_tensors, is_causal,
|
1683 |
window_size_left_int, window_size_right_int,
|
1684 |
softcap_float, deterministic,
|
1685 |
+
gen, rng_state_copy);
|
1686 |
}
|
1687 |
|
1688 |
+
std::vector<torch::Tensor>
|
1689 |
+
mha_fwd_kvcache(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
1690 |
+
const torch::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
1691 |
+
const torch::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
1692 |
+
const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
1693 |
+
const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
1694 |
+
const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
|
1695 |
+
const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
1696 |
+
const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
1697 |
+
const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
|
1698 |
+
const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
|
1699 |
+
const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
1700 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1701 |
+
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
1702 |
const double softmax_scale,
|
1703 |
bool is_causal,
|
1704 |
const int64_t window_size_left,
|
1705 |
const int64_t window_size_right,
|
1706 |
const double softcap,
|
1707 |
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
1708 |
+
const int64_t num_splits) {
|
|
|
1709 |
|
1710 |
// Prepare the optional arguments as const references where needed
|
1711 |
std::optional<const at::Tensor> k = k_.has_value() ? std::optional<const at::Tensor>(k_.value()) : std::nullopt;
|
torch-ext/torch_binding.cpp
CHANGED
@@ -17,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
17 |
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
18 |
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
|
19 |
|
20 |
-
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
21 |
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
|
22 |
|
23 |
ops.def("mha_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
|
|
|
17 |
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
18 |
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
|
19 |
|
20 |
+
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor? seqused_k_, Tensor? leftpad_k_, Tensor? block_table_, Tensor? alibi_slopes_, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
21 |
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
|
22 |
|
23 |
ops.def("mha_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
|
torch-ext/torch_binding.h
CHANGED
@@ -2,11 +2,11 @@
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
-
std::vector<
|
6 |
-
mha_fwd(const
|
7 |
-
const
|
8 |
-
const
|
9 |
-
const c10::optional<torch::Tensor> &out_,
|
10 |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
11 |
const double p_dropout,
|
12 |
const double softmax_scale,
|
@@ -15,4 +15,101 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x roun
|
|
15 |
const int64_t window_size_right,
|
16 |
const double softcap,
|
17 |
const bool return_softmax,
|
18 |
-
const c10::optional<at::Generator> gen_);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
+
std::vector<torch::Tensor>
|
6 |
+
mha_fwd(const torch::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
7 |
+
const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
8 |
+
const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
9 |
+
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
10 |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
11 |
const double p_dropout,
|
12 |
const double softmax_scale,
|
|
|
15 |
const int64_t window_size_right,
|
16 |
const double softcap,
|
17 |
const bool return_softmax,
|
18 |
+
const c10::optional<at::Generator> gen_);
|
19 |
+
|
20 |
+
std::vector<torch::Tensor>
|
21 |
+
mha_varlen_fwd(
|
22 |
+
const torch::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
23 |
+
const torch::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
|
24 |
+
const torch::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
|
25 |
+
const c10::optional<torch::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
26 |
+
const torch::Tensor &cu_seqlens_q, // b+1
|
27 |
+
const torch::Tensor &cu_seqlens_k, // b+1
|
28 |
+
const c10::optional<torch::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
29 |
+
const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
|
30 |
+
const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
31 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
32 |
+
const int64_t max_seqlen_q,
|
33 |
+
const int64_t max_seqlen_k,
|
34 |
+
const double p_dropout,
|
35 |
+
const double softmax_scale,
|
36 |
+
const bool zero_tensors,
|
37 |
+
const bool is_causal,
|
38 |
+
const int64_t window_size_left,
|
39 |
+
const int64_t window_size_right,
|
40 |
+
const double softcap,
|
41 |
+
const bool return_softmax,
|
42 |
+
const c10::optional<at::Generator> gen_);
|
43 |
+
|
44 |
+
|
45 |
+
std::vector<torch::Tensor>
|
46 |
+
mha_bwd(const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
47 |
+
const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
48 |
+
const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
49 |
+
const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
50 |
+
const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
51 |
+
const torch::Tensor &softmax_lse, // b x h x seqlen_q
|
52 |
+
const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
53 |
+
const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
54 |
+
const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
55 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
56 |
+
const double p_dropout, // probability to drop
|
57 |
+
const double softmax_scale,
|
58 |
+
const bool is_causal,
|
59 |
+
const int64_t window_size_left,
|
60 |
+
const int64_t window_size_right,
|
61 |
+
const double softcap,
|
62 |
+
const bool deterministic,
|
63 |
+
c10::optional<at::Generator> gen_,
|
64 |
+
const c10::optional<torch::Tensor> &rng_state);
|
65 |
+
|
66 |
+
|
67 |
+
std::vector<torch::Tensor>
|
68 |
+
mha_varlen_bwd(
|
69 |
+
const torch::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
70 |
+
const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
71 |
+
const torch::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
72 |
+
const torch::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
73 |
+
const torch::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
74 |
+
const torch::Tensor &softmax_lse, // b x h x seqlen_q
|
75 |
+
const c10::optional<torch::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
76 |
+
const c10::optional<torch::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
77 |
+
const c10::optional<torch::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
78 |
+
const torch::Tensor &cu_seqlens_q, // batch_size + 1
|
79 |
+
const torch::Tensor &cu_seqlens_k, // batch_size + 1
|
80 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
81 |
+
const int64_t max_seqlen_q,
|
82 |
+
const int64_t max_seqlen_k,
|
83 |
+
const double p_dropout,
|
84 |
+
const double softmax_scale,
|
85 |
+
const bool zero_tensors,
|
86 |
+
const bool is_causal,
|
87 |
+
const int64_t window_size_left,
|
88 |
+
const int64_t window_size_right,
|
89 |
+
const double softcap,
|
90 |
+
const bool deterministic,
|
91 |
+
c10::optional<at::Generator> gen_,
|
92 |
+
const c10::optional<torch::Tensor> &rng_state);
|
93 |
+
|
94 |
+
std::vector<torch::Tensor>
|
95 |
+
mha_fwd_kvcache(
|
96 |
+
const torch::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
97 |
+
const torch::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
98 |
+
const torch::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
99 |
+
const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
100 |
+
const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
101 |
+
const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
|
102 |
+
const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
103 |
+
const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
104 |
+
const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
|
105 |
+
const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
|
106 |
+
const c10::optional<torch::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
107 |
+
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
108 |
+
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
109 |
+
const double softmax_scale,
|
110 |
+
bool is_causal,
|
111 |
+
const int64_t window_size_left,
|
112 |
+
const int64_t window_size_right,
|
113 |
+
const double softcap,
|
114 |
+
bool is_rotary_interleaved,
|
115 |
+
const int64_t num_splits);
|