kernel
drbh commited on
Commit
eda872e
·
1 Parent(s): 4080f9c

feat: add remaining torch binding def and impls

Browse files
Files changed (1) hide show
  1. torch-ext/torch_binding.cpp +9 -0
torch-ext/torch_binding.cpp CHANGED
@@ -19,6 +19,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
19
 
20
  ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
21
  ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
 
 
 
 
 
 
 
 
 
22
  }
23
 
24
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
19
 
20
  ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
21
  ops.impl("mha_varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
22
+
23
+ ops.def("mha_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor? alibi_slopes_, float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
24
+ ops.impl("mha_bwd", torch::kCUDA, &mha_bwd);
25
+
26
+ ops.def("mha_varlen_bwd(Tensor! dout, Tensor! q, Tensor! k, Tensor! v, Tensor! out, Tensor! softmax_lse, Tensor? dq_, Tensor? dk_, Tensor? dv_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor? alibi_slopes_, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, float softcap, bool deterministic, Generator? gen_, Tensor? rng_state) -> Tensor[]");
27
+ ops.impl("mha_varlen_bwd", torch::kCUDA, &mha_varlen_bwd);
28
+
29
+ ops.def("mha_fwd_kvcache(Tensor! q, Tensor! kcache, Tensor! vcache, Tensor? k_, Tensor? v_, Tensor? seqlens_k_, Tensor? rotary_cos_, Tensor? rotary_sin_, Tensor? cache_batch_idx_, Tensor? leftpad_k_, Tensor? block_table_, Tensor? alibi_slopes_, Tensor? out_, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
30
+ ops.impl("mha_fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
31
  }
32
 
33
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)