mamba-ssm / torch-ext /torch_binding.h
danieldk's picture
danieldk HF Staff
Import mamba-ssm kernels
23d26f4
raw
history blame contribute delete
1.05 kB
#pragma once
#include <torch/torch.h>
std::vector<at::Tensor>
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus);
std::vector<at::Tensor>
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
const at::Tensor &dout,
const c10::optional<at::Tensor> &x_,
const c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> dz_,
bool delta_softplus,
bool recompute_out_z);