drbh
commited on
Commit
·
b0d3c12
1
Parent(s):
9002ff5
fix: expand build combinations and include all files
Browse files- build.toml +74 -75
- flash_attn/flash_api.cpp +159 -4
- flash_attn/src/static_switch.h +23 -28
- torch-ext/torch_binding.cpp +3 -0
build.toml
CHANGED
@@ -33,99 +33,98 @@ src = [
|
|
33 |
"flash_attn/src/static_switch.h",
|
34 |
"flash_attn/src/utils.h",
|
35 |
|
36 |
-
##
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
"flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
|
55 |
"flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
|
56 |
"flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
|
57 |
"flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
"flash_attn/src/flash_bwd_kernel.h",
|
67 |
"flash_attn/src/flash_bwd_launch_template.h",
|
68 |
"flash_attn/src/flash_bwd_preprocess_kernel.h",
|
69 |
|
70 |
-
##
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
# "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
|
88 |
"flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
|
89 |
"flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
|
90 |
"flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
|
91 |
"flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
"flash_attn/src/flash_fwd_kernel.h",
|
101 |
"flash_attn/src/flash_fwd_launch_template.h",
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
"flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
|
119 |
"flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
|
120 |
"flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
|
121 |
"flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
]
|
131 |
depends = ["torch", "cutlass_3_6"]
|
|
|
33 |
"flash_attn/src/static_switch.h",
|
34 |
"flash_attn/src/utils.h",
|
35 |
|
36 |
+
## bwd kernels
|
37 |
|
38 |
+
"flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
39 |
+
"flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
40 |
+
"flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
41 |
+
"flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
42 |
+
"flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
|
43 |
+
"flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
|
44 |
+
"flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
|
45 |
+
"flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
|
46 |
+
"flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
|
47 |
+
"flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
48 |
+
"flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
|
49 |
+
"flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
|
50 |
+
"flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
|
51 |
+
"flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
|
52 |
+
"flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
|
53 |
+
"flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
|
54 |
"flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
|
55 |
"flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
|
56 |
"flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
|
57 |
"flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
|
58 |
+
"flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
|
59 |
+
"flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
|
60 |
+
"flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
|
61 |
+
"flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
|
62 |
+
"flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
|
63 |
+
"flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
|
64 |
+
"flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
|
65 |
+
"flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
|
66 |
"flash_attn/src/flash_bwd_kernel.h",
|
67 |
"flash_attn/src/flash_bwd_launch_template.h",
|
68 |
"flash_attn/src/flash_bwd_preprocess_kernel.h",
|
69 |
|
70 |
+
## fwd kernels
|
71 |
+
"flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
|
72 |
+
"flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
73 |
+
"flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
74 |
+
"flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
75 |
+
"flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
|
76 |
+
"flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
77 |
+
"flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
|
78 |
+
"flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
79 |
+
"flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
80 |
+
"flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
81 |
+
"flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
82 |
+
"flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
|
83 |
+
"flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
|
84 |
+
"flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
|
85 |
+
"flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
|
86 |
+
"flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
|
|
|
87 |
"flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
|
88 |
"flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
|
89 |
"flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
|
90 |
"flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
|
91 |
+
"flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
|
92 |
+
"flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
|
93 |
+
"flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
|
94 |
+
"flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
|
95 |
+
"flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
|
96 |
+
"flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
|
97 |
+
"flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
|
98 |
+
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
99 |
"flash_attn/src/flash_fwd_kernel.h",
|
100 |
"flash_attn/src/flash_fwd_launch_template.h",
|
101 |
+
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
102 |
+
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
103 |
+
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
104 |
+
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
105 |
+
"flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
|
106 |
+
"flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
|
107 |
+
"flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
|
108 |
+
"flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
|
109 |
+
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
110 |
+
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
111 |
+
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
112 |
+
"flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
|
113 |
+
"flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
|
114 |
+
"flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
|
115 |
+
"flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
|
116 |
+
"flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
|
117 |
"flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
|
118 |
"flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
|
119 |
"flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
|
120 |
"flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
|
121 |
+
"flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
|
122 |
+
"flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
|
123 |
+
"flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
|
124 |
+
"flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
|
125 |
+
"flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
|
126 |
+
"flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
|
127 |
+
"flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
|
128 |
+
"flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
|
129 |
]
|
130 |
depends = ["torch", "cutlass_3_6"]
|
flash_attn/flash_api.cpp
CHANGED
@@ -1477,10 +1477,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
1477 |
|
1478 |
// NOTE: wrap the namespaced functions so all types are doubles and longs
|
1479 |
std::vector<at::Tensor>
|
1480 |
-
mha_fwd(const at::Tensor &q,
|
1481 |
-
const at::Tensor &k,
|
1482 |
-
const at::Tensor &v,
|
1483 |
-
const c10::optional<torch::Tensor> &out_,
|
1484 |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1485 |
const double p_dropout,
|
1486 |
const double softmax_scale,
|
@@ -1509,4 +1509,159 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x roun
|
|
1509 |
int window_size_right_int = static_cast<int>(window_size_right);
|
1510 |
|
1511 |
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1512 |
}
|
|
|
1477 |
|
1478 |
// NOTE: wrap the namespaced functions so all types are doubles and longs
|
1479 |
std::vector<at::Tensor>
|
1480 |
+
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1481 |
+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
1482 |
+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
1483 |
+
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1484 |
const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1485 |
const double p_dropout,
|
1486 |
const double softmax_scale,
|
|
|
1509 |
int window_size_right_int = static_cast<int>(window_size_right);
|
1510 |
|
1511 |
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
1512 |
+
}
|
1513 |
+
|
1514 |
+
std::vector<at::Tensor>
|
1515 |
+
mha_varlen_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1516 |
+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
1517 |
+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
|
1518 |
+
const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
|
1519 |
+
const at::Tensor &cu_seqlens_q, // batch_size + 1
|
1520 |
+
const at::Tensor &cu_seqlens_k, // batch_size + 1
|
1521 |
+
const int64_t max_seqlen_q,
|
1522 |
+
const int64_t max_seqlen_k,
|
1523 |
+
const double p_dropout,
|
1524 |
+
const double softmax_scale,
|
1525 |
+
bool is_causal,
|
1526 |
+
const int64_t window_size_left,
|
1527 |
+
const int64_t window_size_right,
|
1528 |
+
const double softcap,
|
1529 |
+
const bool return_softmax,
|
1530 |
+
const c10::optional<at::Generator> gen_) {
|
1531 |
+
|
1532 |
+
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1533 |
+
|
1534 |
+
// Prepare the optional arguments as non-const references.
|
1535 |
+
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
1536 |
+
|
1537 |
+
if (!out.has_value()){
|
1538 |
+
out = torch::empty_like(q);
|
1539 |
+
}
|
1540 |
+
|
1541 |
+
// Convert double to float and int64_t to int.
|
1542 |
+
float p_dropout_float = static_cast<float>(p_dropout);
|
1543 |
+
float softmax_scale_float = static_cast<float>(softmax_scale);
|
1544 |
+
float softcap_float = static_cast<float>(softcap);
|
1545 |
+
int window_size_left_int = static_cast<int>(window_size_left);
|
1546 |
+
int window_size_right_int = static_cast<int>(window_size_right);
|
1547 |
+
|
1548 |
+
return FLASH_NAMESPACE::mha_varlen_fwd(const_cast<at::Tensor &>(q), k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
1549 |
+
}
|
1550 |
+
|
1551 |
+
std::vector<at::Tensor>
|
1552 |
+
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
1553 |
+
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
1554 |
+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
1555 |
+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
1556 |
+
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
1557 |
+
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
1558 |
+
const std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
1559 |
+
const std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
1560 |
+
const std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
1561 |
+
const std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1562 |
+
const double p_dropout, // probability to drop
|
1563 |
+
const double softmax_scale,
|
1564 |
+
const bool is_causal,
|
1565 |
+
const int64_t window_size_left,
|
1566 |
+
const int64_t window_size_right,
|
1567 |
+
const double softcap,
|
1568 |
+
const bool deterministic,
|
1569 |
+
std::optional<at::Generator> gen_,
|
1570 |
+
std::optional<at::Tensor> &rng_state) {
|
1571 |
+
|
1572 |
+
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1573 |
+
|
1574 |
+
// Prepare the optional arguments as non-const references.
|
1575 |
+
std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
|
1576 |
+
std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
|
1577 |
+
std::optional<at::Tensor> dv = dv_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dv_.value())) : std::nullopt;
|
1578 |
+
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
1579 |
+
|
1580 |
+
// Convert double to float and int64_t to int.
|
1581 |
+
float p_dropout_float = static_cast<float>(p_dropout);
|
1582 |
+
float softmax_scale_float = static_cast<float>(softmax_scale);
|
1583 |
+
float softcap_float = static_cast<float>(softcap);
|
1584 |
+
int window_size_left_int = static_cast<int>(window_size_left);
|
1585 |
+
int window_size_right_int = static_cast<int>(window_size_right);
|
1586 |
+
|
1587 |
+
return FLASH_NAMESPACE::mha_bwd(const_cast<at::Tensor &>(dout), q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, deterministic, gen, rng_state);
|
1588 |
+
}
|
1589 |
+
|
1590 |
+
|
1591 |
+
std::vector<at::Tensor>
|
1592 |
+
mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
|
1593 |
+
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
1594 |
+
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
1595 |
+
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
1596 |
+
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
1597 |
+
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
1598 |
+
const at::Tensor &cu_seqlens_q, // batch_size + 1
|
1599 |
+
const at::Tensor &cu_seqlens_k, // batch_size + 1
|
1600 |
+
const int64_t max_seqlen_q,
|
1601 |
+
const int64_t max_seqlen_k,
|
1602 |
+
const double p_dropout,
|
1603 |
+
const double softmax_scale,
|
1604 |
+
const bool is_causal,
|
1605 |
+
const int64_t window_size_left,
|
1606 |
+
const int64_t window_size_right,
|
1607 |
+
const double softcap,
|
1608 |
+
const bool deterministic,
|
1609 |
+
std::optional<at::Generator> gen_,
|
1610 |
+
std::optional<at::Tensor> &rng_state) {
|
1611 |
+
|
1612 |
+
auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
|
1613 |
+
|
1614 |
+
// Convert double to float and int64_t to int.
|
1615 |
+
float p_dropout_float = static_cast<float>(p_dropout);
|
1616 |
+
float softmax_scale_float = static_cast<float>(softmax_scale);
|
1617 |
+
float softcap_float = static_cast<float>(softcap);
|
1618 |
+
int window_size_left_int = static_cast<int>(window_size_left);
|
1619 |
+
int window_size_right_int = static_cast<int>(window_size_right);
|
1620 |
+
|
1621 |
+
return FLASH_NAMESPACE::mha_varlen_bwd(const_cast<at::Tensor &>(dout), q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, deterministic, gen, rng_state);
|
1622 |
+
}
|
1623 |
+
|
1624 |
+
std::vector<at::Tensor>
|
1625 |
+
mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
1626 |
+
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
1627 |
+
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
1628 |
+
const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
1629 |
+
const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
1630 |
+
const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
|
1631 |
+
const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
1632 |
+
const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
1633 |
+
const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
|
1634 |
+
const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
|
1635 |
+
const c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
1636 |
+
const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
1637 |
+
const c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
1638 |
+
const double softmax_scale,
|
1639 |
+
bool is_causal,
|
1640 |
+
const int64_t window_size_left,
|
1641 |
+
const int64_t window_size_right,
|
1642 |
+
const double softcap,
|
1643 |
+
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
1644 |
+
const int64_t num_splits
|
1645 |
+
) {
|
1646 |
+
|
1647 |
+
// Prepare the optional arguments as non-const references.
|
1648 |
+
std::optional<at::Tensor> k = k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(k_.value())) : std::nullopt;
|
1649 |
+
std::optional<at::Tensor> v = v_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(v_.value())) : std::nullopt;
|
1650 |
+
std::optional<at::Tensor> seqlens_k = seqlens_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(seqlens_k_.value())) : std::nullopt;
|
1651 |
+
std::optional<at::Tensor> rotary_cos = rotary_cos_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(rotary_cos_.value())) : std::nullopt;
|
1652 |
+
std::optional<at::Tensor> rotary_sin = rotary_sin_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(rotary_sin_.value())) : std::nullopt;
|
1653 |
+
std::optional<at::Tensor> cache_batch_idx = cache_batch_idx_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(cache_batch_idx_.value())) : std::nullopt;
|
1654 |
+
std::optional<at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::at::Tensor &>(leftpad_k_.value())) : std::nullopt;
|
1655 |
+
std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
|
1656 |
+
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
1657 |
+
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
1658 |
+
|
1659 |
+
// Convert double to float and int64_t to int.
|
1660 |
+
float softmax_scale_float = static_cast<float>(softmax_scale);
|
1661 |
+
float softcap_float = static_cast<float>(softcap);
|
1662 |
+
int window_size_left_int = static_cast<int>(window_size_left);
|
1663 |
+
int window_size_right_int = static_cast<int>(window_size_right);
|
1664 |
+
int num_splits_int = static_cast<int>(num_splits);
|
1665 |
+
|
1666 |
+
return FLASH_NAMESPACE::mha_fwd_kvcache(const_cast<at::Tensor &>(q), kcache, vcache, k, v, seqlens_k, rotary_cos, rotary_sin, cache_batch_idx, leftpad_k, block_table, alibi_slopes, out, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, is_rotary_interleaved, num_splits_int);
|
1667 |
}
|
flash_attn/src/static_switch.h
CHANGED
@@ -87,33 +87,28 @@
|
|
87 |
} \
|
88 |
}()
|
89 |
|
90 |
-
// #define HEADDIM_SWITCH(HEADDIM, ...) \
|
91 |
-
// [&] { \
|
92 |
-
// if (HEADDIM <= 32) { \
|
93 |
-
// constexpr static int kHeadDim = 32; \
|
94 |
-
// return __VA_ARGS__(); \
|
95 |
-
// } else if (HEADDIM <= 64) { \
|
96 |
-
// constexpr static int kHeadDim = 64; \
|
97 |
-
// return __VA_ARGS__(); \
|
98 |
-
// } else if (HEADDIM <= 96) { \
|
99 |
-
// constexpr static int kHeadDim = 96; \
|
100 |
-
// return __VA_ARGS__(); \
|
101 |
-
// } else if (HEADDIM <= 128) { \
|
102 |
-
// constexpr static int kHeadDim = 128; \
|
103 |
-
// return __VA_ARGS__(); \
|
104 |
-
// } else if (HEADDIM <= 160) { \
|
105 |
-
// constexpr static int kHeadDim = 160; \
|
106 |
-
// return __VA_ARGS__(); \
|
107 |
-
// } else if (HEADDIM <= 192) { \
|
108 |
-
// constexpr static int kHeadDim = 192; \
|
109 |
-
// return __VA_ARGS__(); \
|
110 |
-
// } else if (HEADDIM <= 256) { \
|
111 |
-
// constexpr static int kHeadDim = 256; \
|
112 |
-
// return __VA_ARGS__(); \
|
113 |
-
// } \
|
114 |
-
// }()
|
115 |
#define HEADDIM_SWITCH(HEADDIM, ...) \
|
116 |
-
[&] {
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
}()
|
|
|
87 |
} \
|
88 |
}()
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
#define HEADDIM_SWITCH(HEADDIM, ...) \
|
91 |
+
[&] { \
|
92 |
+
if (HEADDIM <= 32) { \
|
93 |
+
constexpr static int kHeadDim = 32; \
|
94 |
+
return __VA_ARGS__(); \
|
95 |
+
} else if (HEADDIM <= 64) { \
|
96 |
+
constexpr static int kHeadDim = 64; \
|
97 |
+
return __VA_ARGS__(); \
|
98 |
+
} else if (HEADDIM <= 96) { \
|
99 |
+
constexpr static int kHeadDim = 96; \
|
100 |
+
return __VA_ARGS__(); \
|
101 |
+
} else if (HEADDIM <= 128) { \
|
102 |
+
constexpr static int kHeadDim = 128; \
|
103 |
+
return __VA_ARGS__(); \
|
104 |
+
} else if (HEADDIM <= 160) { \
|
105 |
+
constexpr static int kHeadDim = 160; \
|
106 |
+
return __VA_ARGS__(); \
|
107 |
+
} else if (HEADDIM <= 192) { \
|
108 |
+
constexpr static int kHeadDim = 192; \
|
109 |
+
return __VA_ARGS__(); \
|
110 |
+
} else if (HEADDIM <= 256) { \
|
111 |
+
constexpr static int kHeadDim = 256; \
|
112 |
+
return __VA_ARGS__(); \
|
113 |
+
} \
|
114 |
}()
|
torch-ext/torch_binding.cpp
CHANGED
@@ -16,6 +16,9 @@
|
|
16 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
17 |
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
18 |
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
|
|
|
|
|
|
|
19 |
}
|
20 |
|
21 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
16 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
17 |
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
18 |
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
|
19 |
+
|
20 |
+
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
21 |
+
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
|
22 |
}
|
23 |
|
24 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|