drbh commited on
Commit
4080f9c
·
1 Parent(s): b0d3c12

fix: adjust types

Browse files
Files changed (1) hide show
  1. flash_attn/flash_api.cpp +102 -32
flash_attn/flash_api.cpp CHANGED
@@ -1507,45 +1507,61 @@ mha_fwd(const at::Tensor &q, // batch_size x seqle
1507
  float softcap_float = static_cast<float>(softcap);
1508
  int window_size_left_int = static_cast<int>(window_size_left);
1509
  int window_size_right_int = static_cast<int>(window_size_right);
1510
-
1511
  return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1512
  }
1513
 
1514
  std::vector<at::Tensor>
1515
- mha_varlen_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1516
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1517
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1518
- const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1519
- const at::Tensor &cu_seqlens_q, // batch_size + 1
1520
- const at::Tensor &cu_seqlens_k, // batch_size + 1
 
 
 
 
1521
  const int64_t max_seqlen_q,
1522
  const int64_t max_seqlen_k,
1523
  const double p_dropout,
1524
  const double softmax_scale,
1525
- bool is_causal,
 
1526
  const int64_t window_size_left,
1527
  const int64_t window_size_right,
1528
  const double softcap,
1529
  const bool return_softmax,
1530
- const c10::optional<at::Generator> gen_) {
1531
-
1532
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1533
-
1534
  // Prepare the optional arguments as non-const references.
1535
  std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1536
-
 
 
 
 
1537
  if (!out.has_value()){
1538
  out = torch::empty_like(q);
1539
  }
1540
-
1541
  // Convert double to float and int64_t to int.
1542
  float p_dropout_float = static_cast<float>(p_dropout);
1543
  float softmax_scale_float = static_cast<float>(softmax_scale);
1544
  float softcap_float = static_cast<float>(softcap);
 
 
1545
  int window_size_left_int = static_cast<int>(window_size_left);
1546
  int window_size_right_int = static_cast<int>(window_size_right);
1547
-
1548
- return FLASH_NAMESPACE::mha_varlen_fwd(const_cast<at::Tensor &>(q), k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
 
 
 
 
 
 
 
 
1549
  }
1550
 
1551
  std::vector<at::Tensor>
@@ -1570,7 +1586,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1570
  std::optional<at::Tensor> &rng_state) {
1571
 
1572
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1573
-
1574
  // Prepare the optional arguments as non-const references.
1575
  std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
1576
  std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
@@ -1584,7 +1600,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1584
  int window_size_left_int = static_cast<int>(window_size_left);
1585
  int window_size_right_int = static_cast<int>(window_size_right);
1586
 
1587
- return FLASH_NAMESPACE::mha_bwd(const_cast<at::Tensor &>(dout), q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, deterministic, gen, rng_state);
 
 
 
 
 
 
 
 
1588
  }
1589
 
1590
 
@@ -1595,12 +1619,17 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1595
  const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1596
  const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1597
  const at::Tensor &softmax_lse, // b x h x seqlen_q
 
 
 
1598
  const at::Tensor &cu_seqlens_q, // batch_size + 1
1599
  const at::Tensor &cu_seqlens_k, // batch_size + 1
 
1600
  const int64_t max_seqlen_q,
1601
  const int64_t max_seqlen_k,
1602
  const double p_dropout,
1603
  const double softmax_scale,
 
1604
  const bool is_causal,
1605
  const int64_t window_size_left,
1606
  const int64_t window_size_right,
@@ -1608,17 +1637,36 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q
1608
  const bool deterministic,
1609
  std::optional<at::Generator> gen_,
1610
  std::optional<at::Tensor> &rng_state) {
1611
-
1612
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1613
-
 
 
 
 
 
 
1614
  // Convert double to float and int64_t to int.
1615
  float p_dropout_float = static_cast<float>(p_dropout);
1616
  float softmax_scale_float = static_cast<float>(softmax_scale);
1617
  float softcap_float = static_cast<float>(softcap);
 
 
1618
  int window_size_left_int = static_cast<int>(window_size_left);
1619
  int window_size_right_int = static_cast<int>(window_size_right);
1620
 
1621
- return FLASH_NAMESPACE::mha_varlen_bwd(const_cast<at::Tensor &>(dout), q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, deterministic, gen, rng_state);
 
 
 
 
 
 
 
 
 
 
 
1622
  }
1623
 
1624
  std::vector<at::Tensor>
@@ -1643,25 +1691,47 @@ mha_fwd_kvcache(const at::Tensor &q, // batch
1643
  bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1644
  const int64_t num_splits
1645
  ) {
1646
-
1647
- // Prepare the optional arguments as non-const references.
1648
- std::optional<at::Tensor> k = k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(k_.value())) : std::nullopt;
1649
- std::optional<at::Tensor> v = v_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(v_.value())) : std::nullopt;
1650
- std::optional<at::Tensor> seqlens_k = seqlens_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(seqlens_k_.value())) : std::nullopt;
1651
- std::optional<at::Tensor> rotary_cos = rotary_cos_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(rotary_cos_.value())) : std::nullopt;
1652
- std::optional<at::Tensor> rotary_sin = rotary_sin_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(rotary_sin_.value())) : std::nullopt;
1653
- std::optional<at::Tensor> cache_batch_idx = cache_batch_idx_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(cache_batch_idx_.value())) : std::nullopt;
1654
- std::optional<at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::at::Tensor &>(leftpad_k_.value())) : std::nullopt;
 
 
1655
  std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
1656
  std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1657
  std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1658
-
 
 
 
 
1659
  // Convert double to float and int64_t to int.
1660
  float softmax_scale_float = static_cast<float>(softmax_scale);
1661
  float softcap_float = static_cast<float>(softcap);
1662
  int window_size_left_int = static_cast<int>(window_size_left);
1663
  int window_size_right_int = static_cast<int>(window_size_right);
1664
  int num_splits_int = static_cast<int>(num_splits);
1665
-
1666
- return FLASH_NAMESPACE::mha_fwd_kvcache(const_cast<at::Tensor &>(q), kcache, vcache, k, v, seqlens_k, rotary_cos, rotary_sin, cache_batch_idx, leftpad_k, block_table, alibi_slopes, out, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, is_rotary_interleaved, num_splits_int);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1667
  }
 
1507
  float softcap_float = static_cast<float>(softcap);
1508
  int window_size_left_int = static_cast<int>(window_size_left);
1509
  int window_size_right_int = static_cast<int>(window_size_right);
1510
+
1511
  return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1512
  }
1513
 
1514
  std::vector<at::Tensor>
1515
+ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
1516
+ const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
1517
+ const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_>
1518
+ const std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
1519
+ const at::Tensor &cu_seqlens_q, // b+1
1520
+ const at::Tensor &cu_seqlens_k, // b+1
1521
+ const std::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1522
+ const std::optional<const at::Tensor> &leftpad_k_, // batch_size
1523
+ const std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1524
+ const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1525
  const int64_t max_seqlen_q,
1526
  const int64_t max_seqlen_k,
1527
  const double p_dropout,
1528
  const double softmax_scale,
1529
+ const bool zero_tensors,
1530
+ const bool is_causal,
1531
  const int64_t window_size_left,
1532
  const int64_t window_size_right,
1533
  const double softcap,
1534
  const bool return_softmax,
1535
+ const std::optional<at::Generator> gen_) {
 
1536
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
 
1537
  // Prepare the optional arguments as non-const references.
1538
  std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1539
+ std::optional<at::Tensor> seqused_k = seqused_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(seqused_k_.value())) : std::nullopt;
1540
+ std::optional<const at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<const at::Tensor>(leftpad_k_.value()) : std::nullopt;
1541
+ std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
1542
+ std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1543
+
1544
  if (!out.has_value()){
1545
  out = torch::empty_like(q);
1546
  }
 
1547
  // Convert double to float and int64_t to int.
1548
  float p_dropout_float = static_cast<float>(p_dropout);
1549
  float softmax_scale_float = static_cast<float>(softmax_scale);
1550
  float softcap_float = static_cast<float>(softcap);
1551
+ int max_seqlen_q_int = static_cast<int>(max_seqlen_q);
1552
+ int max_seqlen_k_int = static_cast<int>(max_seqlen_k);
1553
  int window_size_left_int = static_cast<int>(window_size_left);
1554
  int window_size_right_int = static_cast<int>(window_size_right);
1555
+
1556
+ return FLASH_NAMESPACE::mha_varlen_fwd(
1557
+ const_cast<at::Tensor &>(q), k, v, out,
1558
+ cu_seqlens_q, cu_seqlens_k,
1559
+ seqused_k, leftpad_k, block_table, alibi_slopes,
1560
+ max_seqlen_q_int, max_seqlen_k_int,
1561
+ p_dropout_float, softmax_scale_float,
1562
+ zero_tensors, is_causal,
1563
+ window_size_left_int, window_size_right_int,
1564
+ softcap_float, return_softmax, gen);
1565
  }
1566
 
1567
  std::vector<at::Tensor>
 
1586
  std::optional<at::Tensor> &rng_state) {
1587
 
1588
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1589
+
1590
  // Prepare the optional arguments as non-const references.
1591
  std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
1592
  std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
 
1600
  int window_size_left_int = static_cast<int>(window_size_left);
1601
  int window_size_right_int = static_cast<int>(window_size_right);
1602
 
1603
+ return FLASH_NAMESPACE::mha_bwd(
1604
+ const_cast<at::Tensor &>(dout),
1605
+ q, k, v, out, softmax_lse,
1606
+ dq, dk, dv, alibi_slopes,
1607
+ p_dropout_float, softmax_scale_float,
1608
+ is_causal,
1609
+ window_size_left_int, window_size_right_int,
1610
+ softcap_float, deterministic,
1611
+ gen, rng_state);
1612
  }
1613
 
1614
 
 
1619
  const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1620
  const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1621
  const at::Tensor &softmax_lse, // b x h x seqlen_q
1622
+ const std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
1623
+ const std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
1624
+ const std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
1625
  const at::Tensor &cu_seqlens_q, // batch_size + 1
1626
  const at::Tensor &cu_seqlens_k, // batch_size + 1
1627
+ const std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
1628
  const int64_t max_seqlen_q,
1629
  const int64_t max_seqlen_k,
1630
  const double p_dropout,
1631
  const double softmax_scale,
1632
+ const bool zero_tensors,
1633
  const bool is_causal,
1634
  const int64_t window_size_left,
1635
  const int64_t window_size_right,
 
1637
  const bool deterministic,
1638
  std::optional<at::Generator> gen_,
1639
  std::optional<at::Tensor> &rng_state) {
1640
+
1641
  auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1642
+
1643
+ // Prepare the optional arguments as non-const references.
1644
+ std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
1645
+ std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
1646
+ std::optional<at::Tensor> dv = dv_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dv_.value())) : std::nullopt;
1647
+ std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1648
+
1649
  // Convert double to float and int64_t to int.
1650
  float p_dropout_float = static_cast<float>(p_dropout);
1651
  float softmax_scale_float = static_cast<float>(softmax_scale);
1652
  float softcap_float = static_cast<float>(softcap);
1653
+ int max_seqlen_q_int = static_cast<int>(max_seqlen_q);
1654
+ int max_seqlen_k_int = static_cast<int>(max_seqlen_k);
1655
  int window_size_left_int = static_cast<int>(window_size_left);
1656
  int window_size_right_int = static_cast<int>(window_size_right);
1657
 
1658
+ return FLASH_NAMESPACE::mha_varlen_bwd(
1659
+ const_cast<at::Tensor &>(dout),
1660
+ q, k, v, out, softmax_lse,
1661
+ dq, dk, dv,
1662
+ cu_seqlens_q, cu_seqlens_k,
1663
+ alibi_slopes,
1664
+ max_seqlen_q_int, max_seqlen_k_int,
1665
+ p_dropout_float, softmax_scale_float,
1666
+ zero_tensors, is_causal,
1667
+ window_size_left_int, window_size_right_int,
1668
+ softcap_float, deterministic,
1669
+ gen, rng_state);
1670
  }
1671
 
1672
  std::vector<at::Tensor>
 
1691
  bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1692
  const int64_t num_splits
1693
  ) {
1694
+
1695
+ // Prepare the optional arguments as const references where needed
1696
+ std::optional<const at::Tensor> k = k_.has_value() ? std::optional<const at::Tensor>(k_.value()) : std::nullopt;
1697
+ std::optional<const at::Tensor> v = v_.has_value() ? std::optional<const at::Tensor>(v_.value()) : std::nullopt;
1698
+ std::optional<const at::Tensor> seqlens_k = seqlens_k_.has_value() ? std::optional<const at::Tensor>(seqlens_k_.value()) : std::nullopt;
1699
+ std::optional<const at::Tensor> rotary_cos = rotary_cos_.has_value() ? std::optional<const at::Tensor>(rotary_cos_.value()) : std::nullopt;
1700
+ std::optional<const at::Tensor> rotary_sin = rotary_sin_.has_value() ? std::optional<const at::Tensor>(rotary_sin_.value()) : std::nullopt;
1701
+ std::optional<const at::Tensor> cache_batch_idx = cache_batch_idx_.has_value() ? std::optional<const at::Tensor>(cache_batch_idx_.value()) : std::nullopt;
1702
+ std::optional<const at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<const at::Tensor>(leftpad_k_.value()) : std::nullopt;
1703
+
1704
+ // For non-const tensors
1705
  std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
1706
  std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1707
  std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1708
+
1709
+ if (!out.has_value()){
1710
+ out = torch::empty_like(q);
1711
+ }
1712
+
1713
  // Convert double to float and int64_t to int.
1714
  float softmax_scale_float = static_cast<float>(softmax_scale);
1715
  float softcap_float = static_cast<float>(softcap);
1716
  int window_size_left_int = static_cast<int>(window_size_left);
1717
  int window_size_right_int = static_cast<int>(window_size_right);
1718
  int num_splits_int = static_cast<int>(num_splits);
1719
+
1720
+ return FLASH_NAMESPACE::mha_fwd_kvcache(
1721
+ const_cast<at::Tensor &>(q),
1722
+ kcache, vcache,
1723
+ k, v,
1724
+ seqlens_k,
1725
+ rotary_cos, rotary_sin,
1726
+ cache_batch_idx,
1727
+ leftpad_k,
1728
+ block_table, alibi_slopes,
1729
+ out,
1730
+ softmax_scale_float,
1731
+ is_causal,
1732
+ window_size_left_int, window_size_right_int,
1733
+ softcap_float,
1734
+ is_rotary_interleaved,
1735
+ num_splits_int
1736
+ );
1737
  }