|
#pragma once |
|
|
|
#include <torch/torch.h> |
|
|
|
std::vector<at::Tensor> |
|
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, |
|
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, |
|
const c10::optional<at::Tensor> &D_, |
|
const c10::optional<at::Tensor> &z_, |
|
const c10::optional<at::Tensor> &delta_bias_, |
|
bool delta_softplus); |
|
|
|
std::vector<at::Tensor> |
|
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, |
|
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, |
|
const c10::optional<at::Tensor> &D_, |
|
const c10::optional<at::Tensor> &z_, |
|
const c10::optional<at::Tensor> &delta_bias_, |
|
const at::Tensor &dout, |
|
const c10::optional<at::Tensor> &x_, |
|
const c10::optional<at::Tensor> &out_, |
|
c10::optional<at::Tensor> dz_, |
|
bool delta_softplus, |
|
bool recompute_out_z); |
|
|