drbh
commited on
Commit
·
eda872e
1
Parent(s):
4080f9c
feat: add remaining torch binding def and impls
Browse files
torch-ext/torch_binding.cpp
CHANGED
@@ -19,6 +19,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
19 |
|
20 |
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
21 |
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
}
|
23 |
|
24 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
19 |
|
20 |
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
21 |
ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
|
22 |
+
|
23 |
+
ops.def("mha_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
|
24 |
+
ops.impl("mha_bwd", torch::kCUDA, &mha_bwd);
|
25 |
+
|
26 |
+
ops.def("mha_varlen_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor? alibi_slopes_, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
|
27 |
+
ops.impl("mha_varlen_bwd", torch::kCUDA, &mha_varlen_bwd);
|
28 |
+
|
29 |
+
ops.def("mha_fwd_kvcache(Tensor! q, Tensor! kcache, Tensor! vcache, Tensor? k_, Tensor? v_, Tensor? seqlens_k_, Tensor? rotary_cos_, Tensor? rotary_sin_, Tensor? cache_batch_idx_, Tensor? leftpad_k_, Tensor? block_table_, Tensor? alibi_slopes_, Tensor? out_, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
|
30 |
+
ops.impl("mha_fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
|
31 |
}
|
32 |
|
33 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|