kernel
drbh commited on
Commit
b0d3c12
·
1 Parent(s): 9002ff5

fix: expand build combinations and include all files

Browse files
build.toml CHANGED
@@ -33,99 +33,98 @@ src = [
33
  "flash_attn/src/static_switch.h",
34
  "flash_attn/src/utils.h",
35
 
36
- ## TODO: include bwd kernels
37
 
38
- # "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
39
- # "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
40
- # "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
41
- # "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
42
- # "flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
43
- # "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
44
- # "flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
45
- # "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
46
- # "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
47
- # "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
48
- # "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
49
- # "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
50
- # "flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
51
- # "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
52
- # "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
53
- # "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
54
  "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
55
  "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
56
  "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
57
  "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
58
- # "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
59
- # "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
60
- # "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
61
- # "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
62
- # "flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
63
- # "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
64
- # "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
65
- # "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
66
  "flash_attn/src/flash_bwd_kernel.h",
67
  "flash_attn/src/flash_bwd_launch_template.h",
68
  "flash_attn/src/flash_bwd_preprocess_kernel.h",
69
 
70
- ## TODO: include fwd kernels
71
-
72
- # "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
73
- # "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
74
- # "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
75
- # "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
76
- # "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
77
- # "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
78
- # "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
79
- # "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
80
- # "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
81
- # "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
82
- # "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
83
- # "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
84
- # "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
85
- # "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
86
- # "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
87
- # "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
88
  "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
89
  "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
90
  "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
91
  "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
92
- # "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
93
- # "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
94
- # "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
95
- # "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
96
- # "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
97
- # "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
98
- # "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
99
- # "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
100
  "flash_attn/src/flash_fwd_kernel.h",
101
  "flash_attn/src/flash_fwd_launch_template.h",
102
- # "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
103
- # "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
104
- # "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
105
- # "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
106
- # "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
107
- # "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
108
- # "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
109
- # "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
110
- # "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
111
- # "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
112
- # "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
113
- # "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
114
- # "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
115
- # "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
116
- # "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
117
- # "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
118
  "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
119
  "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
120
  "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
121
  "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
122
- # "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
123
- # "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
124
- # "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
125
- # "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
126
- # "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
127
- # "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
128
- # "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
129
- # "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
130
  ]
131
  depends = ["torch", "cutlass_3_6"]
 
33
  "flash_attn/src/static_switch.h",
34
  "flash_attn/src/utils.h",
35
 
36
+ ## bwd kernels
37
 
38
+ "flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
39
+ "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
40
+ "flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
41
+ "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
42
+ "flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
43
+ "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
44
+ "flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
45
+ "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
46
+ "flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
47
+ "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
48
+ "flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
49
+ "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
50
+ "flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
51
+ "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
52
+ "flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
53
+ "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
54
  "flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
55
  "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
56
  "flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
57
  "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
58
+ "flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
59
+ "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
60
+ "flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
61
+ "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
62
+ "flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
63
+ "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
64
+ "flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
65
+ "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
66
  "flash_attn/src/flash_bwd_kernel.h",
67
  "flash_attn/src/flash_bwd_launch_template.h",
68
  "flash_attn/src/flash_bwd_preprocess_kernel.h",
69
 
70
+ ## fwd kernels
71
+ "flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
72
+ "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
73
+ "flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
74
+ "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
75
+ "flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
76
+ "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
77
+ "flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
78
+ "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
79
+ "flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
80
+ "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
81
+ "flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
82
+ "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
83
+ "flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
84
+ "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
85
+ "flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
86
+ "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
 
87
  "flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
88
  "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
89
  "flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
90
  "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
91
+ "flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
92
+ "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
93
+ "flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
94
+ "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
95
+ "flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
96
+ "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
97
+ "flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
98
+ "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
99
  "flash_attn/src/flash_fwd_kernel.h",
100
  "flash_attn/src/flash_fwd_launch_template.h",
101
+ "flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
102
+ "flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
103
+ "flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
104
+ "flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
105
+ "flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
106
+ "flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
107
+ "flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
108
+ "flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
109
+ "flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
110
+ "flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
111
+ "flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
112
+ "flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
113
+ "flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
114
+ "flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
115
+ "flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
116
+ "flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
117
  "flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
118
  "flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
119
  "flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
120
  "flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
121
+ "flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
122
+ "flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
123
+ "flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
124
+ "flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
125
+ "flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
126
+ "flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
127
+ "flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
128
+ "flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
129
  ]
130
  depends = ["torch", "cutlass_3_6"]
flash_attn/flash_api.cpp CHANGED
@@ -1477,10 +1477,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
1477
 
1478
  // NOTE: wrap the namespaced functions so all types are doubles and longs
1479
  std::vector<at::Tensor>
1480
- mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1481
- const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1482
- const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1483
- const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1484
  const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1485
  const double p_dropout,
1486
  const double softmax_scale,
@@ -1509,4 +1509,159 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x roun
1509
  int window_size_right_int = static_cast<int>(window_size_right);
1510
 
1511
  return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1512
  }
 
1477
 
1478
  // NOTE: wrap the namespaced functions so all types are doubles and longs
1479
  std::vector<at::Tensor>
1480
+ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1481
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1482
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1483
+ const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1484
  const c10::optional<torch::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1485
  const double p_dropout,
1486
  const double softmax_scale,
 
1509
  int window_size_right_int = static_cast<int>(window_size_right);
1510
 
1511
  return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1512
+ }
1513
+
1514
+ std::vector<at::Tensor>
1515
+ mha_varlen_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1516
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1517
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
1518
+ const c10::optional<torch::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
1519
+ const at::Tensor &cu_seqlens_q, // batch_size + 1
1520
+ const at::Tensor &cu_seqlens_k, // batch_size + 1
1521
+ const int64_t max_seqlen_q,
1522
+ const int64_t max_seqlen_k,
1523
+ const double p_dropout,
1524
+ const double softmax_scale,
1525
+ bool is_causal,
1526
+ const int64_t window_size_left,
1527
+ const int64_t window_size_right,
1528
+ const double softcap,
1529
+ const bool return_softmax,
1530
+ const c10::optional<at::Generator> gen_) {
1531
+
1532
+ auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1533
+
1534
+ // Prepare the optional arguments as non-const references.
1535
+ std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1536
+
1537
+ if (!out.has_value()){
1538
+ out = torch::empty_like(q);
1539
+ }
1540
+
1541
+ // Convert double to float and int64_t to int.
1542
+ float p_dropout_float = static_cast<float>(p_dropout);
1543
+ float softmax_scale_float = static_cast<float>(softmax_scale);
1544
+ float softcap_float = static_cast<float>(softcap);
1545
+ int window_size_left_int = static_cast<int>(window_size_left);
1546
+ int window_size_right_int = static_cast<int>(window_size_right);
1547
+
1548
+ return FLASH_NAMESPACE::mha_varlen_fwd(const_cast<at::Tensor &>(q), k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
1549
+ }
1550
+
1551
+ std::vector<at::Tensor>
1552
+ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
1553
+ const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1554
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
1555
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1556
+ const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1557
+ const at::Tensor &softmax_lse, // b x h x seqlen_q
1558
+ const std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
1559
+ const std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
1560
+ const std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
1561
+ const std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1562
+ const double p_dropout, // probability to drop
1563
+ const double softmax_scale,
1564
+ const bool is_causal,
1565
+ const int64_t window_size_left,
1566
+ const int64_t window_size_right,
1567
+ const double softcap,
1568
+ const bool deterministic,
1569
+ std::optional<at::Generator> gen_,
1570
+ std::optional<at::Tensor> &rng_state) {
1571
+
1572
+ auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1573
+
1574
+ // Prepare the optional arguments as non-const references.
1575
+ std::optional<at::Tensor> dq = dq_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dq_.value())) : std::nullopt;
1576
+ std::optional<at::Tensor> dk = dk_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dk_.value())) : std::nullopt;
1577
+ std::optional<at::Tensor> dv = dv_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(dv_.value())) : std::nullopt;
1578
+ std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1579
+
1580
+ // Convert double to float and int64_t to int.
1581
+ float p_dropout_float = static_cast<float>(p_dropout);
1582
+ float softmax_scale_float = static_cast<float>(softmax_scale);
1583
+ float softcap_float = static_cast<float>(softcap);
1584
+ int window_size_left_int = static_cast<int>(window_size_left);
1585
+ int window_size_right_int = static_cast<int>(window_size_right);
1586
+
1587
+ return FLASH_NAMESPACE::mha_bwd(const_cast<at::Tensor &>(dout), q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, deterministic, gen, rng_state);
1588
+ }
1589
+
1590
+
1591
+ std::vector<at::Tensor>
1592
+ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
1593
+ const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1594
+ const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
1595
+ const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
1596
+ const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
1597
+ const at::Tensor &softmax_lse, // b x h x seqlen_q
1598
+ const at::Tensor &cu_seqlens_q, // batch_size + 1
1599
+ const at::Tensor &cu_seqlens_k, // batch_size + 1
1600
+ const int64_t max_seqlen_q,
1601
+ const int64_t max_seqlen_k,
1602
+ const double p_dropout,
1603
+ const double softmax_scale,
1604
+ const bool is_causal,
1605
+ const int64_t window_size_left,
1606
+ const int64_t window_size_right,
1607
+ const double softcap,
1608
+ const bool deterministic,
1609
+ std::optional<at::Generator> gen_,
1610
+ std::optional<at::Tensor> &rng_state) {
1611
+
1612
+ auto gen = gen_.value_or(at::cuda::detail::getDefaultCUDAGenerator());
1613
+
1614
+ // Convert double to float and int64_t to int.
1615
+ float p_dropout_float = static_cast<float>(p_dropout);
1616
+ float softmax_scale_float = static_cast<float>(softmax_scale);
1617
+ float softcap_float = static_cast<float>(softcap);
1618
+ int window_size_left_int = static_cast<int>(window_size_left);
1619
+ int window_size_right_int = static_cast<int>(window_size_right);
1620
+
1621
+ return FLASH_NAMESPACE::mha_varlen_bwd(const_cast<at::Tensor &>(dout), q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, deterministic, gen, rng_state);
1622
+ }
1623
+
1624
+ std::vector<at::Tensor>
1625
+ mha_fwd_kvcache(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
1626
+ const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1627
+ const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
1628
+ const c10::optional<torch::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
1629
+ const c10::optional<torch::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
1630
+ const c10::optional<torch::Tensor> &seqlens_k_, // batch_size
1631
+ const c10::optional<torch::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
1632
+ const c10::optional<torch::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
1633
+ const c10::optional<torch::Tensor> &cache_batch_idx_, // indices to index into the KV cache
1634
+ const c10::optional<torch::Tensor> &leftpad_k_, // batch_size
1635
+ const c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
1636
+ const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
1637
+ const c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
1638
+ const double softmax_scale,
1639
+ bool is_causal,
1640
+ const int64_t window_size_left,
1641
+ const int64_t window_size_right,
1642
+ const double softcap,
1643
+ bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
1644
+ const int64_t num_splits
1645
+ ) {
1646
+
1647
+ // Prepare the optional arguments as non-const references.
1648
+ std::optional<at::Tensor> k = k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(k_.value())) : std::nullopt;
1649
+ std::optional<at::Tensor> v = v_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(v_.value())) : std::nullopt;
1650
+ std::optional<at::Tensor> seqlens_k = seqlens_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(seqlens_k_.value())) : std::nullopt;
1651
+ std::optional<at::Tensor> rotary_cos = rotary_cos_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(rotary_cos_.value())) : std::nullopt;
1652
+ std::optional<at::Tensor> rotary_sin = rotary_sin_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(rotary_sin_.value())) : std::nullopt;
1653
+ std::optional<at::Tensor> cache_batch_idx = cache_batch_idx_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(cache_batch_idx_.value())) : std::nullopt;
1654
+ std::optional<at::Tensor> leftpad_k = leftpad_k_.has_value() ? std::optional<at::Tensor>(const_cast<at::at::Tensor &>(leftpad_k_.value())) : std::nullopt;
1655
+ std::optional<at::Tensor> block_table = block_table_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(block_table_.value())) : std::nullopt;
1656
+ std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
1657
+ std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
1658
+
1659
+ // Convert double to float and int64_t to int.
1660
+ float softmax_scale_float = static_cast<float>(softmax_scale);
1661
+ float softcap_float = static_cast<float>(softcap);
1662
+ int window_size_left_int = static_cast<int>(window_size_left);
1663
+ int window_size_right_int = static_cast<int>(window_size_right);
1664
+ int num_splits_int = static_cast<int>(num_splits);
1665
+
1666
+ return FLASH_NAMESPACE::mha_fwd_kvcache(const_cast<at::Tensor &>(q), kcache, vcache, k, v, seqlens_k, rotary_cos, rotary_sin, cache_batch_idx, leftpad_k, block_table, alibi_slopes, out, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, is_rotary_interleaved, num_splits_int);
1667
  }
flash_attn/src/static_switch.h CHANGED
@@ -87,33 +87,28 @@
87
  } \
88
  }()
89
 
90
- // #define HEADDIM_SWITCH(HEADDIM, ...) \
91
- // [&] { \
92
- // if (HEADDIM <= 32) { \
93
- // constexpr static int kHeadDim = 32; \
94
- // return __VA_ARGS__(); \
95
- // } else if (HEADDIM <= 64) { \
96
- // constexpr static int kHeadDim = 64; \
97
- // return __VA_ARGS__(); \
98
- // } else if (HEADDIM <= 96) { \
99
- // constexpr static int kHeadDim = 96; \
100
- // return __VA_ARGS__(); \
101
- // } else if (HEADDIM <= 128) { \
102
- // constexpr static int kHeadDim = 128; \
103
- // return __VA_ARGS__(); \
104
- // } else if (HEADDIM <= 160) { \
105
- // constexpr static int kHeadDim = 160; \
106
- // return __VA_ARGS__(); \
107
- // } else if (HEADDIM <= 192) { \
108
- // constexpr static int kHeadDim = 192; \
109
- // return __VA_ARGS__(); \
110
- // } else if (HEADDIM <= 256) { \
111
- // constexpr static int kHeadDim = 256; \
112
- // return __VA_ARGS__(); \
113
- // } \
114
- // }()
115
  #define HEADDIM_SWITCH(HEADDIM, ...) \
116
- [&] { \
117
- constexpr static int kHeadDim = 32; \
118
- return __VA_ARGS__(); \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  }()
 
87
  } \
88
  }()
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  #define HEADDIM_SWITCH(HEADDIM, ...) \
91
+ [&] { \
92
+ if (HEADDIM <= 32) { \
93
+ constexpr static int kHeadDim = 32; \
94
+ return __VA_ARGS__(); \
95
+ } else if (HEADDIM <= 64) { \
96
+ constexpr static int kHeadDim = 64; \
97
+ return __VA_ARGS__(); \
98
+ } else if (HEADDIM <= 96) { \
99
+ constexpr static int kHeadDim = 96; \
100
+ return __VA_ARGS__(); \
101
+ } else if (HEADDIM <= 128) { \
102
+ constexpr static int kHeadDim = 128; \
103
+ return __VA_ARGS__(); \
104
+ } else if (HEADDIM <= 160) { \
105
+ constexpr static int kHeadDim = 160; \
106
+ return __VA_ARGS__(); \
107
+ } else if (HEADDIM <= 192) { \
108
+ constexpr static int kHeadDim = 192; \
109
+ return __VA_ARGS__(); \
110
+ } else if (HEADDIM <= 256) { \
111
+ constexpr static int kHeadDim = 256; \
112
+ return __VA_ARGS__(); \
113
+ } \
114
  }()
torch-ext/torch_binding.cpp CHANGED
@@ -16,6 +16,9 @@
16
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
17
  ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
18
  ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
 
 
 
19
  }
20
 
21
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
16
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
17
  ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
18
  ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
19
+
20
+ ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
21
+ ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
22
  }
23
 
24
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)