File size: 1,052 Bytes
23d26f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
#pragma once
#include <torch/torch.h>
std::vector<at::Tensor>
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
bool delta_softplus);
std::vector<at::Tensor>
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
const c10::optional<at::Tensor> &D_,
const c10::optional<at::Tensor> &z_,
const c10::optional<at::Tensor> &delta_bias_,
const at::Tensor &dout,
const c10::optional<at::Tensor> &x_,
const c10::optional<at::Tensor> &out_,
c10::optional<at::Tensor> dz_,
bool delta_softplus,
bool recompute_out_z);
|