Commit
·
23d26f4
0
Parent(s):
Import mamba-ssm kernels
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +14 -0
- build.toml +34 -0
- selective-scan/reverse_scan.cuh +415 -0
- selective-scan/selective_scan.cpp +497 -0
- selective-scan/selective_scan.h +101 -0
- selective-scan/selective_scan_bwd_bf16_complex.cu +9 -0
- selective-scan/selective_scan_bwd_bf16_real.cu +9 -0
- selective-scan/selective_scan_bwd_fp16_complex.cu +9 -0
- selective-scan/selective_scan_bwd_fp16_real.cu +9 -0
- selective-scan/selective_scan_bwd_fp32_complex.cu +9 -0
- selective-scan/selective_scan_bwd_fp32_real.cu +9 -0
- selective-scan/selective_scan_bwd_kernel.cuh +561 -0
- selective-scan/selective_scan_common.h +255 -0
- selective-scan/selective_scan_fwd_bf16.cu +10 -0
- selective-scan/selective_scan_fwd_fp16.cu +10 -0
- selective-scan/selective_scan_fwd_fp32.cu +10 -0
- selective-scan/selective_scan_fwd_kernel.cuh +376 -0
- selective-scan/static_switch.h +25 -0
- selective-scan/uninitialized_copy.cuh +77 -0
- tests/ops/test_selective_scan.py +247 -0
- tests/ops/triton/test_layernorm_gated.py +103 -0
- tests/ops/triton/test_selective_state_update.py +201 -0
- tests/ops/triton/test_ssd.py +78 -0
- tests/test_generation.py +113 -0
- torch-ext/mamba_ssm/__init__.py +14 -0
- torch-ext/mamba_ssm/distributed/__init__.py +0 -0
- torch-ext/mamba_ssm/distributed/distributed_utils.py +144 -0
- torch-ext/mamba_ssm/distributed/tensor_parallel.py +326 -0
- torch-ext/mamba_ssm/models/__init__.py +0 -0
- torch-ext/mamba_ssm/models/config_mamba.py +18 -0
- torch-ext/mamba_ssm/models/mixer_seq_simple.py +338 -0
- torch-ext/mamba_ssm/modules/__init__.py +0 -0
- torch-ext/mamba_ssm/modules/block.py +107 -0
- torch-ext/mamba_ssm/modules/mamba2.py +502 -0
- torch-ext/mamba_ssm/modules/mamba2_simple.py +229 -0
- torch-ext/mamba_ssm/modules/mamba_simple.py +339 -0
- torch-ext/mamba_ssm/modules/mha.py +294 -0
- torch-ext/mamba_ssm/modules/mlp.py +34 -0
- torch-ext/mamba_ssm/modules/ssd_minimal.py +111 -0
- torch-ext/mamba_ssm/ops/__init__.py +0 -0
- torch-ext/mamba_ssm/ops/selective_scan_interface.py +659 -0
- torch-ext/mamba_ssm/ops/triton/__init__.py +0 -0
- torch-ext/mamba_ssm/ops/triton/k_activations.py +169 -0
- torch-ext/mamba_ssm/ops/triton/layer_norm.py +1166 -0
- torch-ext/mamba_ssm/ops/triton/layernorm_gated.py +437 -0
- torch-ext/mamba_ssm/ops/triton/selective_state_update.py +389 -0
- torch-ext/mamba_ssm/ops/triton/softplus.py +15 -0
- torch-ext/mamba_ssm/ops/triton/ssd_bmm.py +262 -0
- torch-ext/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
- torch-ext/mamba_ssm/ops/triton/ssd_chunk_state.py +2012 -0
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
---
|
4 |
+
|
5 |
+
## Mamba
|
6 |
+
|
7 |
+
Mamba state space kernels + models from [state-spaces/mamba](https://github.com/state-spaces/mamba).
|
8 |
+
|
9 |
+
## Warning
|
10 |
+
|
11 |
+
Some functionality is dependent on `einops` and `transformers`, however we
|
12 |
+
currently don't have any way of defining these dependencies yet. The scope
|
13 |
+
of the Hub kernel is probably too large (should maybe only contain the
|
14 |
+
selective-scan and Triton kernels).
|
build.toml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "mamba_ssm"
|
6 |
+
src = [
|
7 |
+
"torch-ext/registration.h",
|
8 |
+
"torch-ext/torch_binding.cpp",
|
9 |
+
"torch-ext/torch_binding.h"
|
10 |
+
]
|
11 |
+
pyroot = "torch-ext"
|
12 |
+
|
13 |
+
[kernel.selective_scan]
|
14 |
+
capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
|
15 |
+
src = [
|
16 |
+
"selective-scan/reverse_scan.cuh",
|
17 |
+
"selective-scan/selective_scan.cpp",
|
18 |
+
"selective-scan/selective_scan.h",
|
19 |
+
"selective-scan/selective_scan_bwd_bf16_complex.cu",
|
20 |
+
"selective-scan/selective_scan_bwd_bf16_real.cu",
|
21 |
+
"selective-scan/selective_scan_bwd_fp16_complex.cu",
|
22 |
+
"selective-scan/selective_scan_bwd_fp16_real.cu",
|
23 |
+
"selective-scan/selective_scan_bwd_fp32_complex.cu",
|
24 |
+
"selective-scan/selective_scan_bwd_fp32_real.cu",
|
25 |
+
"selective-scan/selective_scan_bwd_kernel.cuh",
|
26 |
+
"selective-scan/selective_scan_common.h",
|
27 |
+
"selective-scan/selective_scan_fwd_bf16.cu",
|
28 |
+
"selective-scan/selective_scan_fwd_fp16.cu",
|
29 |
+
"selective-scan/selective_scan_fwd_fp32.cu",
|
30 |
+
"selective-scan/selective_scan_fwd_kernel.cuh",
|
31 |
+
"selective-scan/static_switch.h",
|
32 |
+
"selective-scan/uninitialized_copy.cuh",
|
33 |
+
]
|
34 |
+
depends = [ "torch" ]
|
selective-scan/reverse_scan.cuh
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cub/config.cuh>
|
9 |
+
|
10 |
+
#include <cub/util_ptx.cuh>
|
11 |
+
#include <cub/util_type.cuh>
|
12 |
+
#include <cub/block/block_raking_layout.cuh>
|
13 |
+
// #include <cub/detail/uninitialized_copy.cuh>
|
14 |
+
#else
|
15 |
+
#include <hipcub/hipcub.hpp>
|
16 |
+
namespace cub = hipcub;
|
17 |
+
#endif
|
18 |
+
#include "uninitialized_copy.cuh"
|
19 |
+
|
20 |
+
/**
|
21 |
+
* Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
|
22 |
+
*/
|
23 |
+
template <
|
24 |
+
int LENGTH,
|
25 |
+
typename T,
|
26 |
+
typename ReductionOp>
|
27 |
+
__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
|
28 |
+
static_assert(LENGTH > 0);
|
29 |
+
T retval = input[LENGTH - 1];
|
30 |
+
#pragma unroll
|
31 |
+
for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
|
32 |
+
return retval;
|
33 |
+
}
|
34 |
+
|
35 |
+
/**
|
36 |
+
* Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
37 |
+
*/
|
38 |
+
template <
|
39 |
+
int LENGTH,
|
40 |
+
typename T,
|
41 |
+
typename ScanOp>
|
42 |
+
__device__ __forceinline__ T ThreadReverseScanInclusive(
|
43 |
+
const T (&input)[LENGTH],
|
44 |
+
T (&output)[LENGTH],
|
45 |
+
ScanOp scan_op,
|
46 |
+
const T postfix)
|
47 |
+
{
|
48 |
+
T inclusive = postfix;
|
49 |
+
#pragma unroll
|
50 |
+
for (int i = LENGTH - 1; i >= 0; --i) {
|
51 |
+
inclusive = scan_op(inclusive, input[i]);
|
52 |
+
output[i] = inclusive;
|
53 |
+
}
|
54 |
+
return inclusive;
|
55 |
+
}
|
56 |
+
|
57 |
+
/**
|
58 |
+
* Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
|
59 |
+
*/
|
60 |
+
template <
|
61 |
+
int LENGTH,
|
62 |
+
typename T,
|
63 |
+
typename ScanOp>
|
64 |
+
__device__ __forceinline__ T ThreadReverseScanExclusive(
|
65 |
+
const T (&input)[LENGTH],
|
66 |
+
T (&output)[LENGTH],
|
67 |
+
ScanOp scan_op,
|
68 |
+
const T postfix)
|
69 |
+
{
|
70 |
+
// Careful, output maybe be aliased to input
|
71 |
+
T exclusive = postfix;
|
72 |
+
T inclusive;
|
73 |
+
#pragma unroll
|
74 |
+
for (int i = LENGTH - 1; i >= 0; --i) {
|
75 |
+
inclusive = scan_op(exclusive, input[i]);
|
76 |
+
output[i] = exclusive;
|
77 |
+
exclusive = inclusive;
|
78 |
+
}
|
79 |
+
return inclusive;
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
/**
|
84 |
+
* \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
|
85 |
+
*
|
86 |
+
* LOGICAL_WARP_THREADS must be a power-of-two
|
87 |
+
*/
|
88 |
+
template <
|
89 |
+
typename T, ///< Data type being scanned
|
90 |
+
int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
|
91 |
+
>
|
92 |
+
struct WarpReverseScan {
|
93 |
+
//---------------------------------------------------------------------
|
94 |
+
// Constants and type definitions
|
95 |
+
//---------------------------------------------------------------------
|
96 |
+
|
97 |
+
/// Whether the logical warp size and the PTX warp size coincide
|
98 |
+
|
99 |
+
// In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
|
100 |
+
// While in cub, it's defined as a macro that takes a redundant unused argument.
|
101 |
+
#ifndef USE_ROCM
|
102 |
+
#define WARP_THREADS CUB_WARP_THREADS(0)
|
103 |
+
#else
|
104 |
+
#define WARP_THREADS HIPCUB_WARP_THREADS
|
105 |
+
#endif
|
106 |
+
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
|
107 |
+
/// The number of warp scan steps
|
108 |
+
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
|
109 |
+
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
|
110 |
+
|
111 |
+
|
112 |
+
//---------------------------------------------------------------------
|
113 |
+
// Thread fields
|
114 |
+
//---------------------------------------------------------------------
|
115 |
+
|
116 |
+
/// Lane index in logical warp
|
117 |
+
unsigned int lane_id;
|
118 |
+
|
119 |
+
/// Logical warp index in 32-thread physical warp
|
120 |
+
unsigned int warp_id;
|
121 |
+
|
122 |
+
/// 32-thread physical warp member mask of logical warp
|
123 |
+
unsigned int member_mask;
|
124 |
+
|
125 |
+
//---------------------------------------------------------------------
|
126 |
+
// Construction
|
127 |
+
//---------------------------------------------------------------------
|
128 |
+
|
129 |
+
/// Constructor
|
130 |
+
explicit __device__ __forceinline__
|
131 |
+
WarpReverseScan()
|
132 |
+
: lane_id(cub::LaneId())
|
133 |
+
, warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
|
134 |
+
, member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
|
135 |
+
{
|
136 |
+
if (!IS_ARCH_WARP) {
|
137 |
+
lane_id = lane_id % LOGICAL_WARP_THREADS;
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
/// Broadcast
|
143 |
+
__device__ __forceinline__ T Broadcast(
|
144 |
+
T input, ///< [in] The value to broadcast
|
145 |
+
int src_lane) ///< [in] Which warp lane is to do the broadcasting
|
146 |
+
{
|
147 |
+
return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
|
148 |
+
}
|
149 |
+
|
150 |
+
|
151 |
+
/// Inclusive scan
|
152 |
+
template <typename ScanOpT>
|
153 |
+
__device__ __forceinline__ void InclusiveReverseScan(
|
154 |
+
T input, ///< [in] Calling thread's input item.
|
155 |
+
T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
156 |
+
ScanOpT scan_op) ///< [in] Binary scan operator
|
157 |
+
{
|
158 |
+
inclusive_output = input;
|
159 |
+
#pragma unroll
|
160 |
+
for (int STEP = 0; STEP < STEPS; STEP++) {
|
161 |
+
int offset = 1 << STEP;
|
162 |
+
T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
163 |
+
inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
|
164 |
+
);
|
165 |
+
// Perform scan op if from a valid peer
|
166 |
+
inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
|
167 |
+
? inclusive_output : scan_op(temp, inclusive_output);
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
/// Exclusive scan
|
172 |
+
// Get exclusive from inclusive
|
173 |
+
template <typename ScanOpT>
|
174 |
+
__device__ __forceinline__ void ExclusiveReverseScan(
|
175 |
+
T input, ///< [in] Calling thread's input item.
|
176 |
+
T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
|
177 |
+
ScanOpT scan_op, ///< [in] Binary scan operator
|
178 |
+
T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
|
179 |
+
{
|
180 |
+
T inclusive_output;
|
181 |
+
InclusiveReverseScan(input, inclusive_output, scan_op);
|
182 |
+
warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
|
183 |
+
// initial value unknown
|
184 |
+
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
185 |
+
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
186 |
+
);
|
187 |
+
}
|
188 |
+
|
189 |
+
/**
|
190 |
+
* \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
|
191 |
+
*/
|
192 |
+
template <typename ScanOpT>
|
193 |
+
__device__ __forceinline__ void ReverseScan(
|
194 |
+
T input, ///< [in] Calling thread's input item.
|
195 |
+
T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
|
196 |
+
T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
|
197 |
+
ScanOpT scan_op) ///< [in] Binary scan operator
|
198 |
+
{
|
199 |
+
InclusiveReverseScan(input, inclusive_output, scan_op);
|
200 |
+
// initial value unknown
|
201 |
+
exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
|
202 |
+
inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
|
203 |
+
);
|
204 |
+
}
|
205 |
+
|
206 |
+
};
|
207 |
+
|
208 |
+
/**
|
209 |
+
* \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
|
210 |
+
*/
|
211 |
+
template <
|
212 |
+
typename T, ///< Data type being scanned
|
213 |
+
int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
|
214 |
+
bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
|
215 |
+
>
|
216 |
+
struct BlockReverseScan {
|
217 |
+
//---------------------------------------------------------------------
|
218 |
+
// Types and constants
|
219 |
+
//---------------------------------------------------------------------
|
220 |
+
|
221 |
+
/// Constants
|
222 |
+
/// The thread block size in threads
|
223 |
+
static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
|
224 |
+
|
225 |
+
/// Layout type for padded thread block raking grid
|
226 |
+
using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
|
227 |
+
// The number of reduction elements is not a multiple of the number of raking threads for now
|
228 |
+
static_assert(BlockRakingLayout::UNGUARDED);
|
229 |
+
|
230 |
+
/// Number of raking threads
|
231 |
+
static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
|
232 |
+
/// Number of raking elements per warp synchronous raking thread
|
233 |
+
static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
|
234 |
+
/// Cooperative work can be entirely warp synchronous
|
235 |
+
static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
|
236 |
+
|
237 |
+
/// WarpReverseScan utility type
|
238 |
+
using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
|
239 |
+
|
240 |
+
/// Shared memory storage layout type
|
241 |
+
struct _TempStorage {
|
242 |
+
typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
|
243 |
+
};
|
244 |
+
|
245 |
+
|
246 |
+
/// Alias wrapper allowing storage to be unioned
|
247 |
+
struct TempStorage : cub::Uninitialized<_TempStorage> {};
|
248 |
+
|
249 |
+
|
250 |
+
//---------------------------------------------------------------------
|
251 |
+
// Per-thread fields
|
252 |
+
//---------------------------------------------------------------------
|
253 |
+
|
254 |
+
// Thread fields
|
255 |
+
_TempStorage &temp_storage;
|
256 |
+
unsigned int linear_tid;
|
257 |
+
T cached_segment[SEGMENT_LENGTH];
|
258 |
+
|
259 |
+
|
260 |
+
//---------------------------------------------------------------------
|
261 |
+
// Utility methods
|
262 |
+
//---------------------------------------------------------------------
|
263 |
+
|
264 |
+
/// Performs upsweep raking reduction, returning the aggregate
|
265 |
+
template <typename ScanOp>
|
266 |
+
__device__ __forceinline__ T Upsweep(ScanOp scan_op) {
|
267 |
+
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
268 |
+
// Read data into registers
|
269 |
+
#pragma unroll
|
270 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
271 |
+
T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
|
272 |
+
#pragma unroll
|
273 |
+
for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
|
274 |
+
raking_partial = scan_op(raking_partial, cached_segment[i]);
|
275 |
+
}
|
276 |
+
return raking_partial;
|
277 |
+
}
|
278 |
+
|
279 |
+
|
280 |
+
/// Performs exclusive downsweep raking scan
|
281 |
+
template <typename ScanOp>
|
282 |
+
__device__ __forceinline__ void ExclusiveDownsweep(
|
283 |
+
ScanOp scan_op,
|
284 |
+
T raking_partial)
|
285 |
+
{
|
286 |
+
T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
|
287 |
+
// Read data back into registers
|
288 |
+
if (!MEMOIZE) {
|
289 |
+
#pragma unroll
|
290 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
|
291 |
+
}
|
292 |
+
ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
|
293 |
+
// Write data back to smem
|
294 |
+
#pragma unroll
|
295 |
+
for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
|
296 |
+
}
|
297 |
+
|
298 |
+
|
299 |
+
//---------------------------------------------------------------------
|
300 |
+
// Constructors
|
301 |
+
//---------------------------------------------------------------------
|
302 |
+
|
303 |
+
/// Constructor
|
304 |
+
__device__ __forceinline__ BlockReverseScan(
|
305 |
+
TempStorage &temp_storage)
|
306 |
+
:
|
307 |
+
temp_storage(temp_storage.Alias()),
|
308 |
+
linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
|
309 |
+
{}
|
310 |
+
|
311 |
+
|
312 |
+
/// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
313 |
+
template <
|
314 |
+
typename ScanOp,
|
315 |
+
typename BlockPostfixCallbackOp>
|
316 |
+
__device__ __forceinline__ void ExclusiveReverseScan(
|
317 |
+
T input, ///< [in] Calling thread's input item
|
318 |
+
T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
|
319 |
+
ScanOp scan_op, ///< [in] Binary scan operator
|
320 |
+
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
|
321 |
+
{
|
322 |
+
if (WARP_SYNCHRONOUS) {
|
323 |
+
// Short-circuit directly to warp-synchronous scan
|
324 |
+
T block_aggregate;
|
325 |
+
WarpReverseScan warp_scan;
|
326 |
+
warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
|
327 |
+
// Obtain warp-wide postfix in lane0, then broadcast to other lanes
|
328 |
+
T block_postfix = block_postfix_callback_op(block_aggregate);
|
329 |
+
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
330 |
+
exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
|
331 |
+
} else {
|
332 |
+
// Place thread partial into shared memory raking grid
|
333 |
+
T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
|
334 |
+
detail::uninitialized_copy(placement_ptr, input);
|
335 |
+
cub::CTA_SYNC();
|
336 |
+
// Reduce parallelism down to just raking threads
|
337 |
+
if (linear_tid < RAKING_THREADS) {
|
338 |
+
WarpReverseScan warp_scan;
|
339 |
+
// Raking upsweep reduction across shared partials
|
340 |
+
T upsweep_partial = Upsweep(scan_op);
|
341 |
+
// Warp-synchronous scan
|
342 |
+
T exclusive_partial, block_aggregate;
|
343 |
+
warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
|
344 |
+
// Obtain block-wide postfix in lane0, then broadcast to other lanes
|
345 |
+
T block_postfix = block_postfix_callback_op(block_aggregate);
|
346 |
+
block_postfix = warp_scan.Broadcast(block_postfix, 0);
|
347 |
+
// Update postfix with warpscan exclusive partial
|
348 |
+
T downsweep_postfix = linear_tid == RAKING_THREADS - 1
|
349 |
+
? block_postfix : scan_op(block_postfix, exclusive_partial);
|
350 |
+
// Exclusive raking downsweep scan
|
351 |
+
ExclusiveDownsweep(scan_op, downsweep_postfix);
|
352 |
+
}
|
353 |
+
cub::CTA_SYNC();
|
354 |
+
// Grab thread postfix from shared memory
|
355 |
+
exclusive_output = *placement_ptr;
|
356 |
+
|
357 |
+
// // Compute warp scan in each warp.
|
358 |
+
// // The exclusive output from the last lane in each warp is invalid.
|
359 |
+
// T inclusive_output;
|
360 |
+
// WarpReverseScan warp_scan;
|
361 |
+
// warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
|
362 |
+
|
363 |
+
// // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
|
364 |
+
// T block_aggregate;
|
365 |
+
// T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
|
366 |
+
|
367 |
+
// // Apply warp postfix to our lane's partial
|
368 |
+
// if (warp_id != 0) {
|
369 |
+
// exclusive_output = scan_op(warp_postfix, exclusive_output);
|
370 |
+
// if (lane_id == 0) { exclusive_output = warp_postfix; }
|
371 |
+
// }
|
372 |
+
|
373 |
+
// // Use the first warp to determine the thread block postfix, returning the result in lane0
|
374 |
+
// if (warp_id == 0) {
|
375 |
+
// T block_postfix = block_postfix_callback_op(block_aggregate);
|
376 |
+
// if (lane_id == 0) {
|
377 |
+
// // Share the postfix with all threads
|
378 |
+
// detail::uninitialized_copy(&temp_storage.block_postfix,
|
379 |
+
// block_postfix);
|
380 |
+
|
381 |
+
// exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
|
382 |
+
// }
|
383 |
+
// }
|
384 |
+
|
385 |
+
// cub::CTA_SYNC();
|
386 |
+
|
387 |
+
// // Incorporate thread block postfix into outputs
|
388 |
+
// T block_postfix = temp_storage.block_postfix;
|
389 |
+
// if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
|
390 |
+
}
|
391 |
+
}
|
392 |
+
|
393 |
+
|
394 |
+
/**
|
395 |
+
* \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
|
396 |
+
*/
|
397 |
+
template <
|
398 |
+
int ITEMS_PER_THREAD,
|
399 |
+
typename ScanOp,
|
400 |
+
typename BlockPostfixCallbackOp>
|
401 |
+
__device__ __forceinline__ void InclusiveReverseScan(
|
402 |
+
T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
|
403 |
+
T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
|
404 |
+
ScanOp scan_op, ///< [in] Binary scan functor
|
405 |
+
BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
|
406 |
+
{
|
407 |
+
// Reduce consecutive thread items in registers
|
408 |
+
T thread_postfix = ThreadReverseReduce(input, scan_op);
|
409 |
+
// Exclusive thread block-scan
|
410 |
+
ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
|
411 |
+
// Inclusive scan in registers with postfix as seed
|
412 |
+
ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
|
413 |
+
}
|
414 |
+
|
415 |
+
};
|
selective-scan/selective_scan.cpp
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#include <c10/cuda/CUDAGuard.h>
|
6 |
+
#include <c10/cuda/CUDAStream.h>
|
7 |
+
#include <torch/all.h>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
#include "selective_scan.h"
|
11 |
+
|
12 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
13 |
+
|
14 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
15 |
+
if (ITYPE == at::ScalarType::Half) { \
|
16 |
+
using input_t = at::Half; \
|
17 |
+
__VA_ARGS__(); \
|
18 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
19 |
+
using input_t = at::BFloat16; \
|
20 |
+
__VA_ARGS__(); \
|
21 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
22 |
+
using input_t = float; \
|
23 |
+
__VA_ARGS__(); \
|
24 |
+
} else { \
|
25 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
26 |
+
}
|
27 |
+
|
28 |
+
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
29 |
+
if (WTYPE == at::ScalarType::Half) { \
|
30 |
+
using weight_t = at::Half; \
|
31 |
+
__VA_ARGS__(); \
|
32 |
+
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
33 |
+
using weight_t = at::BFloat16; \
|
34 |
+
__VA_ARGS__(); \
|
35 |
+
} else if (WTYPE == at::ScalarType::Float) { \
|
36 |
+
using weight_t = float; \
|
37 |
+
__VA_ARGS__(); \
|
38 |
+
} else { \
|
39 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
40 |
+
}
|
41 |
+
|
42 |
+
#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
|
43 |
+
if (WTYPE == at::ScalarType::Float) { \
|
44 |
+
using weight_t = float; \
|
45 |
+
__VA_ARGS__(); \
|
46 |
+
} else if (WTYPE == at::ScalarType::ComplexFloat) { \
|
47 |
+
using weight_t = c10::complex<float>; \
|
48 |
+
__VA_ARGS__(); \
|
49 |
+
} else { \
|
50 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
51 |
+
}
|
52 |
+
|
53 |
+
template<typename input_t, typename weight_t>
|
54 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
55 |
+
|
56 |
+
template <typename input_t, typename weight_t>
|
57 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream);
|
58 |
+
|
59 |
+
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
60 |
+
// sizes
|
61 |
+
const size_t batch,
|
62 |
+
const size_t dim,
|
63 |
+
const size_t seqlen,
|
64 |
+
const size_t dstate,
|
65 |
+
const size_t n_groups,
|
66 |
+
const size_t n_chunks,
|
67 |
+
const bool is_variable_B,
|
68 |
+
const bool is_variable_C,
|
69 |
+
// device pointers
|
70 |
+
const at::Tensor u,
|
71 |
+
const at::Tensor delta,
|
72 |
+
const at::Tensor A,
|
73 |
+
const at::Tensor B,
|
74 |
+
const at::Tensor C,
|
75 |
+
const at::Tensor out,
|
76 |
+
const at::Tensor z,
|
77 |
+
const at::Tensor out_z,
|
78 |
+
void* D_ptr,
|
79 |
+
void* delta_bias_ptr,
|
80 |
+
void* x_ptr,
|
81 |
+
bool has_z,
|
82 |
+
bool delta_softplus) {
|
83 |
+
|
84 |
+
// Reset the parameters
|
85 |
+
memset(¶ms, 0, sizeof(params));
|
86 |
+
|
87 |
+
params.batch = batch;
|
88 |
+
params.dim = dim;
|
89 |
+
params.seqlen = seqlen;
|
90 |
+
params.dstate = dstate;
|
91 |
+
params.n_groups = n_groups;
|
92 |
+
params.n_chunks = n_chunks;
|
93 |
+
params.dim_ngroups_ratio = dim / n_groups;
|
94 |
+
|
95 |
+
params.delta_softplus = delta_softplus;
|
96 |
+
|
97 |
+
params.is_variable_B = is_variable_B;
|
98 |
+
params.is_variable_C = is_variable_C;
|
99 |
+
|
100 |
+
// Set the pointers and strides.
|
101 |
+
params.u_ptr = u.data_ptr();
|
102 |
+
params.delta_ptr = delta.data_ptr();
|
103 |
+
params.A_ptr = A.data_ptr();
|
104 |
+
params.B_ptr = B.data_ptr();
|
105 |
+
params.C_ptr = C.data_ptr();
|
106 |
+
params.D_ptr = D_ptr;
|
107 |
+
params.delta_bias_ptr = delta_bias_ptr;
|
108 |
+
params.out_ptr = out.data_ptr();
|
109 |
+
params.x_ptr = x_ptr;
|
110 |
+
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
111 |
+
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
112 |
+
// All stride are in elements, not bytes.
|
113 |
+
params.A_d_stride = A.stride(0);
|
114 |
+
params.A_dstate_stride = A.stride(1);
|
115 |
+
if (!is_variable_B) {
|
116 |
+
params.B_d_stride = B.stride(0);
|
117 |
+
} else {
|
118 |
+
params.B_batch_stride = B.stride(0);
|
119 |
+
params.B_group_stride = B.stride(1);
|
120 |
+
}
|
121 |
+
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
122 |
+
if (!is_variable_C) {
|
123 |
+
params.C_d_stride = C.stride(0);
|
124 |
+
} else {
|
125 |
+
params.C_batch_stride = C.stride(0);
|
126 |
+
params.C_group_stride = C.stride(1);
|
127 |
+
}
|
128 |
+
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
129 |
+
params.u_batch_stride = u.stride(0);
|
130 |
+
params.u_d_stride = u.stride(1);
|
131 |
+
params.delta_batch_stride = delta.stride(0);
|
132 |
+
params.delta_d_stride = delta.stride(1);
|
133 |
+
if (has_z) {
|
134 |
+
params.z_batch_stride = z.stride(0);
|
135 |
+
params.z_d_stride = z.stride(1);
|
136 |
+
params.out_z_batch_stride = out_z.stride(0);
|
137 |
+
params.out_z_d_stride = out_z.stride(1);
|
138 |
+
}
|
139 |
+
params.out_batch_stride = out.stride(0);
|
140 |
+
params.out_d_stride = out.stride(1);
|
141 |
+
}
|
142 |
+
|
143 |
+
void set_ssm_params_bwd(SSMParamsBwd ¶ms,
|
144 |
+
// sizes
|
145 |
+
const size_t batch,
|
146 |
+
const size_t dim,
|
147 |
+
const size_t seqlen,
|
148 |
+
const size_t dstate,
|
149 |
+
const size_t n_groups,
|
150 |
+
const size_t n_chunks,
|
151 |
+
const bool is_variable_B,
|
152 |
+
const bool is_variable_C,
|
153 |
+
// device pointers
|
154 |
+
const at::Tensor u,
|
155 |
+
const at::Tensor delta,
|
156 |
+
const at::Tensor A,
|
157 |
+
const at::Tensor B,
|
158 |
+
const at::Tensor C,
|
159 |
+
const at::Tensor z,
|
160 |
+
const at::Tensor out,
|
161 |
+
const at::Tensor out_z,
|
162 |
+
void* D_ptr,
|
163 |
+
void* delta_bias_ptr,
|
164 |
+
void* x_ptr,
|
165 |
+
const at::Tensor dout,
|
166 |
+
const at::Tensor du,
|
167 |
+
const at::Tensor ddelta,
|
168 |
+
const at::Tensor dA,
|
169 |
+
const at::Tensor dB,
|
170 |
+
const at::Tensor dC,
|
171 |
+
const at::Tensor dz,
|
172 |
+
void* dD_ptr,
|
173 |
+
void* ddelta_bias_ptr,
|
174 |
+
bool has_z,
|
175 |
+
bool delta_softplus,
|
176 |
+
bool recompute_out_z) {
|
177 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
|
178 |
+
set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
179 |
+
u, delta, A, B, C, has_z ? out : dout,
|
180 |
+
has_z ? z : dout,
|
181 |
+
// If not recompute_out_z, pass dout instead of out_z.
|
182 |
+
// This won't be used by the bwd kernel
|
183 |
+
recompute_out_z ? out_z : dout,
|
184 |
+
D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
|
185 |
+
if (!recompute_out_z) { params.out_z_ptr = nullptr; }
|
186 |
+
|
187 |
+
// Set the pointers and strides.
|
188 |
+
params.dout_ptr = dout.data_ptr();
|
189 |
+
params.du_ptr = du.data_ptr();
|
190 |
+
params.dA_ptr = dA.data_ptr();
|
191 |
+
params.dB_ptr = dB.data_ptr();
|
192 |
+
params.dC_ptr = dC.data_ptr();
|
193 |
+
params.dD_ptr = dD_ptr;
|
194 |
+
params.ddelta_ptr = ddelta.data_ptr();
|
195 |
+
params.ddelta_bias_ptr = ddelta_bias_ptr;
|
196 |
+
params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
|
197 |
+
// All stride are in elements, not bytes.
|
198 |
+
params.dout_batch_stride = dout.stride(0);
|
199 |
+
params.dout_d_stride = dout.stride(1);
|
200 |
+
params.dA_d_stride = dA.stride(0);
|
201 |
+
params.dA_dstate_stride = dA.stride(1);
|
202 |
+
if (!is_variable_B) {
|
203 |
+
params.dB_d_stride = dB.stride(0);
|
204 |
+
} else {
|
205 |
+
params.dB_batch_stride = dB.stride(0);
|
206 |
+
params.dB_group_stride = dB.stride(1);
|
207 |
+
}
|
208 |
+
params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
|
209 |
+
if (!is_variable_C) {
|
210 |
+
params.dC_d_stride = dC.stride(0);
|
211 |
+
} else {
|
212 |
+
params.dC_batch_stride = dC.stride(0);
|
213 |
+
params.dC_group_stride = dC.stride(1);
|
214 |
+
}
|
215 |
+
params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
|
216 |
+
params.du_batch_stride = du.stride(0);
|
217 |
+
params.du_d_stride = du.stride(1);
|
218 |
+
params.ddelta_batch_stride = ddelta.stride(0);
|
219 |
+
params.ddelta_d_stride = ddelta.stride(1);
|
220 |
+
if (has_z) {
|
221 |
+
params.dz_batch_stride = dz.stride(0);
|
222 |
+
params.dz_d_stride = dz.stride(1);
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
std::vector<at::Tensor>
|
227 |
+
selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
|
228 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
229 |
+
const c10::optional<at::Tensor> &D_,
|
230 |
+
const c10::optional<at::Tensor> &z_,
|
231 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
232 |
+
bool delta_softplus) {
|
233 |
+
auto input_type = u.scalar_type();
|
234 |
+
auto weight_type = A.scalar_type();
|
235 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
236 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
237 |
+
|
238 |
+
const bool is_variable_B = B.dim() >= 3;
|
239 |
+
const bool is_variable_C = C.dim() >= 3;
|
240 |
+
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
241 |
+
|
242 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
243 |
+
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
244 |
+
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
245 |
+
|
246 |
+
TORCH_CHECK(u.is_cuda());
|
247 |
+
TORCH_CHECK(delta.is_cuda());
|
248 |
+
TORCH_CHECK(A.is_cuda());
|
249 |
+
TORCH_CHECK(B.is_cuda());
|
250 |
+
TORCH_CHECK(C.is_cuda());
|
251 |
+
|
252 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
253 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
254 |
+
|
255 |
+
const auto sizes = u.sizes();
|
256 |
+
const int batch_size = sizes[0];
|
257 |
+
const int dim = sizes[1];
|
258 |
+
const int seqlen = sizes[2];
|
259 |
+
const int dstate = A.size(1);
|
260 |
+
const int n_groups = is_variable_B ? B.size(1) : 1;
|
261 |
+
|
262 |
+
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
263 |
+
|
264 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
265 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
266 |
+
CHECK_SHAPE(A, dim, dstate);
|
267 |
+
if (!is_variable_B) {
|
268 |
+
CHECK_SHAPE(B, dim, dstate);
|
269 |
+
} else {
|
270 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
271 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
272 |
+
}
|
273 |
+
if (!is_variable_C) {
|
274 |
+
CHECK_SHAPE(C, dim, dstate);
|
275 |
+
} else {
|
276 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
277 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
278 |
+
}
|
279 |
+
|
280 |
+
if (D_.has_value()) {
|
281 |
+
auto D = D_.value();
|
282 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
283 |
+
TORCH_CHECK(D.is_cuda());
|
284 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
285 |
+
CHECK_SHAPE(D, dim);
|
286 |
+
}
|
287 |
+
|
288 |
+
if (delta_bias_.has_value()) {
|
289 |
+
auto delta_bias = delta_bias_.value();
|
290 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
291 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
292 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
293 |
+
CHECK_SHAPE(delta_bias, dim);
|
294 |
+
}
|
295 |
+
|
296 |
+
at::Tensor z, out_z;
|
297 |
+
const bool has_z = z_.has_value();
|
298 |
+
if (has_z) {
|
299 |
+
z = z_.value();
|
300 |
+
TORCH_CHECK(z.scalar_type() == input_type);
|
301 |
+
TORCH_CHECK(z.is_cuda());
|
302 |
+
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
303 |
+
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
304 |
+
out_z = torch::empty_like(z);
|
305 |
+
}
|
306 |
+
|
307 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
308 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
309 |
+
// at::Tensor out = torch::empty_like(u);
|
310 |
+
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
311 |
+
at::Tensor out = torch::empty_like(delta);
|
312 |
+
at::Tensor x;
|
313 |
+
x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
|
314 |
+
|
315 |
+
SSMParamsBase params;
|
316 |
+
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
317 |
+
u, delta, A, B, C, out, z, out_z,
|
318 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
319 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
320 |
+
x.data_ptr(),
|
321 |
+
has_z,
|
322 |
+
delta_softplus);
|
323 |
+
|
324 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
325 |
+
// Cast to char to avoid compiler warning about narrowing
|
326 |
+
at::cuda::CUDAGuard device_guard{u.device()};
|
327 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
328 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
329 |
+
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
|
330 |
+
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
331 |
+
});
|
332 |
+
});
|
333 |
+
std::vector<at::Tensor> result = {out, x};
|
334 |
+
if (has_z) { result.push_back(out_z); }
|
335 |
+
return result;
|
336 |
+
}
|
337 |
+
|
338 |
+
std::vector<at::Tensor>
|
339 |
+
selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
|
340 |
+
const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
|
341 |
+
const c10::optional<at::Tensor> &D_,
|
342 |
+
const c10::optional<at::Tensor> &z_,
|
343 |
+
const c10::optional<at::Tensor> &delta_bias_,
|
344 |
+
const at::Tensor &dout,
|
345 |
+
const c10::optional<at::Tensor> &x_,
|
346 |
+
const c10::optional<at::Tensor> &out_,
|
347 |
+
c10::optional<at::Tensor> dz_,
|
348 |
+
bool delta_softplus,
|
349 |
+
bool recompute_out_z) {
|
350 |
+
auto input_type = u.scalar_type();
|
351 |
+
auto weight_type = A.scalar_type();
|
352 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
353 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
|
354 |
+
|
355 |
+
const bool is_variable_B = B.dim() >= 3;
|
356 |
+
const bool is_variable_C = C.dim() >= 3;
|
357 |
+
const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
|
358 |
+
|
359 |
+
TORCH_CHECK(delta.scalar_type() == input_type);
|
360 |
+
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
361 |
+
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
362 |
+
TORCH_CHECK(dout.scalar_type() == input_type);
|
363 |
+
|
364 |
+
TORCH_CHECK(u.is_cuda());
|
365 |
+
TORCH_CHECK(delta.is_cuda());
|
366 |
+
TORCH_CHECK(A.is_cuda());
|
367 |
+
TORCH_CHECK(B.is_cuda());
|
368 |
+
TORCH_CHECK(C.is_cuda());
|
369 |
+
TORCH_CHECK(dout.is_cuda());
|
370 |
+
|
371 |
+
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
372 |
+
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
373 |
+
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
|
374 |
+
|
375 |
+
const auto sizes = u.sizes();
|
376 |
+
const int batch_size = sizes[0];
|
377 |
+
const int dim = sizes[1];
|
378 |
+
const int seqlen = sizes[2];
|
379 |
+
const int dstate = A.size(1);
|
380 |
+
const int n_groups = is_variable_B ? B.size(1) : 1;
|
381 |
+
|
382 |
+
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
383 |
+
|
384 |
+
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
385 |
+
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
386 |
+
CHECK_SHAPE(A, dim, dstate);
|
387 |
+
if (!is_variable_B) {
|
388 |
+
CHECK_SHAPE(B, dim, dstate);
|
389 |
+
} else {
|
390 |
+
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
|
391 |
+
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
392 |
+
}
|
393 |
+
if (!is_variable_C) {
|
394 |
+
CHECK_SHAPE(C, dim, dstate);
|
395 |
+
} else {
|
396 |
+
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
|
397 |
+
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
398 |
+
}
|
399 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
400 |
+
|
401 |
+
if (D_.has_value()) {
|
402 |
+
auto D = D_.value();
|
403 |
+
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
404 |
+
TORCH_CHECK(D.is_cuda());
|
405 |
+
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
406 |
+
CHECK_SHAPE(D, dim);
|
407 |
+
}
|
408 |
+
|
409 |
+
if (delta_bias_.has_value()) {
|
410 |
+
auto delta_bias = delta_bias_.value();
|
411 |
+
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
412 |
+
TORCH_CHECK(delta_bias.is_cuda());
|
413 |
+
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
414 |
+
CHECK_SHAPE(delta_bias, dim);
|
415 |
+
}
|
416 |
+
|
417 |
+
at::Tensor z, out, dz, out_z;
|
418 |
+
const bool has_z = z_.has_value();
|
419 |
+
if (has_z) {
|
420 |
+
z = z_.value();
|
421 |
+
TORCH_CHECK(z.scalar_type() == input_type);
|
422 |
+
TORCH_CHECK(z.is_cuda());
|
423 |
+
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
424 |
+
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
425 |
+
|
426 |
+
TORCH_CHECK(out_.has_value());
|
427 |
+
out = out_.value();
|
428 |
+
TORCH_CHECK(out.scalar_type() == input_type);
|
429 |
+
TORCH_CHECK(out.is_cuda());
|
430 |
+
TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
|
431 |
+
CHECK_SHAPE(out, batch_size, dim, seqlen);
|
432 |
+
|
433 |
+
if (dz_.has_value()) {
|
434 |
+
dz = dz_.value();
|
435 |
+
TORCH_CHECK(dz.scalar_type() == input_type);
|
436 |
+
TORCH_CHECK(dz.is_cuda());
|
437 |
+
TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
|
438 |
+
CHECK_SHAPE(dz, batch_size, dim, seqlen);
|
439 |
+
} else {
|
440 |
+
dz = torch::empty_like(z);
|
441 |
+
}
|
442 |
+
if (recompute_out_z) {
|
443 |
+
out_z = torch::empty_like(out);
|
444 |
+
}
|
445 |
+
}
|
446 |
+
|
447 |
+
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
448 |
+
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
449 |
+
if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
|
450 |
+
if (x_.has_value()) {
|
451 |
+
auto x = x_.value();
|
452 |
+
TORCH_CHECK(x.scalar_type() == weight_type);
|
453 |
+
TORCH_CHECK(x.is_cuda());
|
454 |
+
TORCH_CHECK(x.is_contiguous());
|
455 |
+
CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
|
456 |
+
}
|
457 |
+
|
458 |
+
at::Tensor du = torch::empty_like(u);
|
459 |
+
at::Tensor ddelta = torch::empty_like(delta);
|
460 |
+
at::Tensor dA = torch::zeros_like(A);
|
461 |
+
at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
|
462 |
+
at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
|
463 |
+
at::Tensor dD;
|
464 |
+
if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
|
465 |
+
at::Tensor ddelta_bias;
|
466 |
+
if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
|
467 |
+
|
468 |
+
SSMParamsBwd params;
|
469 |
+
set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
470 |
+
u, delta, A, B, C, z, out, out_z,
|
471 |
+
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
472 |
+
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
473 |
+
x_.has_value() ? x_.value().data_ptr() : nullptr,
|
474 |
+
dout, du, ddelta, dA, dB, dC, dz,
|
475 |
+
D_.has_value() ? dD.data_ptr() : nullptr,
|
476 |
+
delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
|
477 |
+
has_z, delta_softplus, recompute_out_z);
|
478 |
+
|
479 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
480 |
+
// Cast to char to avoid compiler warning about narrowing
|
481 |
+
at::cuda::CUDAGuard device_guard{u.device()};
|
482 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
483 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
|
484 |
+
DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
|
485 |
+
selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
|
486 |
+
});
|
487 |
+
});
|
488 |
+
std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
|
489 |
+
if (has_z) { result.push_back(dz); }
|
490 |
+
if (recompute_out_z) { result.push_back(out_z); }
|
491 |
+
return result;
|
492 |
+
}
|
493 |
+
|
494 |
+
//PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
495 |
+
// m.def("fwd", &selective_scan_fwd, "Selective scan forward");
|
496 |
+
// m.def("bwd", &selective_scan_bwd, "Selective scan backward");
|
497 |
+
//}
|
selective-scan/selective_scan.h
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
8 |
+
|
9 |
+
struct SSMScanParamsBase {
|
10 |
+
using index_t = uint32_t;
|
11 |
+
|
12 |
+
int batch, seqlen, n_chunks;
|
13 |
+
index_t a_batch_stride;
|
14 |
+
index_t b_batch_stride;
|
15 |
+
index_t out_batch_stride;
|
16 |
+
|
17 |
+
// Common data pointers.
|
18 |
+
void *__restrict__ a_ptr;
|
19 |
+
void *__restrict__ b_ptr;
|
20 |
+
void *__restrict__ out_ptr;
|
21 |
+
void *__restrict__ x_ptr;
|
22 |
+
};
|
23 |
+
|
24 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
25 |
+
|
26 |
+
struct SSMParamsBase {
|
27 |
+
using index_t = uint32_t;
|
28 |
+
|
29 |
+
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
30 |
+
int dim_ngroups_ratio;
|
31 |
+
bool is_variable_B;
|
32 |
+
bool is_variable_C;
|
33 |
+
|
34 |
+
bool delta_softplus;
|
35 |
+
|
36 |
+
index_t A_d_stride;
|
37 |
+
index_t A_dstate_stride;
|
38 |
+
index_t B_batch_stride;
|
39 |
+
index_t B_d_stride;
|
40 |
+
index_t B_dstate_stride;
|
41 |
+
index_t B_group_stride;
|
42 |
+
index_t C_batch_stride;
|
43 |
+
index_t C_d_stride;
|
44 |
+
index_t C_dstate_stride;
|
45 |
+
index_t C_group_stride;
|
46 |
+
index_t u_batch_stride;
|
47 |
+
index_t u_d_stride;
|
48 |
+
index_t delta_batch_stride;
|
49 |
+
index_t delta_d_stride;
|
50 |
+
index_t z_batch_stride;
|
51 |
+
index_t z_d_stride;
|
52 |
+
index_t out_batch_stride;
|
53 |
+
index_t out_d_stride;
|
54 |
+
index_t out_z_batch_stride;
|
55 |
+
index_t out_z_d_stride;
|
56 |
+
|
57 |
+
// Common data pointers.
|
58 |
+
void *__restrict__ A_ptr;
|
59 |
+
void *__restrict__ B_ptr;
|
60 |
+
void *__restrict__ C_ptr;
|
61 |
+
void *__restrict__ D_ptr;
|
62 |
+
void *__restrict__ u_ptr;
|
63 |
+
void *__restrict__ delta_ptr;
|
64 |
+
void *__restrict__ delta_bias_ptr;
|
65 |
+
void *__restrict__ out_ptr;
|
66 |
+
void *__restrict__ x_ptr;
|
67 |
+
void *__restrict__ z_ptr;
|
68 |
+
void *__restrict__ out_z_ptr;
|
69 |
+
};
|
70 |
+
|
71 |
+
struct SSMParamsBwd: public SSMParamsBase {
|
72 |
+
index_t dout_batch_stride;
|
73 |
+
index_t dout_d_stride;
|
74 |
+
index_t dA_d_stride;
|
75 |
+
index_t dA_dstate_stride;
|
76 |
+
index_t dB_batch_stride;
|
77 |
+
index_t dB_group_stride;
|
78 |
+
index_t dB_d_stride;
|
79 |
+
index_t dB_dstate_stride;
|
80 |
+
index_t dC_batch_stride;
|
81 |
+
index_t dC_group_stride;
|
82 |
+
index_t dC_d_stride;
|
83 |
+
index_t dC_dstate_stride;
|
84 |
+
index_t du_batch_stride;
|
85 |
+
index_t du_d_stride;
|
86 |
+
index_t dz_batch_stride;
|
87 |
+
index_t dz_d_stride;
|
88 |
+
index_t ddelta_batch_stride;
|
89 |
+
index_t ddelta_d_stride;
|
90 |
+
|
91 |
+
// Common data pointers.
|
92 |
+
void *__restrict__ dout_ptr;
|
93 |
+
void *__restrict__ dA_ptr;
|
94 |
+
void *__restrict__ dB_ptr;
|
95 |
+
void *__restrict__ dC_ptr;
|
96 |
+
void *__restrict__ dD_ptr;
|
97 |
+
void *__restrict__ du_ptr;
|
98 |
+
void *__restrict__ dz_ptr;
|
99 |
+
void *__restrict__ ddelta_ptr;
|
100 |
+
void *__restrict__ ddelta_bias_ptr;
|
101 |
+
};
|
selective-scan/selective_scan_bwd_bf16_complex.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_bwd_bf16_real.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_bwd_fp16_complex.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_bwd_fp16_real.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_bwd_fp32_complex.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_bwd_fp32_real.cu
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_bwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_bwd_kernel.cuh
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <c10/util/BFloat16.h>
|
8 |
+
#include <c10/util/Half.h>
|
9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
10 |
+
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
|
11 |
+
|
12 |
+
#ifndef USE_ROCM
|
13 |
+
#include <cub/block/block_load.cuh>
|
14 |
+
#include <cub/block/block_store.cuh>
|
15 |
+
#include <cub/block/block_scan.cuh>
|
16 |
+
#include <cub/block/block_reduce.cuh>
|
17 |
+
#else
|
18 |
+
#include <hipcub/hipcub.hpp>
|
19 |
+
namespace cub = hipcub;
|
20 |
+
#endif
|
21 |
+
|
22 |
+
#include "selective_scan.h"
|
23 |
+
#include "selective_scan_common.h"
|
24 |
+
#include "reverse_scan.cuh"
|
25 |
+
#include "static_switch.h"
|
26 |
+
|
27 |
+
template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
|
28 |
+
template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
|
29 |
+
template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
|
30 |
+
|
31 |
+
template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
|
32 |
+
bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
|
33 |
+
struct Selective_Scan_bwd_kernel_traits {
|
34 |
+
static_assert(kNItems_ % 4 == 0);
|
35 |
+
using input_t = input_t_;
|
36 |
+
using weight_t = weight_t_;
|
37 |
+
static constexpr int kNThreads = kNThreads_;
|
38 |
+
static constexpr int kNItems = kNItems_;
|
39 |
+
static constexpr int kNBytes = sizeof(input_t);
|
40 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
41 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
42 |
+
static_assert(kNItems % kNElts == 0);
|
43 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
44 |
+
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
|
45 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
46 |
+
static constexpr bool kIsVariableB = kIsVariableB_;
|
47 |
+
static constexpr bool kIsVariableC = kIsVariableC_;
|
48 |
+
static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
|
49 |
+
static constexpr bool kHasZ = kHasZ_;
|
50 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
|
51 |
+
// For complex this would lead to massive register spilling, so we keep it at 2.
|
52 |
+
static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
|
53 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
54 |
+
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
|
55 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
56 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
57 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
58 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
59 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
60 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
61 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
62 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
63 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
64 |
+
using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
|
65 |
+
using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
|
66 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
67 |
+
using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
|
68 |
+
using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
|
69 |
+
|
70 |
+
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
71 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
72 |
+
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
73 |
+
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
74 |
+
sizeof(typename BlockStoreT::TempStorage),
|
75 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
76 |
+
static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
|
77 |
+
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
|
78 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
|
79 |
+
};
|
80 |
+
|
81 |
+
template<typename Ktraits>
|
82 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
83 |
+
void selective_scan_bwd_kernel(SSMParamsBwd params) {
|
84 |
+
constexpr bool kIsComplex = Ktraits::kIsComplex;
|
85 |
+
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
86 |
+
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
87 |
+
constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
|
88 |
+
constexpr bool kHasZ = Ktraits::kHasZ;
|
89 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
90 |
+
constexpr int kNItems = Ktraits::kNItems;
|
91 |
+
using input_t = typename Ktraits::input_t;
|
92 |
+
using weight_t = typename Ktraits::weight_t;
|
93 |
+
using scan_t = typename Ktraits::scan_t;
|
94 |
+
|
95 |
+
// Shared memory.
|
96 |
+
extern __shared__ char smem_[];
|
97 |
+
// cast to lvalue reference of expected type
|
98 |
+
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
99 |
+
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
100 |
+
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
101 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
102 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
103 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
104 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
105 |
+
auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
106 |
+
auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
|
107 |
+
auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
|
108 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
|
109 |
+
auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
|
110 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
|
111 |
+
auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
|
112 |
+
weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
|
113 |
+
scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
|
114 |
+
weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
|
115 |
+
weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
|
116 |
+
|
117 |
+
const int batch_id = blockIdx.x;
|
118 |
+
const int dim_id = blockIdx.y;
|
119 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
120 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
121 |
+
+ dim_id * params.u_d_stride;
|
122 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
123 |
+
+ dim_id * params.delta_d_stride;
|
124 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
125 |
+
+ dim_id * params.dout_d_stride;
|
126 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
|
127 |
+
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
|
128 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
129 |
+
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
|
130 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
131 |
+
weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
|
132 |
+
weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
|
133 |
+
+ (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
|
134 |
+
weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
|
135 |
+
+ (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
|
136 |
+
float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
|
137 |
+
float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
|
138 |
+
float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
|
139 |
+
float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
|
140 |
+
scan_t *x = params.x_ptr == nullptr
|
141 |
+
? nullptr
|
142 |
+
: reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
|
143 |
+
float dD_val = 0;
|
144 |
+
float ddelta_bias_val = 0;
|
145 |
+
|
146 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
147 |
+
u += (params.n_chunks - 1) * kChunkSize;
|
148 |
+
delta += (params.n_chunks - 1) * kChunkSize;
|
149 |
+
dout += (params.n_chunks - 1) * kChunkSize;
|
150 |
+
Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
|
151 |
+
Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
|
152 |
+
for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
|
153 |
+
input_t u_vals[kNItems];
|
154 |
+
input_t delta_vals_load[kNItems];
|
155 |
+
input_t dout_vals_load[kNItems];
|
156 |
+
__syncthreads();
|
157 |
+
load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
158 |
+
u -= kChunkSize;
|
159 |
+
__syncthreads();
|
160 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
161 |
+
// Will reload delta at the same location if kDeltaSoftplus
|
162 |
+
if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
|
163 |
+
__syncthreads();
|
164 |
+
load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
165 |
+
dout -= kChunkSize;
|
166 |
+
|
167 |
+
float dout_vals[kNItems], delta_vals[kNItems];
|
168 |
+
#pragma unroll
|
169 |
+
for (int i = 0; i < kNItems; ++i) {
|
170 |
+
dout_vals[i] = float(dout_vals_load[i]);
|
171 |
+
delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
|
172 |
+
if constexpr (kDeltaSoftplus) {
|
173 |
+
delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
|
174 |
+
}
|
175 |
+
}
|
176 |
+
|
177 |
+
if constexpr (kHasZ) {
|
178 |
+
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
179 |
+
+ dim_id * params.z_d_stride + chunk * kChunkSize;
|
180 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
181 |
+
+ dim_id * params.out_d_stride + chunk * kChunkSize;
|
182 |
+
input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
|
183 |
+
+ dim_id * params.dz_d_stride + chunk * kChunkSize;
|
184 |
+
input_t z_vals[kNItems], out_vals[kNItems];
|
185 |
+
__syncthreads();
|
186 |
+
load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
187 |
+
__syncthreads();
|
188 |
+
load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
189 |
+
float dz_vals[kNItems], z_silu_vals[kNItems];
|
190 |
+
#pragma unroll
|
191 |
+
for (int i = 0; i < kNItems; ++i) {
|
192 |
+
float z_val = z_vals[i];
|
193 |
+
float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
|
194 |
+
z_silu_vals[i] = z_val * z_sigmoid_val;
|
195 |
+
dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
|
196 |
+
* (1.0f + z_val * (1.0f - z_sigmoid_val));
|
197 |
+
dout_vals[i] *= z_silu_vals[i];
|
198 |
+
}
|
199 |
+
__syncthreads();
|
200 |
+
store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
201 |
+
if (params.out_z_ptr != nullptr) { // Recompute and store out_z
|
202 |
+
float out_z_vals[kNItems];
|
203 |
+
#pragma unroll
|
204 |
+
for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
|
205 |
+
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
|
206 |
+
// printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
|
207 |
+
// }
|
208 |
+
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
209 |
+
+ dim_id * params.out_z_d_stride + chunk * kChunkSize;
|
210 |
+
__syncthreads();
|
211 |
+
store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
212 |
+
}
|
213 |
+
}
|
214 |
+
|
215 |
+
float du_vals[kNItems];
|
216 |
+
#pragma unroll
|
217 |
+
for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
|
218 |
+
#pragma unroll
|
219 |
+
for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
|
220 |
+
|
221 |
+
float ddelta_vals[kNItems] = {0};
|
222 |
+
__syncthreads();
|
223 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
224 |
+
const weight_t A_val = A[state_idx * params.A_dstate_stride];
|
225 |
+
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
226 |
+
weight_t A_scaled;
|
227 |
+
constexpr float kLog2e = M_LOG2E;
|
228 |
+
if constexpr (!kIsComplex) {
|
229 |
+
A_scaled = A_val * kLog2e;
|
230 |
+
} else {
|
231 |
+
A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
|
232 |
+
}
|
233 |
+
weight_t B_val, C_val;
|
234 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
235 |
+
if constexpr (!kIsVariableB) {
|
236 |
+
B_val = B[state_idx * params.B_dstate_stride];
|
237 |
+
} else {
|
238 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
239 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
240 |
+
}
|
241 |
+
if constexpr (!kIsVariableC) {
|
242 |
+
C_val = C[state_idx * params.C_dstate_stride];
|
243 |
+
} else {
|
244 |
+
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
245 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
246 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
247 |
+
}
|
248 |
+
// const weight_t A_val = smem_a[state_idx];
|
249 |
+
scan_t thread_data[kNItems], thread_reverse_data[kNItems];
|
250 |
+
if constexpr (!kIsComplex) {
|
251 |
+
#pragma unroll
|
252 |
+
for (int i = 0; i < kNItems; ++i) {
|
253 |
+
const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
|
254 |
+
thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
255 |
+
if (i == 0) {
|
256 |
+
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
|
257 |
+
} else {
|
258 |
+
thread_reverse_data[i - 1].x = delta_a_exp;
|
259 |
+
}
|
260 |
+
thread_reverse_data[i].y = dout_vals[i] *
|
261 |
+
(!kIsVariableC
|
262 |
+
? (!kIsVariableB ? B_val * C_val : C_val)
|
263 |
+
: (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
|
264 |
+
}
|
265 |
+
__syncthreads();
|
266 |
+
thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
|
267 |
+
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
|
268 |
+
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
|
269 |
+
// Initialize running total
|
270 |
+
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
|
271 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
272 |
+
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
273 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
274 |
+
);
|
275 |
+
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
|
276 |
+
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
277 |
+
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
278 |
+
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
279 |
+
);
|
280 |
+
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
281 |
+
weight_t dA_val = 0, dBC_val = 0;
|
282 |
+
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
283 |
+
#pragma unroll
|
284 |
+
for (int i = 0; i < kNItems; ++i) {
|
285 |
+
const float dx = thread_reverse_data[i].y;
|
286 |
+
const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
|
287 |
+
du_vals[i] += ddelta_u * delta_vals[i];
|
288 |
+
const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
|
289 |
+
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
|
290 |
+
dA_val += dx * delta_vals[i] * a;
|
291 |
+
if constexpr (!kIsVariableB || !kIsVariableC) {
|
292 |
+
if constexpr (!kIsVariableB) { // dBC_val is dB_val
|
293 |
+
dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
|
294 |
+
} else { // dBC_val is dC_val
|
295 |
+
dBC_val += dout_vals[i] * thread_data[i].y;
|
296 |
+
}
|
297 |
+
}
|
298 |
+
if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
|
299 |
+
if constexpr (kIsVariableC) {
|
300 |
+
dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
|
301 |
+
}
|
302 |
+
}
|
303 |
+
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
304 |
+
if constexpr (kIsVariableB || kIsVariableC) {
|
305 |
+
if constexpr (kIsVariableB) {
|
306 |
+
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
|
307 |
+
}
|
308 |
+
if constexpr (kIsVariableC) {
|
309 |
+
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
|
310 |
+
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
|
311 |
+
}
|
312 |
+
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
|
313 |
+
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
314 |
+
weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
|
315 |
+
#pragma unroll
|
316 |
+
for (int i = 0; i < kNItems; ++i) {
|
317 |
+
if (i * kNThreads < seqlen_remaining) {
|
318 |
+
if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
|
319 |
+
if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
|
320 |
+
}
|
321 |
+
}
|
322 |
+
}
|
323 |
+
if constexpr (!kIsVariableB || !kIsVariableC) {
|
324 |
+
float2 dA_dBC_val = make_float2(dA_val, dBC_val);
|
325 |
+
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
|
326 |
+
dA_val = dA_dBC_val.x;
|
327 |
+
if (threadIdx.x == 0) {
|
328 |
+
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
|
329 |
+
}
|
330 |
+
} else {
|
331 |
+
dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
|
332 |
+
}
|
333 |
+
if (threadIdx.x == 0) {
|
334 |
+
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
335 |
+
}
|
336 |
+
} else {
|
337 |
+
#pragma unroll
|
338 |
+
for (int i = 0; i < kNItems; ++i) {
|
339 |
+
// Pytorch's implementation of complex exp (which calls thrust) is very slow
|
340 |
+
complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
|
341 |
+
weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
|
342 |
+
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
|
343 |
+
if (i == 0) {
|
344 |
+
smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
|
345 |
+
} else {
|
346 |
+
thread_reverse_data[i - 1].x = delta_a_exp.real_;
|
347 |
+
thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
|
348 |
+
}
|
349 |
+
complex_t dout_BC = 2 * dout_vals[i]
|
350 |
+
* conj(!kIsVariableC
|
351 |
+
? (!kIsVariableB ? B_val * C_val : C_val)
|
352 |
+
: (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
|
353 |
+
thread_reverse_data[i].z = dout_BC.real_;
|
354 |
+
thread_reverse_data[i].w = dout_BC.imag_;
|
355 |
+
}
|
356 |
+
__syncthreads();
|
357 |
+
complex_t delta_a_exp = threadIdx.x == kNThreads - 1
|
358 |
+
? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
|
359 |
+
: smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
|
360 |
+
thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
|
361 |
+
thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
|
362 |
+
// Initialize running total
|
363 |
+
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
364 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
365 |
+
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
366 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
367 |
+
);
|
368 |
+
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
369 |
+
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
|
370 |
+
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
|
371 |
+
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
|
372 |
+
);
|
373 |
+
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
|
374 |
+
weight_t dA_val = 0, dBC_val = 0;
|
375 |
+
weight_t dB_vals[kNItems], dC_vals[kNItems];
|
376 |
+
#pragma unroll
|
377 |
+
for (int i = 0; i < kNItems; ++i) {
|
378 |
+
complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
|
379 |
+
complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
|
380 |
+
float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
|
381 |
+
if constexpr (!kIsVariableB || !kIsVariableC) {
|
382 |
+
if constexpr (!kIsVariableB) { // dBC_val is dB_val
|
383 |
+
dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
|
384 |
+
} else { // dBC_val is dC_val
|
385 |
+
dBC_val += (2 * dout_vals[i]) * conj(x);
|
386 |
+
}
|
387 |
+
}
|
388 |
+
const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
|
389 |
+
du_vals[i] += ddelta_u * delta_vals[i];
|
390 |
+
ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
|
391 |
+
dA_val += delta_vals[i] * dx * a_conj;
|
392 |
+
if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
|
393 |
+
if constexpr (kIsVariableC) {
|
394 |
+
dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
|
395 |
+
}
|
396 |
+
}
|
397 |
+
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
|
398 |
+
if constexpr (kIsVariableB || kIsVariableC) {
|
399 |
+
float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
|
400 |
+
if constexpr (kIsVariableB) {
|
401 |
+
#pragma unroll
|
402 |
+
for (int i = 0; i < kNItems; ++i) {
|
403 |
+
dB_vals_f[i * 2] = dB_vals[i].real_;
|
404 |
+
dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
|
405 |
+
}
|
406 |
+
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
|
407 |
+
}
|
408 |
+
if constexpr (kIsVariableC) {
|
409 |
+
#pragma unroll
|
410 |
+
for (int i = 0; i < kNItems; ++i) {
|
411 |
+
dC_vals_f[i * 2] = dC_vals[i].real_;
|
412 |
+
dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
|
413 |
+
}
|
414 |
+
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
|
415 |
+
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
|
416 |
+
}
|
417 |
+
const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
|
418 |
+
float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
|
419 |
+
float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
|
420 |
+
#pragma unroll
|
421 |
+
for (int i = 0; i < kNItems * 2; ++i) {
|
422 |
+
if (i * kNThreads < seqlen_remaining) {
|
423 |
+
if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
|
424 |
+
if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
|
425 |
+
}
|
426 |
+
}
|
427 |
+
}
|
428 |
+
if constexpr (!kIsVariableB || !kIsVariableC) {
|
429 |
+
float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
|
430 |
+
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
|
431 |
+
dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
|
432 |
+
dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
|
433 |
+
if (threadIdx.x == 0) {
|
434 |
+
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
|
435 |
+
}
|
436 |
+
} else {
|
437 |
+
dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
|
438 |
+
}
|
439 |
+
if (threadIdx.x == 0) {
|
440 |
+
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
|
441 |
+
}
|
442 |
+
}
|
443 |
+
}
|
444 |
+
|
445 |
+
if constexpr (kDeltaSoftplus) {
|
446 |
+
__syncthreads();
|
447 |
+
input_t delta_vals_load[kNItems];
|
448 |
+
load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
|
449 |
+
delta -= kChunkSize;
|
450 |
+
#pragma unroll
|
451 |
+
for (int i = 0; i < kNItems; ++i) {
|
452 |
+
float delta_val = float(delta_vals_load[i]) + delta_bias;
|
453 |
+
float delta_val_neg_exp = expf(-delta_val);
|
454 |
+
ddelta_vals[i] = delta_val <= 20.f
|
455 |
+
? ddelta_vals[i] / (1.f + delta_val_neg_exp)
|
456 |
+
: ddelta_vals[i];
|
457 |
+
}
|
458 |
+
}
|
459 |
+
for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
|
460 |
+
|
461 |
+
input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
|
462 |
+
+ dim_id * params.du_d_stride + chunk * kChunkSize;
|
463 |
+
input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
|
464 |
+
+ dim_id * params.ddelta_d_stride + chunk * kChunkSize;
|
465 |
+
__syncthreads();
|
466 |
+
store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
467 |
+
__syncthreads();
|
468 |
+
store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
|
469 |
+
|
470 |
+
Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
|
471 |
+
Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
|
472 |
+
}
|
473 |
+
if (params.dD_ptr != nullptr) {
|
474 |
+
dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
|
475 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
|
476 |
+
}
|
477 |
+
if (params.ddelta_bias_ptr != nullptr) {
|
478 |
+
__syncthreads();
|
479 |
+
ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
|
480 |
+
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
|
481 |
+
}
|
482 |
+
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
483 |
+
gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
|
484 |
+
weight_t dBC_val;
|
485 |
+
if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
|
486 |
+
if constexpr (!kIsVariableB) {
|
487 |
+
gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
|
488 |
+
!kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
|
489 |
+
}
|
490 |
+
if constexpr (!kIsVariableC) {
|
491 |
+
gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
|
492 |
+
!kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
|
493 |
+
}
|
494 |
+
}
|
495 |
+
}
|
496 |
+
|
497 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
498 |
+
void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
499 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
500 |
+
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
|
501 |
+
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
|
502 |
+
BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
|
503 |
+
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
504 |
+
using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
|
505 |
+
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
|
506 |
+
// TODO: check this
|
507 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
|
508 |
+
|
509 |
+
dim3 grid(params.batch, params.dim);
|
510 |
+
|
511 |
+
auto kernel = &selective_scan_bwd_kernel<Ktraits>;
|
512 |
+
|
513 |
+
if (kSmemSize >= 48 * 1024) {
|
514 |
+
|
515 |
+
#ifndef USE_ROCM
|
516 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
517 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
518 |
+
#else
|
519 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
520 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
521 |
+
std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
522 |
+
#endif
|
523 |
+
|
524 |
+
}
|
525 |
+
|
526 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
527 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
528 |
+
});
|
529 |
+
});
|
530 |
+
});
|
531 |
+
});
|
532 |
+
});
|
533 |
+
}
|
534 |
+
|
535 |
+
template<typename input_t, typename weight_t>
|
536 |
+
void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) {
|
537 |
+
|
538 |
+
#ifndef USE_ROCM
|
539 |
+
if (params.seqlen <= 128) {
|
540 |
+
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
|
541 |
+
} else if (params.seqlen <= 256) {
|
542 |
+
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
|
543 |
+
} else if (params.seqlen <= 512) {
|
544 |
+
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
|
545 |
+
} else if (params.seqlen <= 1024) {
|
546 |
+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
547 |
+
} else {
|
548 |
+
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
549 |
+
}
|
550 |
+
#else
|
551 |
+
if (params.seqlen <= 256) {
|
552 |
+
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
|
553 |
+
} else if (params.seqlen <= 512) {
|
554 |
+
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
|
555 |
+
} else if (params.seqlen <= 1024) {
|
556 |
+
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
|
557 |
+
} else {
|
558 |
+
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
|
559 |
+
}
|
560 |
+
#endif
|
561 |
+
}
|
selective-scan/selective_scan_common.h
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#ifndef USE_ROCM
|
8 |
+
#include <cuda_bf16.h>
|
9 |
+
#else
|
10 |
+
#include <hip/hip_bf16.h>
|
11 |
+
#endif
|
12 |
+
#include <cuda_fp16.h>
|
13 |
+
#include <c10/util/complex.h> // For scalar_value_type
|
14 |
+
|
15 |
+
|
16 |
+
#ifndef USE_ROCM
|
17 |
+
|
18 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
19 |
+
{
|
20 |
+
return std::max(ilist);
|
21 |
+
}
|
22 |
+
|
23 |
+
template<typename T>
|
24 |
+
constexpr T constexpr_min(T a, T b) {
|
25 |
+
return std::min(a, b);
|
26 |
+
}
|
27 |
+
|
28 |
+
#else
|
29 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
30 |
+
{
|
31 |
+
return *std::max_element(ilist.begin(), ilist.end());
|
32 |
+
}
|
33 |
+
|
34 |
+
template<typename T>
|
35 |
+
constexpr T constexpr_min(T a, T b) {
|
36 |
+
return a < b ? a : b;
|
37 |
+
}
|
38 |
+
#endif
|
39 |
+
|
40 |
+
|
41 |
+
#define MAX_DSTATE 256
|
42 |
+
|
43 |
+
using complex_t = c10::complex<float>;
|
44 |
+
|
45 |
+
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
46 |
+
return {a.x + b.x, a.y + b.y};
|
47 |
+
}
|
48 |
+
|
49 |
+
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
50 |
+
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
51 |
+
}
|
52 |
+
|
53 |
+
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
54 |
+
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
55 |
+
}
|
56 |
+
|
57 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
58 |
+
|
59 |
+
template<int BYTES> struct BytesToType {};
|
60 |
+
|
61 |
+
template<> struct BytesToType<16> {
|
62 |
+
using Type = uint4;
|
63 |
+
static_assert(sizeof(Type) == 16);
|
64 |
+
};
|
65 |
+
|
66 |
+
template<> struct BytesToType<8> {
|
67 |
+
using Type = uint64_t;
|
68 |
+
static_assert(sizeof(Type) == 8);
|
69 |
+
};
|
70 |
+
|
71 |
+
template<> struct BytesToType<4> {
|
72 |
+
using Type = uint32_t;
|
73 |
+
static_assert(sizeof(Type) == 4);
|
74 |
+
};
|
75 |
+
|
76 |
+
template<> struct BytesToType<2> {
|
77 |
+
using Type = uint16_t;
|
78 |
+
static_assert(sizeof(Type) == 2);
|
79 |
+
};
|
80 |
+
|
81 |
+
template<> struct BytesToType<1> {
|
82 |
+
using Type = uint8_t;
|
83 |
+
static_assert(sizeof(Type) == 1);
|
84 |
+
};
|
85 |
+
|
86 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
87 |
+
|
88 |
+
template<typename scalar_t, int N>
|
89 |
+
struct Converter{
|
90 |
+
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
91 |
+
#pragma unroll
|
92 |
+
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
93 |
+
}
|
94 |
+
};
|
95 |
+
|
96 |
+
template<int N>
|
97 |
+
struct Converter<at::Half, N>{
|
98 |
+
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
99 |
+
static_assert(N % 2 == 0);
|
100 |
+
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
101 |
+
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
102 |
+
#pragma unroll
|
103 |
+
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
104 |
+
}
|
105 |
+
};
|
106 |
+
|
107 |
+
#if __CUDA_ARCH__ >= 800
|
108 |
+
template<int N>
|
109 |
+
struct Converter<at::BFloat16, N>{
|
110 |
+
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
111 |
+
static_assert(N % 2 == 0);
|
112 |
+
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
113 |
+
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
114 |
+
#pragma unroll
|
115 |
+
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
116 |
+
}
|
117 |
+
};
|
118 |
+
#endif
|
119 |
+
|
120 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
121 |
+
|
122 |
+
// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
|
123 |
+
// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
|
124 |
+
__device__ __forceinline__ complex_t cexp2f(complex_t z) {
|
125 |
+
float t = exp2f(z.real_);
|
126 |
+
float c, s;
|
127 |
+
sincosf(z.imag_, &s, &c);
|
128 |
+
return complex_t(c * t, s * t);
|
129 |
+
}
|
130 |
+
|
131 |
+
__device__ __forceinline__ complex_t cexpf(complex_t z) {
|
132 |
+
float t = expf(z.real_);
|
133 |
+
float c, s;
|
134 |
+
sincosf(z.imag_, &s, &c);
|
135 |
+
return complex_t(c * t, s * t);
|
136 |
+
}
|
137 |
+
|
138 |
+
template<typename scalar_t> struct SSMScanOp;
|
139 |
+
|
140 |
+
template<>
|
141 |
+
struct SSMScanOp<float> {
|
142 |
+
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
143 |
+
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
144 |
+
}
|
145 |
+
};
|
146 |
+
|
147 |
+
template<>
|
148 |
+
struct SSMScanOp<complex_t> {
|
149 |
+
__device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
|
150 |
+
complex_t a0 = complex_t(ab0.x, ab0.y);
|
151 |
+
complex_t b0 = complex_t(ab0.z, ab0.w);
|
152 |
+
complex_t a1 = complex_t(ab1.x, ab1.y);
|
153 |
+
complex_t b1 = complex_t(ab1.z, ab1.w);
|
154 |
+
complex_t out_a = a1 * a0;
|
155 |
+
complex_t out_b = a1 * b0 + b1;
|
156 |
+
return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
|
157 |
+
}
|
158 |
+
};
|
159 |
+
|
160 |
+
// A stateful callback functor that maintains a running prefix to be applied
|
161 |
+
// during consecutive scan operations.
|
162 |
+
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
163 |
+
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
164 |
+
scan_t running_prefix;
|
165 |
+
// Constructor
|
166 |
+
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
167 |
+
// Callback operator to be entered by the first warp of threads in the block.
|
168 |
+
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
169 |
+
__device__ scan_t operator()(scan_t block_aggregate) {
|
170 |
+
scan_t old_prefix = running_prefix;
|
171 |
+
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
172 |
+
return old_prefix;
|
173 |
+
}
|
174 |
+
};
|
175 |
+
|
176 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
177 |
+
|
178 |
+
template<typename Ktraits>
|
179 |
+
inline __device__ void load_input(typename Ktraits::input_t *u,
|
180 |
+
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
181 |
+
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
182 |
+
int seqlen) {
|
183 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
184 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
185 |
+
using vec_t = typename Ktraits::vec_t;
|
186 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
187 |
+
reinterpret_cast<vec_t*>(u),
|
188 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
189 |
+
#ifdef USE_ROCM
|
190 |
+
, Ktraits::kNThreads * Ktraits::kNLoads
|
191 |
+
#endif
|
192 |
+
|
193 |
+
);
|
194 |
+
} else {
|
195 |
+
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
196 |
+
}
|
197 |
+
}
|
198 |
+
|
199 |
+
template<typename Ktraits>
|
200 |
+
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
201 |
+
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
202 |
+
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
203 |
+
int seqlen) {
|
204 |
+
constexpr int kNItems = Ktraits::kNItems;
|
205 |
+
if constexpr (!Ktraits::kIsComplex) {
|
206 |
+
typename Ktraits::input_t B_vals_load[kNItems];
|
207 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
208 |
+
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
209 |
+
using vec_t = typename Ktraits::vec_t;
|
210 |
+
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
211 |
+
reinterpret_cast<vec_t*>(Bvar),
|
212 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
213 |
+
);
|
214 |
+
} else {
|
215 |
+
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
216 |
+
}
|
217 |
+
// #pragma unroll
|
218 |
+
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
219 |
+
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
220 |
+
} else {
|
221 |
+
typename Ktraits::input_t B_vals_load[kNItems * 2];
|
222 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
223 |
+
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
224 |
+
using vec_t = typename Ktraits::vec_t;
|
225 |
+
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
226 |
+
reinterpret_cast<vec_t*>(Bvar),
|
227 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
|
228 |
+
);
|
229 |
+
} else {
|
230 |
+
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
231 |
+
}
|
232 |
+
#pragma unroll
|
233 |
+
for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
|
234 |
+
}
|
235 |
+
}
|
236 |
+
|
237 |
+
template<typename Ktraits>
|
238 |
+
inline __device__ void store_output(typename Ktraits::input_t *out,
|
239 |
+
const float (&out_vals)[Ktraits::kNItems],
|
240 |
+
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
241 |
+
int seqlen) {
|
242 |
+
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
243 |
+
#pragma unroll
|
244 |
+
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
245 |
+
if constexpr (Ktraits::kIsEvenLen) {
|
246 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
247 |
+
using vec_t = typename Ktraits::vec_t;
|
248 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
249 |
+
reinterpret_cast<vec_t*>(out),
|
250 |
+
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
251 |
+
);
|
252 |
+
} else {
|
253 |
+
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
254 |
+
}
|
255 |
+
}
|
selective-scan/selective_scan_fwd_bf16.cu
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_fwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
10 |
+
template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_fwd_fp16.cu
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_fwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
10 |
+
template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_fwd_fp32.cu
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
// Split into multiple files to compile in paralell
|
6 |
+
|
7 |
+
#include "selective_scan_fwd_kernel.cuh"
|
8 |
+
|
9 |
+
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
10 |
+
template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase ¶ms, cudaStream_t stream);
|
selective-scan/selective_scan_fwd_kernel.cuh
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2023, Tri Dao.
|
3 |
+
******************************************************************************/
|
4 |
+
|
5 |
+
#pragma once
|
6 |
+
|
7 |
+
#include <c10/util/BFloat16.h>
|
8 |
+
#include <c10/util/Half.h>
|
9 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
10 |
+
|
11 |
+
#ifndef USE_ROCM
|
12 |
+
#include <cub/block/block_load.cuh>
|
13 |
+
#include <cub/block/block_store.cuh>
|
14 |
+
#include <cub/block/block_scan.cuh>
|
15 |
+
#else
|
16 |
+
#include <hipcub/hipcub.hpp>
|
17 |
+
namespace cub = hipcub;
|
18 |
+
#endif
|
19 |
+
|
20 |
+
#include "selective_scan.h"
|
21 |
+
#include "selective_scan_common.h"
|
22 |
+
#include "static_switch.h"
|
23 |
+
|
24 |
+
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
25 |
+
bool kIsVariableB_, bool kIsVariableC_,
|
26 |
+
bool kHasZ_, typename input_t_, typename weight_t_>
|
27 |
+
struct Selective_Scan_fwd_kernel_traits {
|
28 |
+
static_assert(kNItems_ % 4 == 0);
|
29 |
+
using input_t = input_t_;
|
30 |
+
using weight_t = weight_t_;
|
31 |
+
static constexpr int kNThreads = kNThreads_;
|
32 |
+
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
33 |
+
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
34 |
+
static constexpr int kNItems = kNItems_;
|
35 |
+
static constexpr int kNRows = kNRows_;
|
36 |
+
static constexpr int kNBytes = sizeof(input_t);
|
37 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
38 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
39 |
+
static_assert(kNItems % kNElts == 0);
|
40 |
+
static constexpr int kNLoads = kNItems / kNElts;
|
41 |
+
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
|
42 |
+
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
43 |
+
static constexpr bool kIsVariableB = kIsVariableB_;
|
44 |
+
static constexpr bool kIsVariableC = kIsVariableC_;
|
45 |
+
static constexpr bool kHasZ = kHasZ_;
|
46 |
+
|
47 |
+
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
48 |
+
|
49 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
50 |
+
using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
|
51 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
52 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
53 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
54 |
+
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
55 |
+
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
|
56 |
+
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
57 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
58 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
59 |
+
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
60 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
61 |
+
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
62 |
+
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
63 |
+
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
64 |
+
sizeof(typename BlockLoadVecT::TempStorage),
|
65 |
+
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
66 |
+
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
67 |
+
sizeof(typename BlockStoreT::TempStorage),
|
68 |
+
sizeof(typename BlockStoreVecT::TempStorage)});
|
69 |
+
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
70 |
+
};
|
71 |
+
|
72 |
+
template<typename Ktraits>
|
73 |
+
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
74 |
+
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
75 |
+
constexpr bool kIsComplex = Ktraits::kIsComplex;
|
76 |
+
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
77 |
+
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
78 |
+
constexpr bool kHasZ = Ktraits::kHasZ;
|
79 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
80 |
+
constexpr int kNItems = Ktraits::kNItems;
|
81 |
+
constexpr int kNRows = Ktraits::kNRows;
|
82 |
+
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
83 |
+
using input_t = typename Ktraits::input_t;
|
84 |
+
using weight_t = typename Ktraits::weight_t;
|
85 |
+
using scan_t = typename Ktraits::scan_t;
|
86 |
+
|
87 |
+
// Shared memory.
|
88 |
+
extern __shared__ char smem_[];
|
89 |
+
// cast to lvalue reference of expected type
|
90 |
+
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
91 |
+
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
92 |
+
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
93 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
94 |
+
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
95 |
+
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
96 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
97 |
+
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
98 |
+
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
99 |
+
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
100 |
+
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
101 |
+
|
102 |
+
const int batch_id = blockIdx.x;
|
103 |
+
const int dim_id = blockIdx.y;
|
104 |
+
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
105 |
+
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
106 |
+
+ dim_id * kNRows * params.u_d_stride;
|
107 |
+
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
108 |
+
+ dim_id * kNRows * params.delta_d_stride;
|
109 |
+
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
110 |
+
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
111 |
+
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
112 |
+
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
113 |
+
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
114 |
+
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
115 |
+
|
116 |
+
float D_val[kNRows] = {0};
|
117 |
+
if (params.D_ptr != nullptr) {
|
118 |
+
#pragma unroll
|
119 |
+
for (int r = 0; r < kNRows; ++r) {
|
120 |
+
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
121 |
+
}
|
122 |
+
}
|
123 |
+
float delta_bias[kNRows] = {0};
|
124 |
+
if (params.delta_bias_ptr != nullptr) {
|
125 |
+
#pragma unroll
|
126 |
+
for (int r = 0; r < kNRows; ++r) {
|
127 |
+
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
128 |
+
}
|
129 |
+
}
|
130 |
+
|
131 |
+
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
132 |
+
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
133 |
+
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
134 |
+
// }
|
135 |
+
|
136 |
+
constexpr int kChunkSize = kNThreads * kNItems;
|
137 |
+
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
138 |
+
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
139 |
+
__syncthreads();
|
140 |
+
#pragma unroll
|
141 |
+
for (int r = 0; r < kNRows; ++r) {
|
142 |
+
if constexpr (!kDirectIO) {
|
143 |
+
if (r > 0) { __syncthreads(); }
|
144 |
+
}
|
145 |
+
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
146 |
+
if constexpr (!kDirectIO) { __syncthreads(); }
|
147 |
+
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
148 |
+
}
|
149 |
+
u += kChunkSize;
|
150 |
+
delta += kChunkSize;
|
151 |
+
|
152 |
+
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
153 |
+
#pragma unroll
|
154 |
+
for (int r = 0; r < kNRows; ++r) {
|
155 |
+
#pragma unroll
|
156 |
+
for (int i = 0; i < kNItems; ++i) {
|
157 |
+
float u_val = float(u_vals[r][i]);
|
158 |
+
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
159 |
+
if (params.delta_softplus) {
|
160 |
+
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
161 |
+
}
|
162 |
+
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
163 |
+
out_vals[r][i] = D_val[r] * u_val;
|
164 |
+
}
|
165 |
+
}
|
166 |
+
|
167 |
+
__syncthreads();
|
168 |
+
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
169 |
+
weight_t A_val[kNRows];
|
170 |
+
#pragma unroll
|
171 |
+
for (int r = 0; r < kNRows; ++r) {
|
172 |
+
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
173 |
+
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
174 |
+
constexpr float kLog2e = M_LOG2E;
|
175 |
+
if constexpr (!kIsComplex) {
|
176 |
+
A_val[r] *= kLog2e;
|
177 |
+
} else {
|
178 |
+
A_val[r].real_ *= kLog2e;
|
179 |
+
}
|
180 |
+
}
|
181 |
+
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
182 |
+
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
183 |
+
// If both B and C vary, this is unused.
|
184 |
+
weight_t BC_val[kNRows];
|
185 |
+
weight_t B_vals[kNItems], C_vals[kNItems];
|
186 |
+
if constexpr (kIsVariableB) {
|
187 |
+
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
188 |
+
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
189 |
+
if constexpr (!kIsVariableC) {
|
190 |
+
#pragma unroll
|
191 |
+
for (int r = 0; r < kNRows; ++r) {
|
192 |
+
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
193 |
+
}
|
194 |
+
}
|
195 |
+
}
|
196 |
+
if constexpr (kIsVariableC) {
|
197 |
+
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
198 |
+
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
199 |
+
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
|
200 |
+
if constexpr (!kIsVariableB) {
|
201 |
+
#pragma unroll
|
202 |
+
for (int r = 0; r < kNRows; ++r) {
|
203 |
+
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
204 |
+
}
|
205 |
+
}
|
206 |
+
}
|
207 |
+
if constexpr (!kIsVariableB && !kIsVariableC) {
|
208 |
+
#pragma unroll
|
209 |
+
for (int r = 0; r < kNRows; ++r) {
|
210 |
+
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
211 |
+
}
|
212 |
+
}
|
213 |
+
|
214 |
+
#pragma unroll
|
215 |
+
for (int r = 0; r < kNRows; ++r) {
|
216 |
+
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
217 |
+
scan_t thread_data[kNItems];
|
218 |
+
#pragma unroll
|
219 |
+
for (int i = 0; i < kNItems; ++i) {
|
220 |
+
if constexpr (!kIsComplex) {
|
221 |
+
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
222 |
+
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
223 |
+
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
224 |
+
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
225 |
+
thread_data[i] = make_float2(1.f, 0.f);
|
226 |
+
}
|
227 |
+
}
|
228 |
+
} else {
|
229 |
+
// Pytorch's implementation of complex exp (which calls thrust) is very slow
|
230 |
+
complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
|
231 |
+
weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
|
232 |
+
thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
|
233 |
+
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
234 |
+
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
235 |
+
thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
|
236 |
+
}
|
237 |
+
}
|
238 |
+
}
|
239 |
+
}
|
240 |
+
// Initialize running total
|
241 |
+
scan_t running_prefix;
|
242 |
+
if constexpr (!kIsComplex) {
|
243 |
+
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
244 |
+
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
|
245 |
+
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
246 |
+
} else {
|
247 |
+
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
|
248 |
+
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
|
249 |
+
}
|
250 |
+
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
251 |
+
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
252 |
+
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
253 |
+
);
|
254 |
+
// There's a syncthreads in the scan op, so we don't need to sync here.
|
255 |
+
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
256 |
+
if (threadIdx.x == 0) {
|
257 |
+
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
258 |
+
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
259 |
+
}
|
260 |
+
#pragma unroll
|
261 |
+
for (int i = 0; i < kNItems; ++i) {
|
262 |
+
const weight_t C_val = !kIsVariableC
|
263 |
+
? BC_val[r]
|
264 |
+
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
265 |
+
if constexpr (!kIsComplex) {
|
266 |
+
out_vals[r][i] += thread_data[i].y * C_val;
|
267 |
+
} else {
|
268 |
+
out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
|
269 |
+
}
|
270 |
+
}
|
271 |
+
}
|
272 |
+
}
|
273 |
+
|
274 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
275 |
+
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
276 |
+
__syncthreads();
|
277 |
+
#pragma unroll
|
278 |
+
for (int r = 0; r < kNRows; ++r) {
|
279 |
+
if constexpr (!kDirectIO) {
|
280 |
+
if (r > 0) { __syncthreads(); }
|
281 |
+
}
|
282 |
+
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
283 |
+
}
|
284 |
+
|
285 |
+
if constexpr (kHasZ) {
|
286 |
+
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
287 |
+
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
288 |
+
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
289 |
+
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
290 |
+
#pragma unroll
|
291 |
+
for (int r = 0; r < kNRows; ++r) {
|
292 |
+
input_t z_vals[kNItems];
|
293 |
+
__syncthreads();
|
294 |
+
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
295 |
+
#pragma unroll
|
296 |
+
for (int i = 0; i < kNItems; ++i) {
|
297 |
+
float z_val = z_vals[i];
|
298 |
+
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
299 |
+
}
|
300 |
+
__syncthreads();
|
301 |
+
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
302 |
+
}
|
303 |
+
}
|
304 |
+
|
305 |
+
Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
|
306 |
+
Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
|
307 |
+
}
|
308 |
+
}
|
309 |
+
|
310 |
+
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
311 |
+
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
312 |
+
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
313 |
+
// processing 1 row.
|
314 |
+
constexpr int kNRows = 1;
|
315 |
+
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
316 |
+
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
|
317 |
+
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
|
318 |
+
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
319 |
+
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
|
320 |
+
|
321 |
+
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
322 |
+
dim3 grid(params.batch, params.dim / kNRows);
|
323 |
+
|
324 |
+
// Had to change this substantially since potentially the hip
|
325 |
+
// interface for setting kernel launch attributes is slightly different from
|
326 |
+
// cuda's. In particualar, it seems to expect a plain const void * pointer.
|
327 |
+
|
328 |
+
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
329 |
+
|
330 |
+
|
331 |
+
if (kSmemSize >= 48 * 1024) {
|
332 |
+
#ifndef USE_ROCM
|
333 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
334 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
335 |
+
#else
|
336 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
337 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
338 |
+
std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
339 |
+
#endif
|
340 |
+
}
|
341 |
+
|
342 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
343 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
344 |
+
});
|
345 |
+
});
|
346 |
+
});
|
347 |
+
});
|
348 |
+
}
|
349 |
+
|
350 |
+
template<typename input_t, typename weight_t>
|
351 |
+
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
352 |
+
|
353 |
+
#ifndef USE_ROCM
|
354 |
+
if (params.seqlen <= 128) {
|
355 |
+
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
356 |
+
} else if (params.seqlen <= 256) {
|
357 |
+
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
358 |
+
} else if (params.seqlen <= 512) {
|
359 |
+
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
360 |
+
} else if (params.seqlen <= 1024) {
|
361 |
+
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
362 |
+
} else {
|
363 |
+
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
364 |
+
}
|
365 |
+
#else
|
366 |
+
if (params.seqlen <= 256) {
|
367 |
+
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
368 |
+
} else if (params.seqlen <= 512) {
|
369 |
+
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
370 |
+
} else if (params.seqlen <= 1024) {
|
371 |
+
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
372 |
+
} else {
|
373 |
+
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
374 |
+
}
|
375 |
+
#endif
|
376 |
+
}
|
selective-scan/static_switch.h
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
2 |
+
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
3 |
+
|
4 |
+
#pragma once
|
5 |
+
|
6 |
+
/// @param COND - a boolean expression to switch by
|
7 |
+
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
8 |
+
/// @param ... - code to execute for true and false
|
9 |
+
///
|
10 |
+
/// Usage:
|
11 |
+
/// ```
|
12 |
+
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
13 |
+
/// some_function<BoolConst>(...);
|
14 |
+
/// });
|
15 |
+
/// ```
|
16 |
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
17 |
+
[&] { \
|
18 |
+
if (COND) { \
|
19 |
+
constexpr bool CONST_NAME = true; \
|
20 |
+
return __VA_ARGS__(); \
|
21 |
+
} else { \
|
22 |
+
constexpr bool CONST_NAME = false; \
|
23 |
+
return __VA_ARGS__(); \
|
24 |
+
} \
|
25 |
+
}()
|
selective-scan/uninitialized_copy.cuh
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/******************************************************************************
|
2 |
+
* Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Redistribution and use in source and binary forms, with or without
|
5 |
+
* modification, are permitted provided that the following conditions are met:
|
6 |
+
* * Redistributions of source code must retain the above copyright
|
7 |
+
* notice, this list of conditions and the following disclaimer.
|
8 |
+
* * Redistributions in binary form must reproduce the above copyright
|
9 |
+
* notice, this list of conditions and the following disclaimer in the
|
10 |
+
* documentation and/or other materials provided with the distribution.
|
11 |
+
* * Neither the name of the NVIDIA CORPORATION nor the
|
12 |
+
* names of its contributors may be used to endorse or promote products
|
13 |
+
* derived from this software without specific prior written permission.
|
14 |
+
*
|
15 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
16 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
17 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
18 |
+
* ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
19 |
+
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
20 |
+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
21 |
+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
22 |
+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
23 |
+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
24 |
+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
25 |
+
*
|
26 |
+
******************************************************************************/
|
27 |
+
|
28 |
+
#pragma once
|
29 |
+
|
30 |
+
#ifndef USE_ROCM
|
31 |
+
#include <cub/config.cuh>
|
32 |
+
|
33 |
+
#include <cuda/std/type_traits>
|
34 |
+
#else
|
35 |
+
#include <hipcub/hipcub.hpp>
|
36 |
+
// Map ::cuda::std to the standard std namespace
|
37 |
+
namespace cuda {
|
38 |
+
namespace std = ::std;
|
39 |
+
}
|
40 |
+
#endif
|
41 |
+
|
42 |
+
|
43 |
+
namespace detail
|
44 |
+
{
|
45 |
+
|
46 |
+
#if defined(_NVHPC_CUDA)
|
47 |
+
template <typename T, typename U>
|
48 |
+
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
49 |
+
{
|
50 |
+
// NVBug 3384810
|
51 |
+
new (ptr) T(::cuda::std::forward<U>(val));
|
52 |
+
}
|
53 |
+
#else
|
54 |
+
template <typename T,
|
55 |
+
typename U,
|
56 |
+
typename ::cuda::std::enable_if<
|
57 |
+
::cuda::std::is_trivially_copyable<T>::value,
|
58 |
+
int
|
59 |
+
>::type = 0>
|
60 |
+
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
61 |
+
{
|
62 |
+
*ptr = ::cuda::std::forward<U>(val);
|
63 |
+
}
|
64 |
+
|
65 |
+
template <typename T,
|
66 |
+
typename U,
|
67 |
+
typename ::cuda::std::enable_if<
|
68 |
+
!::cuda::std::is_trivially_copyable<T>::value,
|
69 |
+
int
|
70 |
+
>::type = 0>
|
71 |
+
__host__ __device__ void uninitialized_copy(T *ptr, U &&val)
|
72 |
+
{
|
73 |
+
new (ptr) T(::cuda::std::forward<U>(val));
|
74 |
+
}
|
75 |
+
#endif
|
76 |
+
|
77 |
+
} // namespace detail
|
tests/ops/test_selective_scan.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2023, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import pytest
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
|
12 |
+
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
|
13 |
+
|
14 |
+
|
15 |
+
# @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
|
16 |
+
@pytest.mark.parametrize('wtype', [torch.float32])
|
17 |
+
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
|
18 |
+
@pytest.mark.parametrize('itype', [torch.float32])
|
19 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
|
20 |
+
@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
|
21 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
22 |
+
# @pytest.mark.parametrize("return_last_state", [False, True])
|
23 |
+
@pytest.mark.parametrize("return_last_state", [True])
|
24 |
+
# @pytest.mark.parametrize('has_delta_bias', [False, True])
|
25 |
+
@pytest.mark.parametrize('has_delta_bias', [True])
|
26 |
+
# @pytest.mark.parametrize('delta_softplus', [False, True])
|
27 |
+
@pytest.mark.parametrize('delta_softplus', [True])
|
28 |
+
# @pytest.mark.parametrize('has_z', [False, True])
|
29 |
+
@pytest.mark.parametrize('has_z', [True])
|
30 |
+
# @pytest.mark.parametrize('has_D', [False, True])
|
31 |
+
@pytest.mark.parametrize('has_D', [True])
|
32 |
+
@pytest.mark.parametrize("varBC_groups", [1, 2])
|
33 |
+
# @pytest.mark.parametrize("varBC_groups", [1])
|
34 |
+
# @pytest.mark.parametrize("is_variable_C", [False, True])
|
35 |
+
@pytest.mark.parametrize("is_variable_C", [True])
|
36 |
+
# @pytest.mark.parametrize("is_variable_B", [False, True])
|
37 |
+
@pytest.mark.parametrize("is_variable_B", [True])
|
38 |
+
def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
|
39 |
+
delta_softplus, return_last_state, seqlen, itype, wtype):
|
40 |
+
if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
|
41 |
+
pytest.skip() # This config is not applicable
|
42 |
+
device = 'cuda'
|
43 |
+
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
44 |
+
if itype == torch.bfloat16:
|
45 |
+
rtol, atol = 3e-2, 5e-2
|
46 |
+
rtolw, atolw = (1e-3, 1e-3)
|
47 |
+
if has_z: # If we have z, the errors on the weights seem higher
|
48 |
+
rtolw = max(rtolw, rtol)
|
49 |
+
atolw = max(atolw, atol)
|
50 |
+
# set seed
|
51 |
+
torch.random.manual_seed(0)
|
52 |
+
batch_size = 2
|
53 |
+
dim = 4
|
54 |
+
dstate = 8
|
55 |
+
is_complex = wtype == torch.complex64
|
56 |
+
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
|
57 |
+
if not is_variable_B:
|
58 |
+
B_shape = (dim, dstate)
|
59 |
+
elif varBC_groups == 1:
|
60 |
+
B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
|
61 |
+
else:
|
62 |
+
B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
|
63 |
+
B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
|
64 |
+
requires_grad=True)
|
65 |
+
if not is_variable_C:
|
66 |
+
C_shape = (dim, dstate)
|
67 |
+
elif varBC_groups == 1:
|
68 |
+
C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
|
69 |
+
else:
|
70 |
+
C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
|
71 |
+
C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
|
72 |
+
requires_grad=True)
|
73 |
+
if has_D:
|
74 |
+
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
75 |
+
else:
|
76 |
+
D = None
|
77 |
+
if has_z:
|
78 |
+
z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
|
79 |
+
else:
|
80 |
+
z = None
|
81 |
+
if has_delta_bias:
|
82 |
+
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
|
83 |
+
else:
|
84 |
+
delta_bias = None
|
85 |
+
u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
|
86 |
+
delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
|
87 |
+
A_ref = A.detach().clone().requires_grad_()
|
88 |
+
B_ref = B.detach().clone().requires_grad_()
|
89 |
+
C_ref = C.detach().clone().requires_grad_()
|
90 |
+
D_ref = D.detach().clone().requires_grad_() if D is not None else None
|
91 |
+
z_ref = z.detach().clone().requires_grad_() if z is not None else None
|
92 |
+
u_ref = u.detach().clone().requires_grad_()
|
93 |
+
delta_ref = delta.detach().clone().requires_grad_()
|
94 |
+
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
|
95 |
+
out, *rest = selective_scan_fn(
|
96 |
+
u, delta, A, B, C, D, z=z,
|
97 |
+
delta_bias=delta_bias, delta_softplus=delta_softplus,
|
98 |
+
return_last_state=return_last_state
|
99 |
+
)
|
100 |
+
if return_last_state:
|
101 |
+
state = rest[0]
|
102 |
+
out_ref, *rest = selective_scan_ref(
|
103 |
+
u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
|
104 |
+
delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
|
105 |
+
return_last_state=return_last_state
|
106 |
+
)
|
107 |
+
if return_last_state:
|
108 |
+
state_ref = rest[0]
|
109 |
+
# dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
110 |
+
# dt_u = delta * u
|
111 |
+
|
112 |
+
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
113 |
+
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
114 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
115 |
+
if return_last_state:
|
116 |
+
print(f'State max diff: {(state - state_ref).abs().max().item()}')
|
117 |
+
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
118 |
+
|
119 |
+
g = torch.randn_like(out)
|
120 |
+
out_ref.backward(g)
|
121 |
+
out.backward(g)
|
122 |
+
|
123 |
+
print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
|
124 |
+
print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
|
125 |
+
print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
|
126 |
+
print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
|
127 |
+
print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
|
128 |
+
if has_D:
|
129 |
+
print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
|
130 |
+
if has_z:
|
131 |
+
print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
|
132 |
+
if has_delta_bias:
|
133 |
+
print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
|
134 |
+
|
135 |
+
assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
|
136 |
+
assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
|
137 |
+
assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
|
138 |
+
assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
|
139 |
+
atol=atolw if not is_variable_B else atol)
|
140 |
+
assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
|
141 |
+
atol=atolw if not is_variable_C else atol)
|
142 |
+
if has_D:
|
143 |
+
assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
|
144 |
+
if has_z:
|
145 |
+
assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
|
146 |
+
if has_delta_bias:
|
147 |
+
assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
|
148 |
+
|
149 |
+
|
150 |
+
@pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
|
151 |
+
# @pytest.mark.parametrize('wtype', [torch.complex64])
|
152 |
+
# @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
|
153 |
+
@pytest.mark.parametrize('itype', [torch.float32])
|
154 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
|
155 |
+
@pytest.mark.parametrize('seqlen', [128])
|
156 |
+
@pytest.mark.parametrize("is_variable_C", [False, True])
|
157 |
+
# @pytest.mark.parametrize("is_variable_C", [False])
|
158 |
+
@pytest.mark.parametrize("is_variable_B", [False, True])
|
159 |
+
# @pytest.mark.parametrize("is_variable_B", [True])
|
160 |
+
def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
|
161 |
+
device = 'cuda'
|
162 |
+
rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
163 |
+
if itype == torch.bfloat16:
|
164 |
+
rtol, atol = 3e-2, 5e-2
|
165 |
+
rtolw, atolw = (1e-3, 1e-3)
|
166 |
+
# If we have z, the errors on the weights seem higher
|
167 |
+
rtolw = max(rtolw, rtol)
|
168 |
+
atolw = max(atolw, atol)
|
169 |
+
# set seed
|
170 |
+
torch.random.manual_seed(0)
|
171 |
+
batch_size = 2
|
172 |
+
dim = 768
|
173 |
+
dstate = 8
|
174 |
+
dt_rank = 48
|
175 |
+
is_complex = wtype == torch.complex64
|
176 |
+
xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
|
177 |
+
conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
|
178 |
+
conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
179 |
+
x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
|
180 |
+
* (1 if not is_complex else 2),
|
181 |
+
dim, device=device, dtype=itype, requires_grad=True)
|
182 |
+
delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
|
183 |
+
out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
|
184 |
+
out_proj_bias = None
|
185 |
+
A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
|
186 |
+
B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
|
187 |
+
if not is_variable_B else None)
|
188 |
+
C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
|
189 |
+
if not is_variable_C else None)
|
190 |
+
D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
191 |
+
delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
|
192 |
+
B_proj_bias = None
|
193 |
+
C_proj_bias = None
|
194 |
+
xz_ref = xz.detach().clone().requires_grad_()
|
195 |
+
conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
|
196 |
+
conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
|
197 |
+
x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
|
198 |
+
delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
|
199 |
+
out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
|
200 |
+
out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
|
201 |
+
if out_proj_bias is not None else None)
|
202 |
+
A_ref = A.detach().clone().requires_grad_()
|
203 |
+
B_ref = B.detach().clone().requires_grad_() if B is not None else None
|
204 |
+
C_ref = C.detach().clone().requires_grad_() if C is not None else None
|
205 |
+
D_ref = D.detach().clone().requires_grad_()
|
206 |
+
delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
|
207 |
+
out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
|
208 |
+
out_proj_weight, out_proj_bias,
|
209 |
+
A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
|
210 |
+
out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
|
211 |
+
delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
|
212 |
+
A_ref, B_ref, C_ref, D_ref,
|
213 |
+
delta_bias=delta_bias_ref, delta_softplus=True)
|
214 |
+
# dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
215 |
+
# dt_u = delta * u
|
216 |
+
|
217 |
+
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
218 |
+
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
219 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
220 |
+
|
221 |
+
g = torch.randn_like(out)
|
222 |
+
out_ref.backward(g)
|
223 |
+
out.backward(g)
|
224 |
+
|
225 |
+
print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
|
226 |
+
print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
|
227 |
+
if not is_variable_B:
|
228 |
+
print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
|
229 |
+
if not is_variable_C:
|
230 |
+
print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
|
231 |
+
print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
|
232 |
+
print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
|
233 |
+
print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
|
234 |
+
print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
|
235 |
+
print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
|
236 |
+
print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
|
237 |
+
print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
|
238 |
+
|
239 |
+
# assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
|
240 |
+
# assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
|
241 |
+
# assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
|
242 |
+
# assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
|
243 |
+
# atol=atolw if not is_variable_B else atol)
|
244 |
+
# assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
|
245 |
+
# atol=atolw if not is_variable_C else atol)
|
246 |
+
# assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
|
247 |
+
# assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
|
tests/ops/triton/test_layernorm_gated.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref
|
11 |
+
|
12 |
+
|
13 |
+
@pytest.mark.parametrize("norm_before_gate", [True, False])
|
14 |
+
# @pytest.mark.parametrize("norm_before_gate", [False])
|
15 |
+
@pytest.mark.parametrize("has_group", [False, True])
|
16 |
+
# @pytest.mark.parametrize("has_group", [False])
|
17 |
+
@pytest.mark.parametrize("is_rms_norm", [False, True])
|
18 |
+
# @pytest.mark.parametrize("is_rms_norm", [True])
|
19 |
+
@pytest.mark.parametrize("has_z", [False, True])
|
20 |
+
# @pytest.mark.parametrize("has_z", [True])
|
21 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
22 |
+
# @pytest.mark.parametrize("has_bias", [False])
|
23 |
+
# @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
|
24 |
+
@pytest.mark.parametrize('dtype', [torch.float16])
|
25 |
+
# @pytest.mark.parametrize("wtype", [torch.float32, torch.float16, torch.bfloat16])
|
26 |
+
@pytest.mark.parametrize("wtype", [torch.float32])
|
27 |
+
@pytest.mark.parametrize('d', [2048, 4096])
|
28 |
+
# @pytest.mark.parametrize('d', [4096])
|
29 |
+
def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate):
|
30 |
+
if not has_z and not norm_before_gate:
|
31 |
+
pytest.skip()
|
32 |
+
if not norm_before_gate and not is_rms_norm: # Reference LN isn't implemented for this case yet
|
33 |
+
pytest.skip()
|
34 |
+
device = 'cuda'
|
35 |
+
rtol, atol = (1e-5, 1e-5) if dtype == torch.float32 else (1e-2, 8e-3)
|
36 |
+
group_size = None if not has_group else 64
|
37 |
+
# set seed
|
38 |
+
torch.random.manual_seed(0)
|
39 |
+
batch = 16
|
40 |
+
seqlen = 1024
|
41 |
+
x = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
|
42 |
+
if has_z:
|
43 |
+
z = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
|
44 |
+
else:
|
45 |
+
z = None
|
46 |
+
weight = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
|
47 |
+
if has_bias:
|
48 |
+
bias = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
|
49 |
+
else:
|
50 |
+
bias = None
|
51 |
+
x_ref = x.detach().clone().requires_grad_()
|
52 |
+
x_pt = x.detach().clone().requires_grad_()
|
53 |
+
z_ref = z.detach().clone().requires_grad_() if z is not None else None
|
54 |
+
z_pt = z.detach().clone().requires_grad_() if z is not None else None
|
55 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
56 |
+
weight_pt = weight.detach().clone().requires_grad_()
|
57 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
58 |
+
bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
|
59 |
+
out = layernorm_fn(x, weight, bias, z=z, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate,
|
60 |
+
is_rms_norm=is_rms_norm)
|
61 |
+
if not is_rms_norm:
|
62 |
+
if not has_group:
|
63 |
+
out_ref = F.layer_norm(x_ref.float(), (d,), weight=weight_ref.float(), bias=bias_ref.float() if bias_ref is not None else None, eps=1e-5)
|
64 |
+
out_pt = F.layer_norm(x_pt.to(wtype), (d,), weight=weight_pt, bias=bias_pt, eps=1e-5)
|
65 |
+
else:
|
66 |
+
out_ref = rearrange(F.layer_norm(rearrange(x_ref, "... (g d) -> ... g d", d=group_size).float(), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_ref.float()
|
67 |
+
if has_bias:
|
68 |
+
out_ref = out_ref + bias_ref.float()
|
69 |
+
out_pt = rearrange(F.layer_norm(rearrange(x_pt, "... (g d) -> ... g d", d=group_size), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_pt
|
70 |
+
if has_bias:
|
71 |
+
out_pt = out_pt + bias_pt
|
72 |
+
if has_z and norm_before_gate:
|
73 |
+
out_ref = out_ref * F.silu(z_ref.float())
|
74 |
+
out_pt = out_pt * F.silu(z_pt)
|
75 |
+
else:
|
76 |
+
out_ref = rms_norm_ref(x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size,
|
77 |
+
norm_before_gate=norm_before_gate)
|
78 |
+
out_pt = rms_norm_ref(x_pt, weight_pt, bias_pt, z=z_pt, eps=1e-5, group_size=group_size,
|
79 |
+
norm_before_gate=norm_before_gate, upcast=False)
|
80 |
+
print(f"Max diff = {(out - out_ref).abs().max().item()}")
|
81 |
+
print(f"Max diff Pytorch = {(out_pt - out_ref).abs().max().item()}")
|
82 |
+
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + atol
|
83 |
+
|
84 |
+
g = torch.randn_like(out)
|
85 |
+
out.backward(g)
|
86 |
+
out_ref.backward(g)
|
87 |
+
out_pt.backward(g)
|
88 |
+
print(f"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}")
|
89 |
+
print(f"Max dx diff Pytorch = {(x_pt.grad - x_ref.grad).abs().max().item()}")
|
90 |
+
if has_z:
|
91 |
+
print(f"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}")
|
92 |
+
print(f"Max dz diff Pytorch = {(z_pt.grad - z_ref.grad).abs().max().item()}")
|
93 |
+
print(f"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}")
|
94 |
+
print(f"Max dw diff Pytorch = {(weight_pt.grad - weight_ref.grad).abs().max().item()}")
|
95 |
+
if has_bias:
|
96 |
+
print(f"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}")
|
97 |
+
print(f"Max db diff Pytorch = {(bias_pt.grad - bias_ref.grad).abs().max().item()}")
|
98 |
+
assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol
|
99 |
+
if has_z:
|
100 |
+
assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol
|
101 |
+
assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * (weight_pt.grad - weight_ref.grad).abs().max().item() + atol
|
102 |
+
if has_bias:
|
103 |
+
assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * (bias_pt.grad - bias_ref.grad).abs().max().item() + atol
|
tests/ops/triton/test_selective_state_update.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2023, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import pytest
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
15 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
16 |
+
@pytest.mark.parametrize("has_z", [False, True])
|
17 |
+
# @pytest.mark.parametrize('has_z', [True])
|
18 |
+
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
19 |
+
# @pytest.mark.parametrize("dstate", [16])
|
20 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
21 |
+
# @pytest.mark.parametrize("dim", [2048])
|
22 |
+
def test_selective_state_update(dim, dstate, has_z, itype):
|
23 |
+
device = "cuda"
|
24 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
25 |
+
if itype == torch.bfloat16:
|
26 |
+
rtol, atol = 1e-2, 5e-2
|
27 |
+
if torch.version.hip:
|
28 |
+
atol *= 2
|
29 |
+
# set seed
|
30 |
+
torch.random.manual_seed(0)
|
31 |
+
batch_size = 2
|
32 |
+
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
|
33 |
+
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
34 |
+
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
35 |
+
dt_bias = torch.rand(dim, device=device) - 4.0
|
36 |
+
A = -torch.rand(dim, dstate, device=device) - 1.0
|
37 |
+
B = torch.randn(batch_size, dstate, device=device)
|
38 |
+
C = torch.randn(batch_size, dstate, device=device)
|
39 |
+
D = torch.randn(dim, device=device)
|
40 |
+
if has_z:
|
41 |
+
z = torch.randn_like(x)
|
42 |
+
else:
|
43 |
+
z = None
|
44 |
+
state_ref = state.detach().clone()
|
45 |
+
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
46 |
+
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
47 |
+
|
48 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
49 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
50 |
+
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
51 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
52 |
+
|
53 |
+
|
54 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
55 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
56 |
+
@pytest.mark.parametrize("has_z", [False, True])
|
57 |
+
# @pytest.mark.parametrize('has_z', [True])
|
58 |
+
@pytest.mark.parametrize("tie_hdim", [False, True])
|
59 |
+
# @pytest.mark.parametrize('tie_hdim', [True])
|
60 |
+
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
61 |
+
# @pytest.mark.parametrize("ngroups", [2])
|
62 |
+
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
63 |
+
# @pytest.mark.parametrize("dstate", [16])
|
64 |
+
@pytest.mark.parametrize("dim", [2048, 4096])
|
65 |
+
# @pytest.mark.parametrize("dim", [2048])
|
66 |
+
def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype):
|
67 |
+
device = "cuda"
|
68 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
69 |
+
if itype == torch.bfloat16:
|
70 |
+
rtol, atol = 1e-2, 1e-1
|
71 |
+
# set seed
|
72 |
+
torch.random.manual_seed(0)
|
73 |
+
batch_size = 2
|
74 |
+
headdim = 64
|
75 |
+
nheads = dim // headdim
|
76 |
+
state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device)
|
77 |
+
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
78 |
+
if not tie_hdim:
|
79 |
+
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
80 |
+
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
81 |
+
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
82 |
+
D = torch.randn(nheads, headdim, device=device)
|
83 |
+
else:
|
84 |
+
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
|
85 |
+
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
|
86 |
+
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
|
87 |
+
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
88 |
+
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
89 |
+
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
90 |
+
if has_z:
|
91 |
+
z = torch.randn_like(x)
|
92 |
+
else:
|
93 |
+
z = None
|
94 |
+
state_ref = state.detach().clone()
|
95 |
+
state_og = state.detach().clone()
|
96 |
+
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
97 |
+
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
98 |
+
|
99 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
100 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
101 |
+
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
102 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
103 |
+
|
104 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
105 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
106 |
+
@pytest.mark.parametrize("has_z", [False, True])
|
107 |
+
# @pytest.mark.parametrize('has_z', [True])
|
108 |
+
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
109 |
+
# @pytest.mark.parametrize("dstate", [16])
|
110 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
111 |
+
# @pytest.mark.parametrize("dim", [2048])
|
112 |
+
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
113 |
+
device = "cuda"
|
114 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
115 |
+
if itype == torch.bfloat16:
|
116 |
+
rtol, atol = 6e-2, 6e-2
|
117 |
+
if torch.version.hip:
|
118 |
+
atol *= 2
|
119 |
+
# set seed
|
120 |
+
torch.random.manual_seed(0)
|
121 |
+
batch_size = 16
|
122 |
+
|
123 |
+
total_entries = 10 * batch_size
|
124 |
+
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
125 |
+
state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)
|
126 |
+
|
127 |
+
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
128 |
+
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
129 |
+
dt_bias = torch.rand(dim, device=device) - 4.0
|
130 |
+
A = -torch.rand(dim, dstate, device=device) - 1.0
|
131 |
+
B = torch.randn(batch_size, dstate, device=device)
|
132 |
+
C = torch.randn(batch_size, dstate, device=device)
|
133 |
+
D = torch.randn(dim, device=device)
|
134 |
+
if has_z:
|
135 |
+
z = torch.randn_like(x)
|
136 |
+
else:
|
137 |
+
z = None
|
138 |
+
state_ref = state[state_indices,:].detach().clone()
|
139 |
+
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z,
|
140 |
+
dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)
|
141 |
+
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
142 |
+
|
143 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
144 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
145 |
+
assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)
|
146 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
147 |
+
|
148 |
+
|
149 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
150 |
+
#@pytest.mark.parametrize('itype', [torch.float32])
|
151 |
+
@pytest.mark.parametrize("has_z", [False, True])
|
152 |
+
# @pytest.mark.parametrize('has_z', [True])
|
153 |
+
@pytest.mark.parametrize("tie_hdim", [False, True])
|
154 |
+
# @pytest.mark.parametrize('tie_hdim', [True])
|
155 |
+
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
156 |
+
# @pytest.mark.parametrize("ngroups", [2])
|
157 |
+
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
158 |
+
# @pytest.mark.parametrize("dstate", [16])
|
159 |
+
@pytest.mark.parametrize("dim", [2048, 4096])
|
160 |
+
# @pytest.mark.parametrize("dim", [2048])
|
161 |
+
def test_selective_state_update_with_heads_with_batch_indices(dim, dstate, ngroups, has_z, tie_hdim, itype):
|
162 |
+
device = "cuda"
|
163 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
164 |
+
if itype == torch.bfloat16:
|
165 |
+
rtol, atol = 1e-1, 1e-1
|
166 |
+
# set seed
|
167 |
+
torch.random.manual_seed(0)
|
168 |
+
batch_size = 16
|
169 |
+
headdim = 64
|
170 |
+
nheads = dim // headdim
|
171 |
+
|
172 |
+
total_entries = 10 * batch_size
|
173 |
+
state = torch.randn(total_entries, nheads, headdim, dstate, dtype=itype, device=device)
|
174 |
+
state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)
|
175 |
+
|
176 |
+
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
177 |
+
if not tie_hdim:
|
178 |
+
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
179 |
+
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
180 |
+
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
181 |
+
D = torch.randn(nheads, headdim, device=device)
|
182 |
+
else:
|
183 |
+
dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
|
184 |
+
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
|
185 |
+
A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
|
186 |
+
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
187 |
+
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
188 |
+
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
189 |
+
if has_z:
|
190 |
+
z = torch.randn_like(x)
|
191 |
+
else:
|
192 |
+
z = None
|
193 |
+
state_ref = state[state_indices,:].detach().clone()
|
194 |
+
state_og = state[state_indices,:].detach().clone()
|
195 |
+
out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)
|
196 |
+
out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
|
197 |
+
|
198 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
199 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
200 |
+
assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)
|
201 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
tests/ops/triton/test_ssd.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
|
11 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
12 |
+
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
|
13 |
+
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
|
14 |
+
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd
|
15 |
+
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
|
16 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan
|
17 |
+
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref
|
18 |
+
|
19 |
+
|
20 |
+
def detach_clone(*args):
|
21 |
+
return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args])
|
22 |
+
|
23 |
+
|
24 |
+
@pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
|
25 |
+
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
26 |
+
@pytest.mark.parametrize('ngroups', [1, 2, 8, "max"])
|
27 |
+
# @pytest.mark.parametrize('ngroups', [1])
|
28 |
+
@pytest.mark.parametrize('chunk_size', [64, 128])
|
29 |
+
# @pytest.mark.parametrize('chunk_size', [128])
|
30 |
+
def test_chunk_state_varlen(chunk_size, ngroups, dtype):
|
31 |
+
device = 'cuda'
|
32 |
+
rtol, atol = (1e-2, 3e-3)
|
33 |
+
# set seed
|
34 |
+
torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64))
|
35 |
+
batch = 300
|
36 |
+
seqlens = torch.randint(1, 200, (batch,), device=device)
|
37 |
+
# batch = 3
|
38 |
+
# seqlens = torch.tensor([201, 56, 5], device=device)
|
39 |
+
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
|
40 |
+
total_seqlen = seqlens.sum().item()
|
41 |
+
seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0)
|
42 |
+
dim = 4096
|
43 |
+
# dim = 64
|
44 |
+
headdim = 64
|
45 |
+
# dim = 32
|
46 |
+
dstate = 32
|
47 |
+
assert dim % headdim == 0
|
48 |
+
nheads = dim // headdim
|
49 |
+
if ngroups == "max":
|
50 |
+
ngroups = nheads
|
51 |
+
assert nheads % ngroups == 0
|
52 |
+
B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5
|
53 |
+
x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device)
|
54 |
+
A = -0.1 * (torch.rand(nheads, device=device))
|
55 |
+
dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4)
|
56 |
+
dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size)
|
57 |
+
chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx)
|
58 |
+
chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
|
59 |
+
seq_idx=seq_idx, chunk_size=chunk_size)
|
60 |
+
chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate)
|
61 |
+
chunk_states = chunk_states.squeeze(0)
|
62 |
+
dA_cumsum = dA_cumsum.squeeze(0)
|
63 |
+
dt_rounded = dt_rounded.squeeze(0)
|
64 |
+
out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states)
|
65 |
+
out_ref = []
|
66 |
+
for b in range(batch):
|
67 |
+
x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
|
68 |
+
B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
|
69 |
+
dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
|
70 |
+
dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size)
|
71 |
+
states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s)
|
72 |
+
_, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1],
|
73 |
+
chunk_size=chunk_size)
|
74 |
+
final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate)
|
75 |
+
out_ref.append(final_states)
|
76 |
+
out_ref = torch.cat(out_ref, dim=0)
|
77 |
+
print(f"Max diff = {(out - out_ref).abs().max().item()}")
|
78 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
tests/test_generation.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
5 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
6 |
+
from mamba_ssm.utils.generation import InferenceParams
|
7 |
+
|
8 |
+
import pytest
|
9 |
+
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
|
12 |
+
|
13 |
+
def test_generation():
|
14 |
+
batch = 3
|
15 |
+
seqlen = 20
|
16 |
+
device = "cuda"
|
17 |
+
dtype = torch.float16
|
18 |
+
|
19 |
+
config = MambaConfig(
|
20 |
+
d_model=1024,
|
21 |
+
n_layer=4,
|
22 |
+
vocab_size=50277,
|
23 |
+
ssm_cfg=dict(layer="Mamba2"),
|
24 |
+
rms_norm=True,
|
25 |
+
residual_in_fp32=True,
|
26 |
+
fused_add_norm=True,
|
27 |
+
pad_vocab_size_multiple=16,
|
28 |
+
)
|
29 |
+
torch.manual_seed(2357)
|
30 |
+
model = MambaLMHeadModel(config, device=device, dtype=dtype)
|
31 |
+
x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long)
|
32 |
+
out_ref = model(x).logits
|
33 |
+
prompt_len = seqlen // 2
|
34 |
+
out = model.generate(
|
35 |
+
input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True,
|
36 |
+
cg=True, # Can turn off CUDA graph for easier debugging
|
37 |
+
# instead of sampling, we take output tokens from x, to get logits for testing
|
38 |
+
# For actual generation, don't pass in teacher_outputs
|
39 |
+
teacher_outputs=x,
|
40 |
+
)
|
41 |
+
out_scores = torch.stack(out.scores, dim=1)
|
42 |
+
print(f"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}")
|
43 |
+
assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2)
|
44 |
+
|
45 |
+
|
46 |
+
def test_generation_varlen():
|
47 |
+
seqlens = [170, 65, 100]
|
48 |
+
genlen = 20
|
49 |
+
total_seqlen = sum(seqlens)
|
50 |
+
device = "cuda"
|
51 |
+
dtype = torch.float16
|
52 |
+
|
53 |
+
config = MambaConfig(
|
54 |
+
d_model=1024,
|
55 |
+
n_layer=4,
|
56 |
+
vocab_size=50277,
|
57 |
+
ssm_cfg=dict(layer="Mamba2"),
|
58 |
+
rms_norm=True,
|
59 |
+
residual_in_fp32=True,
|
60 |
+
fused_add_norm=True,
|
61 |
+
pad_vocab_size_multiple=16,
|
62 |
+
)
|
63 |
+
torch.manual_seed(2357)
|
64 |
+
model = MambaLMHeadModel(config, device=device, dtype=dtype)
|
65 |
+
xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]
|
66 |
+
|
67 |
+
# Reference 1: Forward pass with seq_idx
|
68 |
+
x = torch.cat(xs, dim=1)
|
69 |
+
seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)
|
70 |
+
for i, ids in enumerate(xs)], dim=0).unsqueeze(0)
|
71 |
+
cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
|
72 |
+
out_ref = model(x, seq_idx=seq_idx).logits
|
73 |
+
# Only take the last @genlen logits of each sequence
|
74 |
+
out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
|
75 |
+
for i in range(len(seqlens))], dim=0)
|
76 |
+
|
77 |
+
# Reference 2: Generate the last @genlen tokens of each sequence in a for loop
|
78 |
+
out_loop = []
|
79 |
+
for input_ids in xs:
|
80 |
+
out = model.generate(
|
81 |
+
input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,
|
82 |
+
return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,
|
83 |
+
).scores
|
84 |
+
out_loop.append(torch.stack(out, dim=1))
|
85 |
+
out_loop = torch.cat(out_loop, dim=0)
|
86 |
+
print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")
|
87 |
+
|
88 |
+
# Varlen generation
|
89 |
+
input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)
|
90 |
+
prompt_seqlens = [seqlen - genlen for seqlen in seqlens]
|
91 |
+
cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
|
92 |
+
seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)
|
93 |
+
for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)
|
94 |
+
inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))
|
95 |
+
|
96 |
+
scores, sequences = [], []
|
97 |
+
# Both seq_idx and cu_seqlens must be passed in for varlen generation
|
98 |
+
logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits
|
99 |
+
logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d")
|
100 |
+
scores.append(logits)
|
101 |
+
# In practice we should sample. In this case we take from the teacher_output for testing
|
102 |
+
sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1")
|
103 |
+
sequences.append(sampled_tokens)
|
104 |
+
for i in range(1, genlen):
|
105 |
+
inference_params.seqlen_offset += 1
|
106 |
+
logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits
|
107 |
+
scores.append(logits)
|
108 |
+
# In practice we should sample. In this case we take from the teacher_output for testing
|
109 |
+
sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1")
|
110 |
+
sequences.append(sampled_tokens)
|
111 |
+
out_varlen = torch.cat(scores, dim=1)
|
112 |
+
print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
|
113 |
+
assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()
|
torch-ext/mamba_ssm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "2.2.4"
|
2 |
+
|
3 |
+
from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
4 |
+
from .modules.mamba_simple import Mamba
|
5 |
+
from .modules.mamba2 import Mamba2
|
6 |
+
from .models.mixer_seq_simple import MambaLMHeadModel
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"selective_scan_fn",
|
10 |
+
"mamba_inner_fn",
|
11 |
+
"Mamba",
|
12 |
+
"Mamba2",
|
13 |
+
"MambaLMHeadModel",
|
14 |
+
]
|
torch-ext/mamba_ssm/distributed/__init__.py
ADDED
File without changes
|
torch-ext/mamba_ssm/distributed/distributed_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.distributed import ProcessGroup
|
6 |
+
|
7 |
+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
8 |
+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
9 |
+
# version of PyTorch. The following 4 lines are for backward compatibility with
|
10 |
+
# older PyTorch.
|
11 |
+
if "all_gather_into_tensor" not in dir(torch.distributed):
|
12 |
+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
13 |
+
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
14 |
+
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
15 |
+
|
16 |
+
|
17 |
+
# Raw operation, does not support autograd, but does support async
|
18 |
+
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
19 |
+
world_size = torch.distributed.get_world_size(process_group)
|
20 |
+
output = torch.empty(
|
21 |
+
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
22 |
+
)
|
23 |
+
handle = torch.distributed.all_gather_into_tensor(
|
24 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
25 |
+
)
|
26 |
+
return output, handle
|
27 |
+
|
28 |
+
|
29 |
+
# Raw operation, does not support autograd, but does support async
|
30 |
+
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
31 |
+
world_size = torch.distributed.get_world_size(process_group)
|
32 |
+
assert input_.shape[0] % world_size == 0
|
33 |
+
output = torch.empty(
|
34 |
+
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
35 |
+
)
|
36 |
+
handle = torch.distributed.reduce_scatter_tensor(
|
37 |
+
output, input_.contiguous(), group=process_group, async_op=async_op
|
38 |
+
)
|
39 |
+
return output, handle
|
40 |
+
|
41 |
+
|
42 |
+
# Raw operation, does not support autograd, but does support async
|
43 |
+
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
44 |
+
input_ = input_.contiguous()
|
45 |
+
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
|
46 |
+
return input_, handle
|
47 |
+
|
48 |
+
|
49 |
+
class AllGatherFunc(torch.autograd.Function):
|
50 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
54 |
+
ctx.process_group = process_group
|
55 |
+
output, _ = all_gather_raw(input_, process_group)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output: Tensor):
|
60 |
+
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
61 |
+
return grad_input, None
|
62 |
+
|
63 |
+
|
64 |
+
# Supports autograd, but does not support async
|
65 |
+
all_gather = AllGatherFunc.apply
|
66 |
+
|
67 |
+
|
68 |
+
class ReduceScatterFunc(torch.autograd.Function):
|
69 |
+
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
73 |
+
ctx.process_group = process_group
|
74 |
+
output, _ = reduce_scatter_raw(input_, process_group)
|
75 |
+
return output
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def backward(ctx, grad_output: Tensor):
|
79 |
+
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
80 |
+
return grad_input, None
|
81 |
+
|
82 |
+
|
83 |
+
# Supports autograd, but does not support async
|
84 |
+
reduce_scatter = ReduceScatterFunc.apply
|
85 |
+
|
86 |
+
|
87 |
+
class AllReduceFunc(torch.autograd.Function):
|
88 |
+
"""Gather the input from sequence parallel region and concatenate."""
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
92 |
+
ctx.process_group = process_group
|
93 |
+
output, _ = all_reduce_raw(input_, process_group)
|
94 |
+
return output
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, grad_output: Tensor):
|
98 |
+
return grad_output, None
|
99 |
+
|
100 |
+
|
101 |
+
# Supports autograd, but does not support async
|
102 |
+
all_reduce = AllReduceFunc.apply
|
103 |
+
|
104 |
+
|
105 |
+
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
|
106 |
+
# We want to iterate over parameters with _shared_params=True in the same order,
|
107 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
108 |
+
pamams_shared = {
|
109 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
|
110 |
+
}
|
111 |
+
for _, p in sorted(pamams_shared.items()):
|
112 |
+
with torch.no_grad():
|
113 |
+
# Broadcast needs src to be global rank, not group rank
|
114 |
+
torch.distributed.broadcast(
|
115 |
+
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
|
120 |
+
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
|
121 |
+
# We want to iterate over parameters with _sequence_parallel=True in the same order,
|
122 |
+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
|
123 |
+
params_seqparallel = {
|
124 |
+
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
|
125 |
+
}
|
126 |
+
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
|
127 |
+
if grads:
|
128 |
+
with torch.no_grad():
|
129 |
+
coalesced = torch._utils._flatten_dense_tensors(grads)
|
130 |
+
torch.distributed.all_reduce(coalesced, group=process_group)
|
131 |
+
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
|
132 |
+
buf.copy_(synced)
|
133 |
+
|
134 |
+
|
135 |
+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
|
136 |
+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
|
137 |
+
|
138 |
+
The split may not be even across the world_size processes.
|
139 |
+
"""
|
140 |
+
multiple = dim // multiple_of
|
141 |
+
div = multiple // world_size
|
142 |
+
mod = multiple % world_size
|
143 |
+
local_multiple = div + int(local_rank < mod)
|
144 |
+
return local_multiple * multiple_of
|
torch-ext/mamba_ssm/distributed/tensor_parallel.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.distributed import ProcessGroup
|
10 |
+
from ..utils.torch import custom_bwd, custom_fwd
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from ..distributed.distributed_utils import (
|
15 |
+
all_gather_raw,
|
16 |
+
all_reduce,
|
17 |
+
all_reduce_raw,
|
18 |
+
reduce_scatter,
|
19 |
+
reduce_scatter_raw,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class ParallelLinearFunc(torch.autograd.Function):
|
24 |
+
@staticmethod
|
25 |
+
@custom_fwd
|
26 |
+
def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
|
27 |
+
"""
|
28 |
+
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
29 |
+
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
|
30 |
+
"""
|
31 |
+
ctx.compute_weight_gradient = weight.requires_grad
|
32 |
+
ctx.process_group = process_group
|
33 |
+
ctx.sequence_parallel = sequence_parallel
|
34 |
+
|
35 |
+
if torch.is_autocast_enabled():
|
36 |
+
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
37 |
+
x = x.contiguous()
|
38 |
+
if process_group is not None and sequence_parallel:
|
39 |
+
# We want to kick off the all_gather early, before weight dtype conversion
|
40 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
41 |
+
else:
|
42 |
+
total_x = x
|
43 |
+
|
44 |
+
if torch.is_autocast_enabled():
|
45 |
+
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
46 |
+
bias = (
|
47 |
+
bias.to(dtype=torch.get_autocast_gpu_dtype())
|
48 |
+
if bias is not None
|
49 |
+
else None
|
50 |
+
)
|
51 |
+
weight = weight.contiguous()
|
52 |
+
if process_group is not None and sequence_parallel:
|
53 |
+
handle_x.wait()
|
54 |
+
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
55 |
+
batch_dim = batch_shape.numel()
|
56 |
+
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
|
57 |
+
output = F.linear(total_x, weight, bias)
|
58 |
+
if ctx.compute_weight_gradient:
|
59 |
+
ctx.save_for_backward(x, weight)
|
60 |
+
else:
|
61 |
+
ctx.save_for_backward(weight)
|
62 |
+
return output
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
@custom_bwd
|
66 |
+
def backward(ctx, grad_output):
|
67 |
+
grad_output = grad_output.contiguous()
|
68 |
+
process_group = ctx.process_group
|
69 |
+
sequence_parallel = ctx.sequence_parallel
|
70 |
+
if ctx.compute_weight_gradient:
|
71 |
+
x, weight = ctx.saved_tensors
|
72 |
+
if process_group is not None and sequence_parallel:
|
73 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
74 |
+
else:
|
75 |
+
total_x = x
|
76 |
+
else:
|
77 |
+
(weight,) = ctx.saved_tensors
|
78 |
+
total_x = None
|
79 |
+
batch_shape = grad_output.shape[:-1]
|
80 |
+
batch_dim = batch_shape.numel()
|
81 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
82 |
+
if ctx.needs_input_grad[0]:
|
83 |
+
grad_input = F.linear(grad_output, weight.t())
|
84 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
85 |
+
if process_group is not None:
|
86 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
87 |
+
grad_input, handle_grad_input = reduce_fn(
|
88 |
+
grad_input, process_group, async_op=True
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
grad_input = None
|
92 |
+
if ctx.needs_input_grad[1]:
|
93 |
+
assert ctx.compute_weight_gradient
|
94 |
+
if process_group is not None and sequence_parallel:
|
95 |
+
handle_x.wait()
|
96 |
+
grad_weight = torch.einsum(
|
97 |
+
"bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
grad_weight = None
|
101 |
+
grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
|
102 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
103 |
+
handle_grad_input.wait()
|
104 |
+
return grad_input, grad_weight, grad_bias, None, None
|
105 |
+
|
106 |
+
|
107 |
+
def parallel_linear_func(
|
108 |
+
x: Tensor,
|
109 |
+
weight: Tensor,
|
110 |
+
bias: Optional[Tensor] = None,
|
111 |
+
process_group: Optional[ProcessGroup] = None,
|
112 |
+
sequence_parallel: bool = True,
|
113 |
+
):
|
114 |
+
return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
|
115 |
+
|
116 |
+
|
117 |
+
class ColumnParallelLinear(nn.Linear):
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
in_features: int,
|
121 |
+
out_features: int,
|
122 |
+
process_group: ProcessGroup,
|
123 |
+
bias: bool = True,
|
124 |
+
sequence_parallel=True,
|
125 |
+
multiple_of=1,
|
126 |
+
device=None,
|
127 |
+
dtype=None,
|
128 |
+
) -> None:
|
129 |
+
world_size = torch.distributed.get_world_size(process_group)
|
130 |
+
if out_features % multiple_of:
|
131 |
+
raise ValueError(
|
132 |
+
f"out_features ({out_features}) must be a multiple of {multiple_of}"
|
133 |
+
)
|
134 |
+
multiple = out_features // multiple_of
|
135 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
136 |
+
div = multiple // world_size
|
137 |
+
mod = multiple % world_size
|
138 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
139 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
140 |
+
super().__init__(
|
141 |
+
in_features,
|
142 |
+
local_multiple * multiple_of,
|
143 |
+
bias=bias,
|
144 |
+
device=device,
|
145 |
+
dtype=dtype,
|
146 |
+
)
|
147 |
+
self.process_group = process_group
|
148 |
+
self.sequence_parallel = sequence_parallel
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
152 |
+
# we do an all_gather of x before doing the matmul.
|
153 |
+
# If not, then the input is already gathered.
|
154 |
+
return parallel_linear_func(
|
155 |
+
x,
|
156 |
+
self.weight,
|
157 |
+
self.bias,
|
158 |
+
process_group=self.process_group,
|
159 |
+
sequence_parallel=self.sequence_parallel,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class RowParallelLinear(nn.Linear):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_features: int,
|
167 |
+
out_features: int,
|
168 |
+
process_group: ProcessGroup,
|
169 |
+
bias: bool = True,
|
170 |
+
sequence_parallel=True,
|
171 |
+
multiple_of=1,
|
172 |
+
device=None,
|
173 |
+
dtype=None,
|
174 |
+
) -> None:
|
175 |
+
world_size = torch.distributed.get_world_size(process_group)
|
176 |
+
rank = torch.distributed.get_rank(process_group)
|
177 |
+
if in_features % multiple_of:
|
178 |
+
raise ValueError(
|
179 |
+
f"in_features ({in_features}) must be a multiple of {multiple_of}"
|
180 |
+
)
|
181 |
+
multiple = in_features // multiple_of
|
182 |
+
# We want to split @multiple across world_size, but it could be an uneven split
|
183 |
+
div = multiple // world_size
|
184 |
+
mod = multiple % world_size
|
185 |
+
# The first @mod ranks get @div + 1 copies, the rest get @div copies
|
186 |
+
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
|
187 |
+
# Only rank 0 will have bias
|
188 |
+
super().__init__(
|
189 |
+
local_multiple * multiple_of,
|
190 |
+
out_features,
|
191 |
+
bias=bias and rank == 0,
|
192 |
+
device=device,
|
193 |
+
dtype=dtype,
|
194 |
+
)
|
195 |
+
self.process_group = process_group
|
196 |
+
self.sequence_parallel = sequence_parallel
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
"""
|
200 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
201 |
+
a reduce_scatter of the result.
|
202 |
+
"""
|
203 |
+
out = parallel_linear_func(x, self.weight, self.bias)
|
204 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
205 |
+
return reduce_fn(out, self.process_group)
|
206 |
+
|
207 |
+
|
208 |
+
class VocabParallelEmbedding(nn.Embedding):
|
209 |
+
def __init__(
|
210 |
+
self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
|
211 |
+
):
|
212 |
+
self.process_group = process_group
|
213 |
+
if process_group is not None:
|
214 |
+
world_size = torch.distributed.get_world_size(process_group)
|
215 |
+
if num_embeddings % world_size != 0:
|
216 |
+
raise ValueError(
|
217 |
+
f"num_embeddings ({num_embeddings}) must be divisible by "
|
218 |
+
f"world_size ({world_size})"
|
219 |
+
)
|
220 |
+
if world_size > 1 and padding_idx is not None:
|
221 |
+
raise RuntimeError("ParallelEmbedding does not support padding_idx")
|
222 |
+
else:
|
223 |
+
world_size = 1
|
224 |
+
super().__init__(
|
225 |
+
num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
def forward(self, input: Tensor) -> Tensor:
|
229 |
+
if self.process_group is None:
|
230 |
+
return super().forward(input)
|
231 |
+
else:
|
232 |
+
rank = torch.distributed.get_rank(self.process_group)
|
233 |
+
vocab_size = self.num_embeddings
|
234 |
+
vocab_start_index, vocab_end_index = (
|
235 |
+
rank * vocab_size,
|
236 |
+
(rank + 1) * vocab_size,
|
237 |
+
)
|
238 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
239 |
+
input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
|
240 |
+
input = input - vocab_start_index
|
241 |
+
input[input_ids_mask] = 0
|
242 |
+
embeddings = super().forward(input)
|
243 |
+
embeddings[input_ids_mask] = 0.0
|
244 |
+
return embeddings
|
245 |
+
|
246 |
+
|
247 |
+
class ColumnParallelEmbedding(nn.Embedding):
|
248 |
+
def __init__(
|
249 |
+
self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
|
250 |
+
):
|
251 |
+
self.process_group = process_group
|
252 |
+
if process_group is not None:
|
253 |
+
world_size = torch.distributed.get_world_size(process_group)
|
254 |
+
if embedding_dim % world_size != 0:
|
255 |
+
raise ValueError(
|
256 |
+
f"embedding_dim ({embedding_dim}) must be divisible by "
|
257 |
+
f"world_size ({world_size})"
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
world_size = 1
|
261 |
+
super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
|
262 |
+
|
263 |
+
|
264 |
+
class ParallelEmbeddings(nn.Module):
|
265 |
+
def __init__(
|
266 |
+
self,
|
267 |
+
embed_dim,
|
268 |
+
vocab_size,
|
269 |
+
max_position_embeddings,
|
270 |
+
process_group,
|
271 |
+
padding_idx=None,
|
272 |
+
sequence_parallel=True,
|
273 |
+
device=None,
|
274 |
+
dtype=None,
|
275 |
+
):
|
276 |
+
"""
|
277 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
278 |
+
"""
|
279 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
280 |
+
super().__init__()
|
281 |
+
self.process_group = process_group
|
282 |
+
self.sequence_parallel = sequence_parallel
|
283 |
+
self.word_embeddings = VocabParallelEmbedding(
|
284 |
+
vocab_size,
|
285 |
+
embed_dim,
|
286 |
+
padding_idx=padding_idx,
|
287 |
+
process_group=process_group,
|
288 |
+
**factory_kwargs,
|
289 |
+
)
|
290 |
+
self.max_position_embeddings = max_position_embeddings
|
291 |
+
if self.max_position_embeddings > 0:
|
292 |
+
self.position_embeddings = ColumnParallelEmbedding(
|
293 |
+
max_position_embeddings,
|
294 |
+
embed_dim,
|
295 |
+
process_group=process_group,
|
296 |
+
**factory_kwargs,
|
297 |
+
)
|
298 |
+
|
299 |
+
def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
|
300 |
+
"""
|
301 |
+
input_ids: (batch, seqlen)
|
302 |
+
position_ids: (batch, seqlen)
|
303 |
+
"""
|
304 |
+
batch_size, seqlen = input_ids.shape
|
305 |
+
world_size = torch.distributed.get_world_size(self.process_group)
|
306 |
+
embeddings = self.word_embeddings(input_ids)
|
307 |
+
if self.max_position_embeddings > 0:
|
308 |
+
if position_ids is None:
|
309 |
+
position_ids = torch.arange(
|
310 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
311 |
+
)
|
312 |
+
position_embeddings = self.position_embeddings(position_ids)
|
313 |
+
if world_size <= 1:
|
314 |
+
embeddings = embeddings + position_embeddings
|
315 |
+
else:
|
316 |
+
partition_dim = self.position_embeddings.embedding_dim
|
317 |
+
rank = torch.distributed.get_rank(self.process_group)
|
318 |
+
embeddings[
|
319 |
+
..., rank * partition_dim : (rank + 1) * partition_dim
|
320 |
+
] += position_embeddings
|
321 |
+
if combine_batch_seqlen_dim:
|
322 |
+
embeddings = rearrange(embeddings, "b s d -> (b s) d")
|
323 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
324 |
+
return (
|
325 |
+
embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
|
326 |
+
)
|
torch-ext/mamba_ssm/models/__init__.py
ADDED
File without changes
|
torch-ext/mamba_ssm/models/config_mamba.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class MambaConfig:
|
6 |
+
|
7 |
+
d_model: int = 2560
|
8 |
+
d_intermediate: int = 0
|
9 |
+
n_layer: int = 64
|
10 |
+
vocab_size: int = 50277
|
11 |
+
ssm_cfg: dict = field(default_factory=dict)
|
12 |
+
attn_layer_idx: list = field(default_factory=list)
|
13 |
+
attn_cfg: dict = field(default_factory=dict)
|
14 |
+
rms_norm: bool = True
|
15 |
+
residual_in_fp32: bool = True
|
16 |
+
fused_add_norm: bool = True
|
17 |
+
pad_vocab_size_multiple: int = 8
|
18 |
+
tie_embeddings: bool = True
|
torch-ext/mamba_ssm/models/mixer_seq_simple.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from functools import partial
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
|
9 |
+
from collections import namedtuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from .config_mamba import MambaConfig
|
15 |
+
from ..modules.mamba_simple import Mamba
|
16 |
+
from ..modules.mamba2 import Mamba2
|
17 |
+
from ..modules.mha import MHA
|
18 |
+
from ..modules.mlp import GatedMLP
|
19 |
+
from ..modules.block import Block
|
20 |
+
from ..utils.generation import GenerationMixin
|
21 |
+
from ..utils.hf import load_config_hf, load_state_dict_hf
|
22 |
+
|
23 |
+
try:
|
24 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
25 |
+
except ImportError:
|
26 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
27 |
+
|
28 |
+
|
29 |
+
def create_block(
|
30 |
+
d_model,
|
31 |
+
d_intermediate,
|
32 |
+
ssm_cfg=None,
|
33 |
+
attn_layer_idx=None,
|
34 |
+
attn_cfg=None,
|
35 |
+
norm_epsilon=1e-5,
|
36 |
+
rms_norm=False,
|
37 |
+
residual_in_fp32=False,
|
38 |
+
fused_add_norm=False,
|
39 |
+
layer_idx=None,
|
40 |
+
device=None,
|
41 |
+
dtype=None,
|
42 |
+
):
|
43 |
+
if ssm_cfg is None:
|
44 |
+
ssm_cfg = {}
|
45 |
+
if attn_layer_idx is None:
|
46 |
+
attn_layer_idx = []
|
47 |
+
if attn_cfg is None:
|
48 |
+
attn_cfg = {}
|
49 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
50 |
+
if layer_idx not in attn_layer_idx:
|
51 |
+
# Create a copy of the config to modify
|
52 |
+
ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
|
53 |
+
ssm_layer = ssm_cfg.pop("layer", "Mamba1")
|
54 |
+
if ssm_layer not in ["Mamba1", "Mamba2"]:
|
55 |
+
raise ValueError(
|
56 |
+
f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
|
57 |
+
)
|
58 |
+
mixer_cls = partial(
|
59 |
+
Mamba2 if ssm_layer == "Mamba2" else Mamba,
|
60 |
+
layer_idx=layer_idx,
|
61 |
+
**ssm_cfg,
|
62 |
+
**factory_kwargs,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
|
66 |
+
norm_cls = partial(
|
67 |
+
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
68 |
+
)
|
69 |
+
if d_intermediate == 0:
|
70 |
+
mlp_cls = nn.Identity
|
71 |
+
else:
|
72 |
+
mlp_cls = partial(
|
73 |
+
GatedMLP,
|
74 |
+
hidden_features=d_intermediate,
|
75 |
+
out_features=d_model,
|
76 |
+
**factory_kwargs,
|
77 |
+
)
|
78 |
+
block = Block(
|
79 |
+
d_model,
|
80 |
+
mixer_cls,
|
81 |
+
mlp_cls,
|
82 |
+
norm_cls=norm_cls,
|
83 |
+
fused_add_norm=fused_add_norm,
|
84 |
+
residual_in_fp32=residual_in_fp32,
|
85 |
+
)
|
86 |
+
block.layer_idx = layer_idx
|
87 |
+
return block
|
88 |
+
|
89 |
+
|
90 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
91 |
+
def _init_weights(
|
92 |
+
module,
|
93 |
+
n_layer,
|
94 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
95 |
+
rescale_prenorm_residual=True,
|
96 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
97 |
+
):
|
98 |
+
if isinstance(module, nn.Linear):
|
99 |
+
if module.bias is not None:
|
100 |
+
if not getattr(module.bias, "_no_reinit", False):
|
101 |
+
nn.init.zeros_(module.bias)
|
102 |
+
elif isinstance(module, nn.Embedding):
|
103 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
104 |
+
|
105 |
+
if rescale_prenorm_residual:
|
106 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
107 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
108 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
109 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
110 |
+
#
|
111 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
112 |
+
for name, p in module.named_parameters():
|
113 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
114 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
115 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
116 |
+
# We need to reinit p since this code could be called multiple times
|
117 |
+
# Having just p *= scale would repeatedly scale it down
|
118 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
119 |
+
with torch.no_grad():
|
120 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
121 |
+
|
122 |
+
|
123 |
+
class MixerModel(nn.Module):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
d_model: int,
|
127 |
+
n_layer: int,
|
128 |
+
d_intermediate: int,
|
129 |
+
vocab_size: int,
|
130 |
+
ssm_cfg=None,
|
131 |
+
attn_layer_idx=None,
|
132 |
+
attn_cfg=None,
|
133 |
+
norm_epsilon: float = 1e-5,
|
134 |
+
rms_norm: bool = False,
|
135 |
+
initializer_cfg=None,
|
136 |
+
fused_add_norm=False,
|
137 |
+
residual_in_fp32=False,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
) -> None:
|
141 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
142 |
+
super().__init__()
|
143 |
+
self.residual_in_fp32 = residual_in_fp32
|
144 |
+
|
145 |
+
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
146 |
+
|
147 |
+
# We change the order of residual and layer norm:
|
148 |
+
# Instead of LN -> Attn / MLP -> Add, we do:
|
149 |
+
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
|
150 |
+
# the main branch (output of MLP / Mixer). The model definition is unchanged.
|
151 |
+
# This is for performance reason: we can fuse add + layer_norm.
|
152 |
+
self.fused_add_norm = fused_add_norm
|
153 |
+
if self.fused_add_norm:
|
154 |
+
if layer_norm_fn is None or rms_norm_fn is None:
|
155 |
+
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
|
156 |
+
|
157 |
+
self.layers = nn.ModuleList(
|
158 |
+
[
|
159 |
+
create_block(
|
160 |
+
d_model,
|
161 |
+
d_intermediate=d_intermediate,
|
162 |
+
ssm_cfg=ssm_cfg,
|
163 |
+
attn_layer_idx=attn_layer_idx,
|
164 |
+
attn_cfg=attn_cfg,
|
165 |
+
norm_epsilon=norm_epsilon,
|
166 |
+
rms_norm=rms_norm,
|
167 |
+
residual_in_fp32=residual_in_fp32,
|
168 |
+
fused_add_norm=fused_add_norm,
|
169 |
+
layer_idx=i,
|
170 |
+
**factory_kwargs,
|
171 |
+
)
|
172 |
+
for i in range(n_layer)
|
173 |
+
]
|
174 |
+
)
|
175 |
+
|
176 |
+
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
177 |
+
d_model, eps=norm_epsilon, **factory_kwargs
|
178 |
+
)
|
179 |
+
|
180 |
+
self.apply(
|
181 |
+
partial(
|
182 |
+
_init_weights,
|
183 |
+
n_layer=n_layer,
|
184 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
185 |
+
n_residuals_per_layer=(
|
186 |
+
1 if d_intermediate == 0 else 2
|
187 |
+
), # 2 if we have MLP
|
188 |
+
)
|
189 |
+
)
|
190 |
+
|
191 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
192 |
+
return {
|
193 |
+
i: layer.allocate_inference_cache(
|
194 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
195 |
+
)
|
196 |
+
for i, layer in enumerate(self.layers)
|
197 |
+
}
|
198 |
+
|
199 |
+
def forward(self, input_ids, inference_params=None, **mixer_kwargs):
|
200 |
+
hidden_states = self.embedding(input_ids)
|
201 |
+
residual = None
|
202 |
+
for layer in self.layers:
|
203 |
+
hidden_states, residual = layer(
|
204 |
+
hidden_states,
|
205 |
+
residual,
|
206 |
+
inference_params=inference_params,
|
207 |
+
**mixer_kwargs,
|
208 |
+
)
|
209 |
+
if not self.fused_add_norm:
|
210 |
+
residual = (
|
211 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
212 |
+
)
|
213 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
214 |
+
else:
|
215 |
+
# Set prenorm=False here since we don't need the residual
|
216 |
+
hidden_states = layer_norm_fn(
|
217 |
+
hidden_states,
|
218 |
+
self.norm_f.weight,
|
219 |
+
self.norm_f.bias,
|
220 |
+
eps=self.norm_f.eps,
|
221 |
+
residual=residual,
|
222 |
+
prenorm=False,
|
223 |
+
residual_in_fp32=self.residual_in_fp32,
|
224 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm),
|
225 |
+
)
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
230 |
+
|
231 |
+
def __init__(
|
232 |
+
self,
|
233 |
+
config: MambaConfig,
|
234 |
+
initializer_cfg=None,
|
235 |
+
device=None,
|
236 |
+
dtype=None,
|
237 |
+
) -> None:
|
238 |
+
self.config = config
|
239 |
+
d_model = config.d_model
|
240 |
+
n_layer = config.n_layer
|
241 |
+
d_intermediate = config.d_intermediate
|
242 |
+
vocab_size = config.vocab_size
|
243 |
+
ssm_cfg = config.ssm_cfg
|
244 |
+
attn_layer_idx = config.attn_layer_idx
|
245 |
+
attn_cfg = config.attn_cfg
|
246 |
+
rms_norm = config.rms_norm
|
247 |
+
residual_in_fp32 = config.residual_in_fp32
|
248 |
+
fused_add_norm = config.fused_add_norm
|
249 |
+
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
250 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
251 |
+
|
252 |
+
super().__init__()
|
253 |
+
if vocab_size % pad_vocab_size_multiple != 0:
|
254 |
+
vocab_size += pad_vocab_size_multiple - (
|
255 |
+
vocab_size % pad_vocab_size_multiple
|
256 |
+
)
|
257 |
+
self.backbone = MixerModel(
|
258 |
+
d_model=d_model,
|
259 |
+
n_layer=n_layer,
|
260 |
+
d_intermediate=d_intermediate,
|
261 |
+
vocab_size=vocab_size,
|
262 |
+
ssm_cfg=ssm_cfg,
|
263 |
+
attn_layer_idx=attn_layer_idx,
|
264 |
+
attn_cfg=attn_cfg,
|
265 |
+
rms_norm=rms_norm,
|
266 |
+
initializer_cfg=initializer_cfg,
|
267 |
+
fused_add_norm=fused_add_norm,
|
268 |
+
residual_in_fp32=residual_in_fp32,
|
269 |
+
**factory_kwargs,
|
270 |
+
)
|
271 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
272 |
+
|
273 |
+
# Initialize weights and apply final processing
|
274 |
+
self.apply(
|
275 |
+
partial(
|
276 |
+
_init_weights,
|
277 |
+
n_layer=n_layer,
|
278 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
279 |
+
)
|
280 |
+
)
|
281 |
+
self.tie_weights()
|
282 |
+
|
283 |
+
def tie_weights(self):
|
284 |
+
if self.config.tie_embeddings:
|
285 |
+
self.lm_head.weight = self.backbone.embedding.weight
|
286 |
+
|
287 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
288 |
+
return self.backbone.allocate_inference_cache(
|
289 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(
|
293 |
+
self,
|
294 |
+
input_ids,
|
295 |
+
position_ids=None,
|
296 |
+
inference_params=None,
|
297 |
+
num_last_tokens=0,
|
298 |
+
**mixer_kwargs,
|
299 |
+
):
|
300 |
+
"""
|
301 |
+
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
302 |
+
num_last_tokens: if > 0, only return the logits for the last n tokens
|
303 |
+
"""
|
304 |
+
hidden_states = self.backbone(
|
305 |
+
input_ids, inference_params=inference_params, **mixer_kwargs
|
306 |
+
)
|
307 |
+
if num_last_tokens > 0:
|
308 |
+
hidden_states = hidden_states[:, -num_last_tokens:]
|
309 |
+
lm_logits = self.lm_head(hidden_states)
|
310 |
+
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
311 |
+
return CausalLMOutput(logits=lm_logits)
|
312 |
+
|
313 |
+
@classmethod
|
314 |
+
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
315 |
+
config_data = load_config_hf(pretrained_model_name)
|
316 |
+
config = MambaConfig(**config_data)
|
317 |
+
model = cls(config, device=device, dtype=dtype, **kwargs)
|
318 |
+
model.load_state_dict(
|
319 |
+
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
320 |
+
)
|
321 |
+
return model
|
322 |
+
|
323 |
+
def save_pretrained(self, save_directory):
|
324 |
+
"""
|
325 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
326 |
+
Save the model and its configuration file to a directory.
|
327 |
+
"""
|
328 |
+
# Ensure save_directory exists
|
329 |
+
os.makedirs(save_directory, exist_ok=True)
|
330 |
+
|
331 |
+
# Save the model's state_dict
|
332 |
+
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
333 |
+
torch.save(self.state_dict(), model_path)
|
334 |
+
|
335 |
+
# Save the configuration of the model
|
336 |
+
config_path = os.path.join(save_directory, "config.json")
|
337 |
+
with open(config_path, "w") as f:
|
338 |
+
json.dump(self.config.__dict__, f, indent=4)
|
torch-ext/mamba_ssm/modules/__init__.py
ADDED
File without changes
|
torch-ext/mamba_ssm/modules/block.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, Tensor
|
6 |
+
|
7 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
8 |
+
|
9 |
+
|
10 |
+
class Block(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
dim,
|
14 |
+
mixer_cls,
|
15 |
+
mlp_cls,
|
16 |
+
norm_cls=nn.LayerNorm,
|
17 |
+
fused_add_norm=False,
|
18 |
+
residual_in_fp32=False,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
22 |
+
|
23 |
+
This Block has a slightly different structure compared to a regular
|
24 |
+
prenorm Transformer block.
|
25 |
+
The standard block is: LN -> MHA/MLP -> Add.
|
26 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
27 |
+
Here we have: Add -> LN -> Mixer, returning both
|
28 |
+
the hidden_states (output of the mixer) and the residual.
|
29 |
+
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
30 |
+
The residual needs to be provided (except for the very first block).
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.residual_in_fp32 = residual_in_fp32
|
34 |
+
self.fused_add_norm = fused_add_norm
|
35 |
+
self.norm = norm_cls(dim)
|
36 |
+
self.mixer = mixer_cls(dim)
|
37 |
+
if mlp_cls is not nn.Identity:
|
38 |
+
self.norm2 = norm_cls(dim)
|
39 |
+
self.mlp = mlp_cls(dim)
|
40 |
+
else:
|
41 |
+
self.mlp = None
|
42 |
+
if self.fused_add_norm:
|
43 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
44 |
+
assert isinstance(
|
45 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
46 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
47 |
+
|
48 |
+
def forward(
|
49 |
+
self,
|
50 |
+
hidden_states: Tensor,
|
51 |
+
residual: Optional[Tensor] = None,
|
52 |
+
inference_params=None,
|
53 |
+
**mixer_kwargs
|
54 |
+
):
|
55 |
+
r"""Pass the input through the encoder layer.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
hidden_states: the sequence to the encoder layer (required).
|
59 |
+
residual: hidden_states = Mixer(LN(residual))
|
60 |
+
"""
|
61 |
+
if not self.fused_add_norm:
|
62 |
+
residual = (
|
63 |
+
(hidden_states + residual) if residual is not None else hidden_states
|
64 |
+
)
|
65 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
66 |
+
if self.residual_in_fp32:
|
67 |
+
residual = residual.to(torch.float32)
|
68 |
+
else:
|
69 |
+
hidden_states, residual = layer_norm_fn(
|
70 |
+
hidden_states,
|
71 |
+
self.norm.weight,
|
72 |
+
self.norm.bias,
|
73 |
+
residual=residual,
|
74 |
+
prenorm=True,
|
75 |
+
residual_in_fp32=self.residual_in_fp32,
|
76 |
+
eps=self.norm.eps,
|
77 |
+
is_rms_norm=isinstance(self.norm, RMSNorm),
|
78 |
+
)
|
79 |
+
hidden_states = self.mixer(
|
80 |
+
hidden_states, inference_params=inference_params, **mixer_kwargs
|
81 |
+
)
|
82 |
+
|
83 |
+
if self.mlp is not None:
|
84 |
+
if not self.fused_add_norm:
|
85 |
+
residual = hidden_states + residual
|
86 |
+
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
87 |
+
if self.residual_in_fp32:
|
88 |
+
residual = residual.to(torch.float32)
|
89 |
+
else:
|
90 |
+
hidden_states, residual = layer_norm_fn(
|
91 |
+
hidden_states,
|
92 |
+
self.norm2.weight,
|
93 |
+
self.norm2.bias,
|
94 |
+
residual=residual,
|
95 |
+
prenorm=True,
|
96 |
+
residual_in_fp32=self.residual_in_fp32,
|
97 |
+
eps=self.norm2.eps,
|
98 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
99 |
+
)
|
100 |
+
hidden_states = self.mlp(hidden_states)
|
101 |
+
|
102 |
+
return hidden_states, residual
|
103 |
+
|
104 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
105 |
+
return self.mixer.allocate_inference_cache(
|
106 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
107 |
+
)
|
torch-ext/mamba_ssm/modules/mamba2.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
try:
|
12 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
13 |
+
except ImportError:
|
14 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
15 |
+
|
16 |
+
try:
|
17 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
|
18 |
+
except ImportError:
|
19 |
+
causal_conv1d_varlen_states = None
|
20 |
+
|
21 |
+
try:
|
22 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
23 |
+
except ImportError:
|
24 |
+
selective_state_update = None
|
25 |
+
|
26 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
|
27 |
+
|
28 |
+
from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
|
29 |
+
from ..distributed.distributed_utils import all_reduce, reduce_scatter
|
30 |
+
|
31 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
32 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
33 |
+
|
34 |
+
from huggingface_hub import PyTorchModelHubMixin
|
35 |
+
|
36 |
+
|
37 |
+
class Mamba2(nn.Module, PyTorchModelHubMixin):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
d_model,
|
41 |
+
d_state=128,
|
42 |
+
d_conv=4,
|
43 |
+
conv_init=None,
|
44 |
+
expand=2,
|
45 |
+
headdim=64,
|
46 |
+
d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
|
47 |
+
ngroups=1,
|
48 |
+
A_init_range=(1, 16),
|
49 |
+
D_has_hdim=False,
|
50 |
+
rmsnorm=True,
|
51 |
+
norm_before_gate=False,
|
52 |
+
dt_min=0.001,
|
53 |
+
dt_max=0.1,
|
54 |
+
dt_init_floor=1e-4,
|
55 |
+
dt_limit=(0.0, float("inf")),
|
56 |
+
bias=False,
|
57 |
+
conv_bias=True,
|
58 |
+
# Fused kernel and sharding options
|
59 |
+
chunk_size=256,
|
60 |
+
use_mem_eff_path=True,
|
61 |
+
layer_idx=None, # Absorb kwarg for general module
|
62 |
+
process_group=None,
|
63 |
+
sequence_parallel=True,
|
64 |
+
device=None,
|
65 |
+
dtype=None,
|
66 |
+
):
|
67 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
68 |
+
super().__init__()
|
69 |
+
self.d_model = d_model
|
70 |
+
self.d_state = d_state
|
71 |
+
self.d_conv = d_conv
|
72 |
+
self.conv_init = conv_init
|
73 |
+
self.expand = expand
|
74 |
+
self.process_group = process_group
|
75 |
+
self.sequence_parallel = sequence_parallel
|
76 |
+
self.world_size = 1 if process_group is None else process_group.size()
|
77 |
+
self.local_rank = 0 if process_group is None else process_group.rank()
|
78 |
+
self.d_inner = (self.expand * self.d_model) // self.world_size
|
79 |
+
assert self.d_inner * self.world_size == self.expand * self.d_model
|
80 |
+
self.headdim = headdim
|
81 |
+
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
|
82 |
+
assert ngroups % self.world_size == 0
|
83 |
+
self.ngroups = ngroups // self.world_size
|
84 |
+
assert self.d_ssm % self.headdim == 0
|
85 |
+
self.nheads = self.d_ssm // self.headdim
|
86 |
+
self.D_has_hdim = D_has_hdim
|
87 |
+
self.rmsnorm = rmsnorm
|
88 |
+
self.norm_before_gate = norm_before_gate
|
89 |
+
self.dt_limit = dt_limit
|
90 |
+
self.activation = "silu"
|
91 |
+
self.chunk_size = chunk_size
|
92 |
+
self.use_mem_eff_path = use_mem_eff_path
|
93 |
+
self.layer_idx = layer_idx
|
94 |
+
|
95 |
+
# Order: [z, x, B, C, dt]
|
96 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
97 |
+
if self.process_group is None:
|
98 |
+
self.in_proj = nn.Linear(
|
99 |
+
self.d_model, d_in_proj, bias=bias, **factory_kwargs
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.in_proj = ColumnParallelLinear(
|
103 |
+
self.d_model,
|
104 |
+
d_in_proj * self.world_size,
|
105 |
+
bias=bias,
|
106 |
+
process_group=self.process_group,
|
107 |
+
sequence_parallel=self.sequence_parallel,
|
108 |
+
**factory_kwargs,
|
109 |
+
)
|
110 |
+
|
111 |
+
conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
|
112 |
+
self.conv1d = nn.Conv1d(
|
113 |
+
in_channels=conv_dim,
|
114 |
+
out_channels=conv_dim,
|
115 |
+
bias=conv_bias,
|
116 |
+
kernel_size=d_conv,
|
117 |
+
groups=conv_dim,
|
118 |
+
padding=d_conv - 1,
|
119 |
+
**factory_kwargs,
|
120 |
+
)
|
121 |
+
if self.conv_init is not None:
|
122 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
123 |
+
|
124 |
+
self.act = nn.SiLU()
|
125 |
+
|
126 |
+
# Initialize log dt bias
|
127 |
+
dt = torch.exp(
|
128 |
+
torch.rand(self.nheads, **factory_kwargs)
|
129 |
+
* (math.log(dt_max) - math.log(dt_min))
|
130 |
+
+ math.log(dt_min)
|
131 |
+
)
|
132 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
133 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
134 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
135 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
136 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
137 |
+
# name.endswith("bias") in param_grouping.py
|
138 |
+
self.dt_bias._no_weight_decay = True
|
139 |
+
|
140 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
141 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
142 |
+
*A_init_range
|
143 |
+
)
|
144 |
+
A_log = torch.log(A).to(dtype=dtype)
|
145 |
+
self.A_log = nn.Parameter(A_log)
|
146 |
+
self.A_log._no_weight_decay = True
|
147 |
+
|
148 |
+
# D "skip" parameter
|
149 |
+
self.D = nn.Parameter(
|
150 |
+
torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
|
151 |
+
)
|
152 |
+
self.D._no_weight_decay = True
|
153 |
+
|
154 |
+
if self.rmsnorm:
|
155 |
+
assert RMSNormGated is not None
|
156 |
+
self.norm = RMSNormGated(
|
157 |
+
self.d_ssm,
|
158 |
+
eps=1e-5,
|
159 |
+
norm_before_gate=self.norm_before_gate,
|
160 |
+
group_size=self.d_ssm // ngroups,
|
161 |
+
**factory_kwargs,
|
162 |
+
)
|
163 |
+
|
164 |
+
if self.process_group is None:
|
165 |
+
self.out_proj = nn.Linear(
|
166 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.out_proj = RowParallelLinear(
|
170 |
+
self.d_inner * self.world_size,
|
171 |
+
self.d_model,
|
172 |
+
bias=bias,
|
173 |
+
process_group=self.process_group,
|
174 |
+
sequence_parallel=self.sequence_parallel,
|
175 |
+
**factory_kwargs,
|
176 |
+
)
|
177 |
+
|
178 |
+
def forward(
|
179 |
+
self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
|
180 |
+
):
|
181 |
+
"""
|
182 |
+
u: (batch, seqlen, hidden_dim) if seqlen=None.
|
183 |
+
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
|
184 |
+
split u during sequence parallel, we split the batch * seqlen dimension
|
185 |
+
(in case batch is small).
|
186 |
+
Returns: same shape as u
|
187 |
+
"""
|
188 |
+
seqlen_og = seqlen
|
189 |
+
if seqlen is None:
|
190 |
+
batch, seqlen, dim = u.shape
|
191 |
+
else:
|
192 |
+
batch_seqlen, dim = u.shape
|
193 |
+
batch = batch_seqlen // seqlen
|
194 |
+
|
195 |
+
conv_state, ssm_state = None, None
|
196 |
+
if inference_params is not None:
|
197 |
+
inference_batch = (
|
198 |
+
cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
|
199 |
+
)
|
200 |
+
conv_state, ssm_state = self._get_states_from_cache(
|
201 |
+
inference_params, inference_batch
|
202 |
+
)
|
203 |
+
if inference_params.seqlen_offset > 0:
|
204 |
+
# The states are updated inplace
|
205 |
+
out, _, _ = self.step(u, conv_state, ssm_state)
|
206 |
+
return out
|
207 |
+
|
208 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
|
209 |
+
if seqlen_og is not None:
|
210 |
+
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
|
211 |
+
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
212 |
+
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
|
213 |
+
dt_limit_kwargs = (
|
214 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
215 |
+
)
|
216 |
+
if self.use_mem_eff_path and inference_params is None:
|
217 |
+
out = mamba_split_conv1d_scan_combined(
|
218 |
+
zxbcdt,
|
219 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
220 |
+
self.conv1d.bias,
|
221 |
+
self.dt_bias,
|
222 |
+
A,
|
223 |
+
D=(
|
224 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
225 |
+
if self.D_has_hdim
|
226 |
+
else self.D
|
227 |
+
),
|
228 |
+
chunk_size=self.chunk_size,
|
229 |
+
seq_idx=seq_idx,
|
230 |
+
activation=self.activation,
|
231 |
+
rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
|
232 |
+
rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
|
233 |
+
outproj_weight=self.out_proj.weight,
|
234 |
+
outproj_bias=self.out_proj.bias,
|
235 |
+
headdim=None if self.D_has_hdim else self.headdim,
|
236 |
+
ngroups=self.ngroups,
|
237 |
+
norm_before_gate=self.norm_before_gate,
|
238 |
+
**dt_limit_kwargs,
|
239 |
+
)
|
240 |
+
if seqlen_og is not None:
|
241 |
+
out = rearrange(out, "b l d -> (b l) d")
|
242 |
+
if self.process_group is not None:
|
243 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
244 |
+
out = reduce_fn(out, self.process_group)
|
245 |
+
else:
|
246 |
+
d_mlp = (
|
247 |
+
zxbcdt.shape[-1]
|
248 |
+
- 2 * self.d_ssm
|
249 |
+
- 2 * self.ngroups * self.d_state
|
250 |
+
- self.nheads
|
251 |
+
) // 2
|
252 |
+
z0, x0, z, xBC, dt = torch.split(
|
253 |
+
zxbcdt,
|
254 |
+
[
|
255 |
+
d_mlp,
|
256 |
+
d_mlp,
|
257 |
+
self.d_ssm,
|
258 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
259 |
+
self.nheads,
|
260 |
+
],
|
261 |
+
dim=-1,
|
262 |
+
)
|
263 |
+
if conv_state is not None:
|
264 |
+
if cu_seqlens is None:
|
265 |
+
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
266 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
267 |
+
xBC_t = rearrange(xBC, "b l d -> b d l")
|
268 |
+
conv_state.copy_(
|
269 |
+
F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
|
270 |
+
) # Update state (B D W)
|
271 |
+
else:
|
272 |
+
assert (
|
273 |
+
causal_conv1d_varlen_states is not None
|
274 |
+
), "varlen inference requires causal_conv1d package"
|
275 |
+
assert (
|
276 |
+
batch == 1
|
277 |
+
), "varlen inference only supports batch dimension 1"
|
278 |
+
conv_varlen_states = causal_conv1d_varlen_states(
|
279 |
+
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
|
280 |
+
)
|
281 |
+
conv_state.copy_(conv_varlen_states)
|
282 |
+
assert self.activation in ["silu", "swish"]
|
283 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
284 |
+
assert (
|
285 |
+
seq_idx is None
|
286 |
+
), "varlen conv1d requires the causal_conv1d package"
|
287 |
+
xBC = self.act(
|
288 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
|
289 |
+
:, : -(self.d_conv - 1)
|
290 |
+
]
|
291 |
+
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
|
292 |
+
else:
|
293 |
+
xBC = causal_conv1d_fn(
|
294 |
+
xBC.transpose(1, 2),
|
295 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
296 |
+
bias=self.conv1d.bias,
|
297 |
+
activation=self.activation,
|
298 |
+
seq_idx=seq_idx,
|
299 |
+
).transpose(1, 2)
|
300 |
+
x, B, C = torch.split(
|
301 |
+
xBC,
|
302 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
303 |
+
dim=-1,
|
304 |
+
)
|
305 |
+
y = mamba_chunk_scan_combined(
|
306 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
307 |
+
dt,
|
308 |
+
A,
|
309 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
310 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
311 |
+
chunk_size=self.chunk_size,
|
312 |
+
D=(
|
313 |
+
rearrange(self.D, "(h p) -> h p", p=self.headdim)
|
314 |
+
if self.D_has_hdim
|
315 |
+
else self.D
|
316 |
+
),
|
317 |
+
z=(
|
318 |
+
rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
|
319 |
+
if not self.rmsnorm
|
320 |
+
else None
|
321 |
+
),
|
322 |
+
dt_bias=self.dt_bias,
|
323 |
+
dt_softplus=True,
|
324 |
+
seq_idx=seq_idx,
|
325 |
+
cu_seqlens=cu_seqlens,
|
326 |
+
**dt_limit_kwargs,
|
327 |
+
return_final_states=ssm_state is not None,
|
328 |
+
return_varlen_states=cu_seqlens is not None
|
329 |
+
and inference_params is not None,
|
330 |
+
)
|
331 |
+
if ssm_state is not None:
|
332 |
+
y, last_state, *rest = y
|
333 |
+
if cu_seqlens is None:
|
334 |
+
ssm_state.copy_(last_state)
|
335 |
+
else:
|
336 |
+
varlen_states = rest[0]
|
337 |
+
ssm_state.copy_(varlen_states)
|
338 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
339 |
+
if self.rmsnorm:
|
340 |
+
y = self.norm(y, z)
|
341 |
+
if d_mlp > 0:
|
342 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
343 |
+
if seqlen_og is not None:
|
344 |
+
y = rearrange(y, "b l d -> (b l) d")
|
345 |
+
out = self.out_proj(y)
|
346 |
+
return out
|
347 |
+
|
348 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
349 |
+
dtype = hidden_states.dtype
|
350 |
+
assert (
|
351 |
+
hidden_states.shape[1] == 1
|
352 |
+
), "Only support decoding with 1 token at a time for now"
|
353 |
+
zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
354 |
+
d_mlp = (
|
355 |
+
zxbcdt.shape[-1]
|
356 |
+
- 2 * self.d_ssm
|
357 |
+
- 2 * self.ngroups * self.d_state
|
358 |
+
- self.nheads
|
359 |
+
) // 2
|
360 |
+
z0, x0, z, xBC, dt = torch.split(
|
361 |
+
zxbcdt,
|
362 |
+
[
|
363 |
+
d_mlp,
|
364 |
+
d_mlp,
|
365 |
+
self.d_ssm,
|
366 |
+
self.d_ssm + 2 * self.ngroups * self.d_state,
|
367 |
+
self.nheads,
|
368 |
+
],
|
369 |
+
dim=-1,
|
370 |
+
)
|
371 |
+
|
372 |
+
# Conv step
|
373 |
+
if causal_conv1d_update is None:
|
374 |
+
conv_state.copy_(
|
375 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
376 |
+
) # Update state (B D W)
|
377 |
+
conv_state[:, :, -1] = xBC
|
378 |
+
xBC = torch.sum(
|
379 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
380 |
+
) # (B D)
|
381 |
+
if self.conv1d.bias is not None:
|
382 |
+
xBC = xBC + self.conv1d.bias
|
383 |
+
xBC = self.act(xBC).to(dtype=dtype)
|
384 |
+
else:
|
385 |
+
xBC = causal_conv1d_update(
|
386 |
+
xBC,
|
387 |
+
conv_state,
|
388 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
389 |
+
self.conv1d.bias,
|
390 |
+
self.activation,
|
391 |
+
)
|
392 |
+
|
393 |
+
x, B, C = torch.split(
|
394 |
+
xBC,
|
395 |
+
[self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
|
396 |
+
dim=-1,
|
397 |
+
)
|
398 |
+
A = -torch.exp(self.A_log.float()) # (nheads,)
|
399 |
+
|
400 |
+
# SSM step
|
401 |
+
if selective_state_update is None:
|
402 |
+
assert (
|
403 |
+
self.ngroups == 1
|
404 |
+
), "Only support ngroups=1 for this inference code path"
|
405 |
+
# Discretize A and B
|
406 |
+
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
|
407 |
+
dA = torch.exp(dt * A) # (batch, nheads)
|
408 |
+
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
409 |
+
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
|
410 |
+
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
|
411 |
+
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
|
412 |
+
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
|
413 |
+
y = rearrange(y, "b h p -> b (h p)")
|
414 |
+
if not self.rmsnorm:
|
415 |
+
y = y * self.act(z) # (B D)
|
416 |
+
else:
|
417 |
+
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
|
418 |
+
dtype=torch.float32
|
419 |
+
)
|
420 |
+
dt = repeat(dt, "b h -> b h p", p=self.headdim)
|
421 |
+
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
|
422 |
+
D = repeat(self.D, "h -> h p", p=self.headdim)
|
423 |
+
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
|
424 |
+
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
|
425 |
+
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
|
426 |
+
if not self.rmsnorm:
|
427 |
+
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
|
428 |
+
y = selective_state_update(
|
429 |
+
ssm_state,
|
430 |
+
x_reshaped,
|
431 |
+
dt,
|
432 |
+
A,
|
433 |
+
B,
|
434 |
+
C,
|
435 |
+
D,
|
436 |
+
z=z if not self.rmsnorm else None,
|
437 |
+
dt_bias=dt_bias,
|
438 |
+
dt_softplus=True,
|
439 |
+
)
|
440 |
+
y = rearrange(y, "b h p -> b (h p)")
|
441 |
+
if self.rmsnorm:
|
442 |
+
y = self.norm(y, z)
|
443 |
+
if d_mlp > 0:
|
444 |
+
y = torch.cat([F.silu(z0) * x0, y], dim=-1)
|
445 |
+
out = self.out_proj(y)
|
446 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
447 |
+
|
448 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
449 |
+
device = self.out_proj.weight.device
|
450 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
451 |
+
conv_state = torch.zeros(
|
452 |
+
batch_size,
|
453 |
+
self.d_conv,
|
454 |
+
self.conv1d.weight.shape[0],
|
455 |
+
device=device,
|
456 |
+
dtype=conv_dtype,
|
457 |
+
).transpose(1, 2)
|
458 |
+
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
|
459 |
+
ssm_state = torch.zeros(
|
460 |
+
batch_size,
|
461 |
+
self.nheads,
|
462 |
+
self.headdim,
|
463 |
+
self.d_state,
|
464 |
+
device=device,
|
465 |
+
dtype=ssm_dtype,
|
466 |
+
)
|
467 |
+
return conv_state, ssm_state
|
468 |
+
|
469 |
+
def _get_states_from_cache(
|
470 |
+
self, inference_params, batch_size, initialize_states=False
|
471 |
+
):
|
472 |
+
assert self.layer_idx is not None
|
473 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
474 |
+
batch_shape = (batch_size,)
|
475 |
+
conv_state = torch.zeros(
|
476 |
+
batch_size,
|
477 |
+
self.d_conv,
|
478 |
+
self.conv1d.weight.shape[0],
|
479 |
+
device=self.conv1d.weight.device,
|
480 |
+
dtype=self.conv1d.weight.dtype,
|
481 |
+
).transpose(1, 2)
|
482 |
+
ssm_state = torch.zeros(
|
483 |
+
batch_size,
|
484 |
+
self.nheads,
|
485 |
+
self.headdim,
|
486 |
+
self.d_state,
|
487 |
+
device=self.in_proj.weight.device,
|
488 |
+
dtype=self.in_proj.weight.dtype,
|
489 |
+
)
|
490 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
491 |
+
conv_state,
|
492 |
+
ssm_state,
|
493 |
+
)
|
494 |
+
else:
|
495 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
496 |
+
self.layer_idx
|
497 |
+
]
|
498 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
499 |
+
if initialize_states:
|
500 |
+
conv_state.zero_()
|
501 |
+
ssm_state.zero_()
|
502 |
+
return conv_state, ssm_state
|
torch-ext/mamba_ssm/modules/mamba2_simple.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
try:
|
11 |
+
from causal_conv1d import causal_conv1d_fn
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
|
17 |
+
except ImportError:
|
18 |
+
RMSNormGated, LayerNorm = None, None
|
19 |
+
|
20 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
21 |
+
from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
|
22 |
+
|
23 |
+
|
24 |
+
class Mamba2Simple(nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
d_model,
|
28 |
+
d_state=64,
|
29 |
+
d_conv=4,
|
30 |
+
conv_init=None,
|
31 |
+
expand=2,
|
32 |
+
headdim=128,
|
33 |
+
ngroups=1,
|
34 |
+
A_init_range=(1, 16),
|
35 |
+
dt_min=0.001,
|
36 |
+
dt_max=0.1,
|
37 |
+
dt_init_floor=1e-4,
|
38 |
+
dt_limit=(0.0, float("inf")),
|
39 |
+
learnable_init_states=False,
|
40 |
+
activation="swish",
|
41 |
+
bias=False,
|
42 |
+
conv_bias=True,
|
43 |
+
# Fused kernel and sharding options
|
44 |
+
chunk_size=256,
|
45 |
+
use_mem_eff_path=True,
|
46 |
+
layer_idx=None, # Absorb kwarg for general module
|
47 |
+
device=None,
|
48 |
+
dtype=None,
|
49 |
+
):
|
50 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
51 |
+
super().__init__()
|
52 |
+
self.d_model = d_model
|
53 |
+
self.d_state = d_state
|
54 |
+
self.d_conv = d_conv
|
55 |
+
self.conv_init = conv_init
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = self.expand * self.d_model
|
58 |
+
self.headdim = headdim
|
59 |
+
self.ngroups = ngroups
|
60 |
+
assert self.d_inner % self.headdim == 0
|
61 |
+
self.nheads = self.d_inner // self.headdim
|
62 |
+
self.dt_limit = dt_limit
|
63 |
+
self.learnable_init_states = learnable_init_states
|
64 |
+
self.activation = activation
|
65 |
+
self.chunk_size = chunk_size
|
66 |
+
self.use_mem_eff_path = use_mem_eff_path
|
67 |
+
self.layer_idx = layer_idx
|
68 |
+
|
69 |
+
# Order: [z, x, B, C, dt]
|
70 |
+
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
|
71 |
+
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
|
72 |
+
|
73 |
+
conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
|
74 |
+
self.conv1d = nn.Conv1d(
|
75 |
+
in_channels=conv_dim,
|
76 |
+
out_channels=conv_dim,
|
77 |
+
bias=conv_bias,
|
78 |
+
kernel_size=d_conv,
|
79 |
+
groups=conv_dim,
|
80 |
+
padding=d_conv - 1,
|
81 |
+
**factory_kwargs,
|
82 |
+
)
|
83 |
+
if self.conv_init is not None:
|
84 |
+
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
|
85 |
+
# self.conv1d.weight._no_weight_decay = True
|
86 |
+
|
87 |
+
if self.learnable_init_states:
|
88 |
+
self.init_states = nn.Parameter(
|
89 |
+
torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
|
90 |
+
)
|
91 |
+
self.init_states._no_weight_decay = True
|
92 |
+
|
93 |
+
self.act = nn.SiLU()
|
94 |
+
|
95 |
+
# Initialize log dt bias
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.nheads, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
)
|
101 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
102 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
103 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
104 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
105 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
106 |
+
# name.endswith("bias") in param_grouping.py
|
107 |
+
self.dt_bias._no_weight_decay = True
|
108 |
+
|
109 |
+
# A parameter
|
110 |
+
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
|
111 |
+
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
|
112 |
+
*A_init_range
|
113 |
+
)
|
114 |
+
A_log = torch.log(A).to(dtype=dtype)
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
# self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
|
117 |
+
self.A_log._no_weight_decay = True
|
118 |
+
|
119 |
+
# D "skip" parameter
|
120 |
+
self.D = nn.Parameter(torch.ones(self.nheads, device=device))
|
121 |
+
self.D._no_weight_decay = True
|
122 |
+
|
123 |
+
# Extra normalization layer right before output projection
|
124 |
+
assert RMSNormGated is not None
|
125 |
+
self.norm = RMSNormGated(
|
126 |
+
self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
|
127 |
+
)
|
128 |
+
|
129 |
+
self.out_proj = nn.Linear(
|
130 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, u, seq_idx=None):
|
134 |
+
"""
|
135 |
+
u: (B, L, D)
|
136 |
+
Returns: same shape as u
|
137 |
+
"""
|
138 |
+
batch, seqlen, dim = u.shape
|
139 |
+
|
140 |
+
zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
|
141 |
+
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
|
142 |
+
initial_states = (
|
143 |
+
repeat(self.init_states, "... -> b ...", b=batch)
|
144 |
+
if self.learnable_init_states
|
145 |
+
else None
|
146 |
+
)
|
147 |
+
dt_limit_kwargs = (
|
148 |
+
{} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
|
149 |
+
)
|
150 |
+
|
151 |
+
if self.use_mem_eff_path:
|
152 |
+
# Fully fused path
|
153 |
+
out = mamba_split_conv1d_scan_combined(
|
154 |
+
zxbcdt,
|
155 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
156 |
+
self.conv1d.bias,
|
157 |
+
self.dt_bias,
|
158 |
+
A,
|
159 |
+
D=self.D,
|
160 |
+
chunk_size=self.chunk_size,
|
161 |
+
seq_idx=seq_idx,
|
162 |
+
activation=self.activation,
|
163 |
+
rmsnorm_weight=self.norm.weight,
|
164 |
+
rmsnorm_eps=self.norm.eps,
|
165 |
+
outproj_weight=self.out_proj.weight,
|
166 |
+
outproj_bias=self.out_proj.bias,
|
167 |
+
headdim=self.headdim,
|
168 |
+
ngroups=self.ngroups,
|
169 |
+
norm_before_gate=False,
|
170 |
+
initial_states=initial_states,
|
171 |
+
**dt_limit_kwargs,
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
z, xBC, dt = torch.split(
|
175 |
+
zxbcdt,
|
176 |
+
[
|
177 |
+
self.d_inner,
|
178 |
+
self.d_inner + 2 * self.ngroups * self.d_state,
|
179 |
+
self.nheads,
|
180 |
+
],
|
181 |
+
dim=-1,
|
182 |
+
)
|
183 |
+
dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
|
186 |
+
# 1D Convolution
|
187 |
+
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
188 |
+
xBC = self.act(
|
189 |
+
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
|
190 |
+
) # (B, L, self.d_inner + 2 * ngroups * d_state)
|
191 |
+
xBC = xBC[:, :seqlen, :]
|
192 |
+
else:
|
193 |
+
xBC = causal_conv1d_fn(
|
194 |
+
x=xBC.transpose(1, 2),
|
195 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
196 |
+
bias=self.conv1d.bias,
|
197 |
+
activation=self.activation,
|
198 |
+
).transpose(1, 2)
|
199 |
+
|
200 |
+
# Split into 3 main branches: X, B, C
|
201 |
+
# These correspond to V, K, Q respectively in the SSM/attention duality
|
202 |
+
x, B, C = torch.split(
|
203 |
+
xBC,
|
204 |
+
[
|
205 |
+
self.d_inner,
|
206 |
+
self.ngroups * self.d_state,
|
207 |
+
self.ngroups * self.d_state,
|
208 |
+
],
|
209 |
+
dim=-1,
|
210 |
+
)
|
211 |
+
y = mamba_chunk_scan_combined(
|
212 |
+
rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
|
213 |
+
dt,
|
214 |
+
A,
|
215 |
+
rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
|
216 |
+
rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
|
217 |
+
chunk_size=self.chunk_size,
|
218 |
+
D=self.D,
|
219 |
+
z=None,
|
220 |
+
seq_idx=seq_idx,
|
221 |
+
initial_states=initial_states,
|
222 |
+
**dt_limit_kwargs,
|
223 |
+
)
|
224 |
+
y = rearrange(y, "b l h p -> b l (h p)")
|
225 |
+
|
226 |
+
# Multiply "gate" branch and apply extra normalization layer
|
227 |
+
y = self.norm(y, z)
|
228 |
+
out = self.out_proj(y)
|
229 |
+
return out
|
torch-ext/mamba_ssm/modules/mamba_simple.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
17 |
+
except ImportError:
|
18 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from ..ops.triton.selective_state_update import selective_state_update
|
22 |
+
except ImportError:
|
23 |
+
selective_state_update = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
27 |
+
except ImportError:
|
28 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
29 |
+
|
30 |
+
|
31 |
+
class Mamba(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
d_model,
|
35 |
+
d_state=16,
|
36 |
+
d_conv=4,
|
37 |
+
expand=2,
|
38 |
+
dt_rank="auto",
|
39 |
+
dt_min=0.001,
|
40 |
+
dt_max=0.1,
|
41 |
+
dt_init="random",
|
42 |
+
dt_scale=1.0,
|
43 |
+
dt_init_floor=1e-4,
|
44 |
+
conv_bias=True,
|
45 |
+
bias=False,
|
46 |
+
use_fast_path=True, # Fused kernel options
|
47 |
+
layer_idx=None,
|
48 |
+
device=None,
|
49 |
+
dtype=None,
|
50 |
+
):
|
51 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
52 |
+
super().__init__()
|
53 |
+
self.d_model = d_model
|
54 |
+
self.d_state = d_state
|
55 |
+
self.d_conv = d_conv
|
56 |
+
self.expand = expand
|
57 |
+
self.d_inner = int(self.expand * self.d_model)
|
58 |
+
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
59 |
+
self.use_fast_path = use_fast_path
|
60 |
+
self.layer_idx = layer_idx
|
61 |
+
|
62 |
+
self.in_proj = nn.Linear(
|
63 |
+
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
|
64 |
+
)
|
65 |
+
|
66 |
+
self.conv1d = nn.Conv1d(
|
67 |
+
in_channels=self.d_inner,
|
68 |
+
out_channels=self.d_inner,
|
69 |
+
bias=conv_bias,
|
70 |
+
kernel_size=d_conv,
|
71 |
+
groups=self.d_inner,
|
72 |
+
padding=d_conv - 1,
|
73 |
+
**factory_kwargs,
|
74 |
+
)
|
75 |
+
|
76 |
+
self.activation = "silu"
|
77 |
+
self.act = nn.SiLU()
|
78 |
+
|
79 |
+
self.x_proj = nn.Linear(
|
80 |
+
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
81 |
+
)
|
82 |
+
self.dt_proj = nn.Linear(
|
83 |
+
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
|
84 |
+
)
|
85 |
+
|
86 |
+
# Initialize special dt projection to preserve variance at initialization
|
87 |
+
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
88 |
+
if dt_init == "constant":
|
89 |
+
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
90 |
+
elif dt_init == "random":
|
91 |
+
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
96 |
+
dt = torch.exp(
|
97 |
+
torch.rand(self.d_inner, **factory_kwargs)
|
98 |
+
* (math.log(dt_max) - math.log(dt_min))
|
99 |
+
+ math.log(dt_min)
|
100 |
+
).clamp(min=dt_init_floor)
|
101 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
102 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
103 |
+
with torch.no_grad():
|
104 |
+
self.dt_proj.bias.copy_(inv_dt)
|
105 |
+
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
106 |
+
self.dt_proj.bias._no_reinit = True
|
107 |
+
|
108 |
+
# S4D real initialization
|
109 |
+
A = repeat(
|
110 |
+
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
111 |
+
"n -> d n",
|
112 |
+
d=self.d_inner,
|
113 |
+
).contiguous()
|
114 |
+
A_log = torch.log(A) # Keep A_log in fp32
|
115 |
+
self.A_log = nn.Parameter(A_log)
|
116 |
+
self.A_log._no_weight_decay = True
|
117 |
+
|
118 |
+
# D "skip" parameter
|
119 |
+
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
120 |
+
self.D._no_weight_decay = True
|
121 |
+
|
122 |
+
self.out_proj = nn.Linear(
|
123 |
+
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, hidden_states, inference_params=None):
|
127 |
+
"""
|
128 |
+
hidden_states: (B, L, D)
|
129 |
+
Returns: same shape as hidden_states
|
130 |
+
"""
|
131 |
+
batch, seqlen, dim = hidden_states.shape
|
132 |
+
|
133 |
+
conv_state, ssm_state = None, None
|
134 |
+
if inference_params is not None:
|
135 |
+
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
136 |
+
if inference_params.seqlen_offset > 0:
|
137 |
+
# The states are updated inplace
|
138 |
+
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
139 |
+
return out
|
140 |
+
|
141 |
+
# We do matmul and transpose BLH -> HBL at the same time
|
142 |
+
xz = rearrange(
|
143 |
+
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
|
144 |
+
"d (b l) -> b d l",
|
145 |
+
l=seqlen,
|
146 |
+
)
|
147 |
+
if self.in_proj.bias is not None:
|
148 |
+
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
|
149 |
+
|
150 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
151 |
+
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
152 |
+
if (
|
153 |
+
self.use_fast_path
|
154 |
+
and causal_conv1d_fn is not None
|
155 |
+
and inference_params is None
|
156 |
+
): # Doesn't support outputting the states
|
157 |
+
out = mamba_inner_fn(
|
158 |
+
xz,
|
159 |
+
self.conv1d.weight,
|
160 |
+
self.conv1d.bias,
|
161 |
+
self.x_proj.weight,
|
162 |
+
self.dt_proj.weight,
|
163 |
+
self.out_proj.weight,
|
164 |
+
self.out_proj.bias,
|
165 |
+
A,
|
166 |
+
None, # input-dependent B
|
167 |
+
None, # input-dependent C
|
168 |
+
self.D.float(),
|
169 |
+
delta_bias=self.dt_proj.bias.float(),
|
170 |
+
delta_softplus=True,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
x, z = xz.chunk(2, dim=1)
|
174 |
+
# Compute short convolution
|
175 |
+
if conv_state is not None:
|
176 |
+
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
177 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
178 |
+
conv_state.copy_(
|
179 |
+
F.pad(x, (self.d_conv - x.shape[-1], 0))
|
180 |
+
) # Update state (B D W)
|
181 |
+
if causal_conv1d_fn is None:
|
182 |
+
x = self.act(self.conv1d(x)[..., :seqlen])
|
183 |
+
else:
|
184 |
+
assert self.activation in ["silu", "swish"]
|
185 |
+
x = causal_conv1d_fn(
|
186 |
+
x=x,
|
187 |
+
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
188 |
+
bias=self.conv1d.bias,
|
189 |
+
activation=self.activation,
|
190 |
+
)
|
191 |
+
|
192 |
+
# We're careful here about the layout, to avoid extra transposes.
|
193 |
+
# We want dt to have d as the slowest moving dimension
|
194 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
195 |
+
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
196 |
+
dt, B, C = torch.split(
|
197 |
+
x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
|
198 |
+
)
|
199 |
+
dt = self.dt_proj.weight @ dt.t()
|
200 |
+
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
201 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
202 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
203 |
+
assert self.activation in ["silu", "swish"]
|
204 |
+
y = selective_scan_fn(
|
205 |
+
x,
|
206 |
+
dt,
|
207 |
+
A,
|
208 |
+
B,
|
209 |
+
C,
|
210 |
+
self.D.float(),
|
211 |
+
z=z,
|
212 |
+
delta_bias=self.dt_proj.bias.float(),
|
213 |
+
delta_softplus=True,
|
214 |
+
return_last_state=ssm_state is not None,
|
215 |
+
)
|
216 |
+
if ssm_state is not None:
|
217 |
+
y, last_state = y
|
218 |
+
ssm_state.copy_(last_state)
|
219 |
+
y = rearrange(y, "b d l -> b l d")
|
220 |
+
out = self.out_proj(y)
|
221 |
+
return out
|
222 |
+
|
223 |
+
def step(self, hidden_states, conv_state, ssm_state):
|
224 |
+
dtype = hidden_states.dtype
|
225 |
+
assert (
|
226 |
+
hidden_states.shape[1] == 1
|
227 |
+
), "Only support decoding with 1 token at a time for now"
|
228 |
+
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
229 |
+
x, z = xz.chunk(2, dim=-1) # (B D)
|
230 |
+
|
231 |
+
# Conv step
|
232 |
+
if causal_conv1d_update is None:
|
233 |
+
conv_state.copy_(
|
234 |
+
torch.roll(conv_state, shifts=-1, dims=-1)
|
235 |
+
) # Update state (B D W)
|
236 |
+
conv_state[:, :, -1] = x
|
237 |
+
x = torch.sum(
|
238 |
+
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
239 |
+
) # (B D)
|
240 |
+
if self.conv1d.bias is not None:
|
241 |
+
x = x + self.conv1d.bias
|
242 |
+
x = self.act(x).to(dtype=dtype)
|
243 |
+
else:
|
244 |
+
x = causal_conv1d_update(
|
245 |
+
x,
|
246 |
+
conv_state,
|
247 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
248 |
+
self.conv1d.bias,
|
249 |
+
self.activation,
|
250 |
+
)
|
251 |
+
|
252 |
+
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
253 |
+
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
254 |
+
# Don't add dt_bias here
|
255 |
+
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
|
256 |
+
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
257 |
+
|
258 |
+
# SSM step
|
259 |
+
if selective_state_update is None:
|
260 |
+
# Discretize A and B
|
261 |
+
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
|
262 |
+
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
263 |
+
dB = torch.einsum("bd,bn->bdn", dt, B)
|
264 |
+
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
265 |
+
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
266 |
+
y = y + self.D.to(dtype) * x
|
267 |
+
y = y * self.act(z) # (B D)
|
268 |
+
else:
|
269 |
+
y = selective_state_update(
|
270 |
+
ssm_state,
|
271 |
+
x,
|
272 |
+
dt,
|
273 |
+
A,
|
274 |
+
B,
|
275 |
+
C,
|
276 |
+
self.D,
|
277 |
+
z=z,
|
278 |
+
dt_bias=self.dt_proj.bias,
|
279 |
+
dt_softplus=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
out = self.out_proj(y)
|
283 |
+
return out.unsqueeze(1), conv_state, ssm_state
|
284 |
+
|
285 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
286 |
+
device = self.out_proj.weight.device
|
287 |
+
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
288 |
+
conv_state = torch.zeros(
|
289 |
+
batch_size,
|
290 |
+
self.d_model * self.expand,
|
291 |
+
self.d_conv,
|
292 |
+
device=device,
|
293 |
+
dtype=conv_dtype,
|
294 |
+
)
|
295 |
+
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
296 |
+
# ssm_dtype = torch.float32
|
297 |
+
ssm_state = torch.zeros(
|
298 |
+
batch_size,
|
299 |
+
self.d_model * self.expand,
|
300 |
+
self.d_state,
|
301 |
+
device=device,
|
302 |
+
dtype=ssm_dtype,
|
303 |
+
)
|
304 |
+
return conv_state, ssm_state
|
305 |
+
|
306 |
+
def _get_states_from_cache(
|
307 |
+
self, inference_params, batch_size, initialize_states=False
|
308 |
+
):
|
309 |
+
assert self.layer_idx is not None
|
310 |
+
if self.layer_idx not in inference_params.key_value_memory_dict:
|
311 |
+
batch_shape = (batch_size,)
|
312 |
+
conv_state = torch.zeros(
|
313 |
+
batch_size,
|
314 |
+
self.d_model * self.expand,
|
315 |
+
self.d_conv,
|
316 |
+
device=self.conv1d.weight.device,
|
317 |
+
dtype=self.conv1d.weight.dtype,
|
318 |
+
)
|
319 |
+
ssm_state = torch.zeros(
|
320 |
+
batch_size,
|
321 |
+
self.d_model * self.expand,
|
322 |
+
self.d_state,
|
323 |
+
device=self.dt_proj.weight.device,
|
324 |
+
dtype=self.dt_proj.weight.dtype,
|
325 |
+
# dtype=torch.float32,
|
326 |
+
)
|
327 |
+
inference_params.key_value_memory_dict[self.layer_idx] = (
|
328 |
+
conv_state,
|
329 |
+
ssm_state,
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
333 |
+
self.layer_idx
|
334 |
+
]
|
335 |
+
# TODO: What if batch size changes between generation, and we reuse the same states?
|
336 |
+
if initialize_states:
|
337 |
+
conv_state.zero_()
|
338 |
+
ssm_state.zero_()
|
339 |
+
return conv_state, ssm_state
|
torch-ext/mamba_ssm/modules/mha.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
try:
|
11 |
+
from flash_attn import flash_attn_with_kvcache
|
12 |
+
except ImportError:
|
13 |
+
flash_attn_with_kvcache = None
|
14 |
+
|
15 |
+
try:
|
16 |
+
from flash_attn.layers.rotary import RotaryEmbedding
|
17 |
+
except ImportError:
|
18 |
+
RotaryEmbedding = None
|
19 |
+
|
20 |
+
try:
|
21 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
22 |
+
except ImportError:
|
23 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
24 |
+
|
25 |
+
|
26 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
27 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
28 |
+
# Pre-allocate memory for key-values for inference.
|
29 |
+
num_heads, head_dim = kv.shape[-2:]
|
30 |
+
assert layer_idx in inference_params.key_value_memory_dict
|
31 |
+
kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
|
32 |
+
# Adjust key and value for inference
|
33 |
+
batch_start = inference_params.batch_size_offset
|
34 |
+
batch_end = batch_start + kv.shape[0]
|
35 |
+
sequence_start = inference_params.seqlen_offset
|
36 |
+
sequence_end = sequence_start + kv.shape[1]
|
37 |
+
assert batch_end <= kv_cache.shape[0]
|
38 |
+
assert sequence_end <= kv_cache.shape[1]
|
39 |
+
assert kv_cache is not None
|
40 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
41 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
42 |
+
|
43 |
+
|
44 |
+
class MHA(nn.Module):
|
45 |
+
"""Multi-head self-attention and cross-attention"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
embed_dim,
|
50 |
+
num_heads,
|
51 |
+
num_heads_kv=None,
|
52 |
+
head_dim=None, # If None, use embed_dim // num_heads
|
53 |
+
mlp_dim=0,
|
54 |
+
qkv_proj_bias=True,
|
55 |
+
out_proj_bias=True,
|
56 |
+
softmax_scale=None,
|
57 |
+
causal=False,
|
58 |
+
layer_idx=None,
|
59 |
+
d_conv=0,
|
60 |
+
rotary_emb_dim=0,
|
61 |
+
rotary_emb_base=10000.0,
|
62 |
+
rotary_emb_interleaved=False,
|
63 |
+
device=None,
|
64 |
+
dtype=None,
|
65 |
+
) -> None:
|
66 |
+
"""
|
67 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
68 |
+
return_residual: whether to return the input x along with the output. This is for
|
69 |
+
performance reason: for post-norm architecture, returning the input allows us
|
70 |
+
to fuse the backward of nn.Linear with the residual connection.
|
71 |
+
"""
|
72 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
73 |
+
super().__init__()
|
74 |
+
self.embed_dim = embed_dim
|
75 |
+
self.layer_idx = layer_idx
|
76 |
+
self.d_conv = d_conv
|
77 |
+
self.rotary_emb_dim = rotary_emb_dim
|
78 |
+
self.softmax_scale = softmax_scale
|
79 |
+
self.causal = causal
|
80 |
+
|
81 |
+
self.num_heads = num_heads
|
82 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
83 |
+
assert (
|
84 |
+
self.num_heads % self.num_heads_kv == 0
|
85 |
+
), "num_heads must be divisible by num_heads_kv"
|
86 |
+
if head_dim is None:
|
87 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
88 |
+
self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
|
89 |
+
self.mlp_dim = math.ceil(mlp_dim / 256) * 256
|
90 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
91 |
+
out_dim = self.head_dim * self.num_heads
|
92 |
+
|
93 |
+
if self.rotary_emb_dim > 0:
|
94 |
+
assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
|
95 |
+
self.rotary_emb = RotaryEmbedding(
|
96 |
+
self.rotary_emb_dim,
|
97 |
+
base=rotary_emb_base,
|
98 |
+
interleaved=rotary_emb_interleaved,
|
99 |
+
device=device,
|
100 |
+
)
|
101 |
+
|
102 |
+
self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
|
103 |
+
if self.d_conv > 0:
|
104 |
+
self.conv1d = nn.Conv1d(
|
105 |
+
qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
|
106 |
+
**factory_kwargs
|
107 |
+
)
|
108 |
+
self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
109 |
+
|
110 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
111 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
112 |
+
device = self.out_proj.weight.device
|
113 |
+
if self.d_conv > 0:
|
114 |
+
conv_state = torch.zeros(
|
115 |
+
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
conv_state = None
|
119 |
+
kv_cache = torch.empty(
|
120 |
+
batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
|
121 |
+
)
|
122 |
+
return kv_cache, conv_state
|
123 |
+
|
124 |
+
def _update_kv_cache(self, kv, inference_params):
|
125 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
126 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
127 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
128 |
+
|
129 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
130 |
+
"""
|
131 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
132 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
133 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
134 |
+
"""
|
135 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
136 |
+
if self.rotary_emb_dim > 0:
|
137 |
+
self.rotary_emb._update_cos_sin_cache(
|
138 |
+
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
139 |
+
)
|
140 |
+
rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
|
141 |
+
else:
|
142 |
+
rotary_cos, rotary_sin = None, None
|
143 |
+
batch = q.shape[0]
|
144 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
145 |
+
kv_cache = kv_cache[:batch]
|
146 |
+
cache_seqlens = (
|
147 |
+
inference_params.lengths_per_sample[:batch]
|
148 |
+
if inference_params.lengths_per_sample is not None
|
149 |
+
else inference_params.seqlen_offset
|
150 |
+
)
|
151 |
+
assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
|
152 |
+
context = flash_attn_with_kvcache(
|
153 |
+
q,
|
154 |
+
kv_cache[:, :, 0],
|
155 |
+
kv_cache[:, :, 1],
|
156 |
+
kv[:, :, 0],
|
157 |
+
kv[:, :, 1],
|
158 |
+
rotary_cos=rotary_cos,
|
159 |
+
rotary_sin=rotary_sin,
|
160 |
+
cache_seqlens=cache_seqlens,
|
161 |
+
softmax_scale=self.softmax_scale,
|
162 |
+
causal=self.causal,
|
163 |
+
rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
164 |
+
)
|
165 |
+
return context
|
166 |
+
|
167 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
168 |
+
"""Write kv to inference_params, then do attention"""
|
169 |
+
if (
|
170 |
+
inference_params.seqlen_offset == 0
|
171 |
+
or flash_attn_with_kvcache is None
|
172 |
+
):
|
173 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
174 |
+
kv = self._update_kv_cache(kv, inference_params)
|
175 |
+
k, v = kv.unbind(dim=-3)
|
176 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
177 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
178 |
+
return F.scaled_dot_product_attention(
|
179 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
180 |
+
).transpose(1, 2)
|
181 |
+
else:
|
182 |
+
batch = q.shape[0]
|
183 |
+
kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
|
184 |
+
kv_cache = kv_cache[:batch]
|
185 |
+
cache_seqlens = (
|
186 |
+
inference_params.lengths_per_sample[:batch]
|
187 |
+
if inference_params.lengths_per_sample is not None
|
188 |
+
else inference_params.seqlen_offset
|
189 |
+
)
|
190 |
+
return flash_attn_with_kvcache(
|
191 |
+
q,
|
192 |
+
kv_cache[:, :, 0],
|
193 |
+
kv_cache[:, :, 1],
|
194 |
+
kv[:, :, 0],
|
195 |
+
kv[:, :, 1],
|
196 |
+
cache_seqlens=cache_seqlens,
|
197 |
+
softmax_scale=self.softmax_scale,
|
198 |
+
causal=self.causal,
|
199 |
+
)
|
200 |
+
|
201 |
+
def forward(self, x, inference_params=None):
|
202 |
+
"""
|
203 |
+
Arguments:
|
204 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
205 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
206 |
+
is the is the sum of the sequence lengths in the batch.
|
207 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
208 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
209 |
+
"""
|
210 |
+
if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
|
211 |
+
inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
|
212 |
+
x.shape[0], inference_params.max_seqlen, dtype=x.dtype
|
213 |
+
)
|
214 |
+
seqlen_offset = (
|
215 |
+
0
|
216 |
+
if inference_params is None
|
217 |
+
else (
|
218 |
+
inference_params.lengths_per_sample
|
219 |
+
if inference_params.lengths_per_sample is not None
|
220 |
+
else inference_params.seqlen_offset
|
221 |
+
)
|
222 |
+
)
|
223 |
+
rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
|
224 |
+
qkv = self.in_proj(x)
|
225 |
+
if self.mlp_dim > 0:
|
226 |
+
qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
|
227 |
+
x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
|
228 |
+
x_mlp = x_mlp_up * F.silu(x_mlp_gate)
|
229 |
+
if self.d_conv > 0:
|
230 |
+
# The inference code for conv1d is pretty messy, should clean it up
|
231 |
+
if (inference_params is None or inference_params.seqlen_offset == 0):
|
232 |
+
if causal_conv1d_fn is None:
|
233 |
+
qkv = rearrange(
|
234 |
+
self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
|
235 |
+
).contiguous()
|
236 |
+
else:
|
237 |
+
qkv = causal_conv1d_fn(
|
238 |
+
qkv.transpose(1, 2),
|
239 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
240 |
+
self.conv1d.bias
|
241 |
+
).transpose(1, 2)
|
242 |
+
if inference_params is not None:
|
243 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
244 |
+
# If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
245 |
+
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
246 |
+
qkv_t = rearrange(qkv, "b l d -> b d l")
|
247 |
+
conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
|
248 |
+
else:
|
249 |
+
_, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
|
250 |
+
assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
|
251 |
+
qkv = qkv.squeeze(1)
|
252 |
+
# Conv step
|
253 |
+
if causal_conv1d_update is None:
|
254 |
+
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
|
255 |
+
conv_state[:, :, -1] = qkv
|
256 |
+
qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
|
257 |
+
if self.conv1d.bias is not None:
|
258 |
+
qkv = qkv + self.conv1d.bias
|
259 |
+
else:
|
260 |
+
qkv = causal_conv1d_update(
|
261 |
+
qkv,
|
262 |
+
conv_state,
|
263 |
+
rearrange(self.conv1d.weight, "d 1 w -> d w"),
|
264 |
+
self.conv1d.bias
|
265 |
+
)
|
266 |
+
qkv = qkv.unsqueeze(1)
|
267 |
+
q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
|
268 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
269 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
270 |
+
if (
|
271 |
+
inference_params is None
|
272 |
+
or inference_params.seqlen_offset == 0
|
273 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
274 |
+
):
|
275 |
+
if self.rotary_emb_dim > 0:
|
276 |
+
q, kv = self.rotary_emb(
|
277 |
+
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
278 |
+
)
|
279 |
+
if inference_params is None:
|
280 |
+
k, v = kv.unbind(dim=-3)
|
281 |
+
k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
282 |
+
v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
|
283 |
+
context = F.scaled_dot_product_attention(
|
284 |
+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
|
285 |
+
).transpose(1, 2)
|
286 |
+
else:
|
287 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
288 |
+
else:
|
289 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
290 |
+
context = rearrange(context, "... h d -> ... (h d)")
|
291 |
+
if self.mlp_dim > 0:
|
292 |
+
context = torch.cat([context, x_mlp], dim=-1)
|
293 |
+
out = self.out_proj(context)
|
294 |
+
return out
|
torch-ext/mamba_ssm/modules/mlp.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class GatedMLP(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
in_features,
|
10 |
+
hidden_features=None,
|
11 |
+
out_features=None,
|
12 |
+
activation=F.silu,
|
13 |
+
bias=False,
|
14 |
+
multiple_of=128,
|
15 |
+
device=None,
|
16 |
+
dtype=None,
|
17 |
+
):
|
18 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features if out_features is not None else in_features
|
21 |
+
hidden_features = (
|
22 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
23 |
+
)
|
24 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
25 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
|
26 |
+
self.activation = activation
|
27 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
y = self.fc1(x)
|
31 |
+
y, gate = y.chunk(2, dim=-1)
|
32 |
+
y = y * self.activation(gate)
|
33 |
+
y = self.fc2(y)
|
34 |
+
return y
|
torch-ext/mamba_ssm/modules/ssd_minimal.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Albert Gu and Tri Dao.
|
2 |
+
"""Minimal implementation of SSD.
|
3 |
+
|
4 |
+
This is the same as Listing 1 from the paper.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
|
12 |
+
|
13 |
+
|
14 |
+
def segsum_unstable(x):
|
15 |
+
"""Naive segment sum calculation."""
|
16 |
+
T = x.size(-1)
|
17 |
+
x_cumsum = torch.cumsum(x, dim=-1)
|
18 |
+
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
|
19 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
20 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
21 |
+
return x_segsum
|
22 |
+
|
23 |
+
|
24 |
+
def segsum(x):
|
25 |
+
"""More stable segment sum calculation."""
|
26 |
+
T = x.size(-1)
|
27 |
+
x = repeat(x, "... d -> ... d e", e=T)
|
28 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
|
29 |
+
x = x.masked_fill(~mask, 0)
|
30 |
+
x_segsum = torch.cumsum(x, dim=-2)
|
31 |
+
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
|
32 |
+
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
|
33 |
+
return x_segsum
|
34 |
+
|
35 |
+
|
36 |
+
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
|
37 |
+
"""
|
38 |
+
Arguments:
|
39 |
+
X: (batch, length, n_heads, d_head)
|
40 |
+
A: (batch, length, n_heads)
|
41 |
+
B: (batch, length, n_heads, d_state)
|
42 |
+
C: (batch, length, n_heads, d_state)
|
43 |
+
Return:
|
44 |
+
Y: (batch, length, n_heads, d_head)
|
45 |
+
"""
|
46 |
+
assert X.dtype == A.dtype == B.dtype == C.dtype
|
47 |
+
assert X.shape[1] % block_len == 0
|
48 |
+
|
49 |
+
# Rearrange into blocks/chunks
|
50 |
+
X, A, B, C = [
|
51 |
+
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
|
52 |
+
]
|
53 |
+
|
54 |
+
A = rearrange(A, "b c l h -> b h c l")
|
55 |
+
A_cumsum = torch.cumsum(A, dim=-1)
|
56 |
+
|
57 |
+
# 1. Compute the output for each intra-chunk (diagonal blocks)
|
58 |
+
L = torch.exp(segsum(A))
|
59 |
+
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
|
60 |
+
|
61 |
+
# 2. Compute the state for each intra-chunk
|
62 |
+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
63 |
+
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
|
64 |
+
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
|
65 |
+
|
66 |
+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
67 |
+
# (middle term of factorization of off-diag blocks; A terms)
|
68 |
+
if initial_states is None:
|
69 |
+
initial_states = torch.zeros_like(states[:, :1])
|
70 |
+
states = torch.cat([initial_states, states], dim=1)
|
71 |
+
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
|
72 |
+
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
|
73 |
+
states, final_state = new_states[:, :-1], new_states[:, -1]
|
74 |
+
|
75 |
+
# 4. Compute state -> output conversion per chunk
|
76 |
+
# (left term of low-rank factorization of off-diagonal blocks; C terms)
|
77 |
+
state_decay_out = torch.exp(A_cumsum)
|
78 |
+
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
|
79 |
+
|
80 |
+
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
|
81 |
+
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
|
82 |
+
return Y, final_state
|
83 |
+
|
84 |
+
|
85 |
+
# Simple test
|
86 |
+
def test_correctness():
|
87 |
+
torch.manual_seed(42)
|
88 |
+
|
89 |
+
## Dimensions
|
90 |
+
# Denoted (B, T, Q, D, P) in the paper
|
91 |
+
batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
|
92 |
+
nheads = dim // headdim # (H) in the paper
|
93 |
+
ngroups = 1 # (G) in the paper
|
94 |
+
dstate = 64 # (N) in the paper
|
95 |
+
dtype = torch.float32
|
96 |
+
device = "cuda"
|
97 |
+
|
98 |
+
x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
|
99 |
+
dt = F.softplus(
|
100 |
+
torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
|
101 |
+
).requires_grad_()
|
102 |
+
A = (
|
103 |
+
-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
|
104 |
+
).requires_grad_()
|
105 |
+
B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
106 |
+
C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
|
107 |
+
D = torch.randn(nheads, dtype=dtype, device=device)
|
108 |
+
|
109 |
+
# Comparing fused version and minimal version
|
110 |
+
y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
|
111 |
+
y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
|
torch-ext/mamba_ssm/ops/__init__.py
ADDED
File without changes
|
torch-ext/mamba_ssm/ops/selective_scan_interface.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from ..utils.torch import custom_fwd, custom_bwd
|
6 |
+
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
try:
|
10 |
+
from causal_conv1d import causal_conv1d_fn
|
11 |
+
import causal_conv1d_cuda
|
12 |
+
except ImportError:
|
13 |
+
causal_conv1d_fn = None
|
14 |
+
causal_conv1d_cuda = None
|
15 |
+
|
16 |
+
from .triton.layer_norm import _layer_norm_fwd
|
17 |
+
|
18 |
+
from .._ops import ops
|
19 |
+
|
20 |
+
|
21 |
+
class SelectiveScanFn(torch.autograd.Function):
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def forward(
|
25 |
+
ctx,
|
26 |
+
u,
|
27 |
+
delta,
|
28 |
+
A,
|
29 |
+
B,
|
30 |
+
C,
|
31 |
+
D=None,
|
32 |
+
z=None,
|
33 |
+
delta_bias=None,
|
34 |
+
delta_softplus=False,
|
35 |
+
return_last_state=False,
|
36 |
+
):
|
37 |
+
if u.stride(-1) != 1:
|
38 |
+
u = u.contiguous()
|
39 |
+
if delta.stride(-1) != 1:
|
40 |
+
delta = delta.contiguous()
|
41 |
+
if D is not None:
|
42 |
+
D = D.contiguous()
|
43 |
+
if B.stride(-1) != 1:
|
44 |
+
B = B.contiguous()
|
45 |
+
if C.stride(-1) != 1:
|
46 |
+
C = C.contiguous()
|
47 |
+
if z is not None and z.stride(-1) != 1:
|
48 |
+
z = z.contiguous()
|
49 |
+
if B.dim() == 3:
|
50 |
+
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
51 |
+
ctx.squeeze_B = True
|
52 |
+
if C.dim() == 3:
|
53 |
+
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
54 |
+
ctx.squeeze_C = True
|
55 |
+
out, x, *rest = ops.selective_scan_fwd(
|
56 |
+
u, delta, A, B, C, D, z, delta_bias, delta_softplus
|
57 |
+
)
|
58 |
+
ctx.delta_softplus = delta_softplus
|
59 |
+
ctx.has_z = z is not None
|
60 |
+
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
61 |
+
if not ctx.has_z:
|
62 |
+
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
63 |
+
return out if not return_last_state else (out, last_state)
|
64 |
+
else:
|
65 |
+
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
66 |
+
out_z = rest[0]
|
67 |
+
return out_z if not return_last_state else (out_z, last_state)
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def backward(ctx, dout, *args):
|
71 |
+
if not ctx.has_z:
|
72 |
+
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
73 |
+
z = None
|
74 |
+
out = None
|
75 |
+
else:
|
76 |
+
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
77 |
+
if dout.stride(-1) != 1:
|
78 |
+
dout = dout.contiguous()
|
79 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
80 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
81 |
+
# Here we just pass in None and dz will be allocated in the C++ code.
|
82 |
+
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
|
83 |
+
u,
|
84 |
+
delta,
|
85 |
+
A,
|
86 |
+
B,
|
87 |
+
C,
|
88 |
+
D,
|
89 |
+
z,
|
90 |
+
delta_bias,
|
91 |
+
dout,
|
92 |
+
x,
|
93 |
+
out,
|
94 |
+
None,
|
95 |
+
ctx.delta_softplus,
|
96 |
+
False, # option to recompute out_z, not used here
|
97 |
+
)
|
98 |
+
dz = rest[0] if ctx.has_z else None
|
99 |
+
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
100 |
+
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
101 |
+
return (
|
102 |
+
du,
|
103 |
+
ddelta,
|
104 |
+
dA,
|
105 |
+
dB,
|
106 |
+
dC,
|
107 |
+
dD if D is not None else None,
|
108 |
+
dz,
|
109 |
+
ddelta_bias if delta_bias is not None else None,
|
110 |
+
None,
|
111 |
+
None,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
def rms_norm_forward(
|
116 |
+
x,
|
117 |
+
weight,
|
118 |
+
bias,
|
119 |
+
eps=1e-6,
|
120 |
+
is_rms_norm=True,
|
121 |
+
):
|
122 |
+
# x (b l) d
|
123 |
+
if x.stride(-1) != 1:
|
124 |
+
x = x.contiguous()
|
125 |
+
weight = weight.contiguous()
|
126 |
+
if bias is not None:
|
127 |
+
bias = bias.contiguous()
|
128 |
+
y = _layer_norm_fwd(
|
129 |
+
x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
|
130 |
+
)[0]
|
131 |
+
# y (b l) d
|
132 |
+
return y
|
133 |
+
|
134 |
+
|
135 |
+
def selective_scan_fn(
|
136 |
+
u,
|
137 |
+
delta,
|
138 |
+
A,
|
139 |
+
B,
|
140 |
+
C,
|
141 |
+
D=None,
|
142 |
+
z=None,
|
143 |
+
delta_bias=None,
|
144 |
+
delta_softplus=False,
|
145 |
+
return_last_state=False,
|
146 |
+
):
|
147 |
+
"""if return_last_state is True, returns (out, last_state)
|
148 |
+
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
149 |
+
not considered in the backward pass.
|
150 |
+
"""
|
151 |
+
return SelectiveScanFn.apply(
|
152 |
+
u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
def selective_scan_ref(
|
157 |
+
u,
|
158 |
+
delta,
|
159 |
+
A,
|
160 |
+
B,
|
161 |
+
C,
|
162 |
+
D=None,
|
163 |
+
z=None,
|
164 |
+
delta_bias=None,
|
165 |
+
delta_softplus=False,
|
166 |
+
return_last_state=False,
|
167 |
+
):
|
168 |
+
"""
|
169 |
+
u: r(B D L)
|
170 |
+
delta: r(B D L)
|
171 |
+
A: c(D N) or r(D N)
|
172 |
+
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
173 |
+
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
174 |
+
D: r(D)
|
175 |
+
z: r(B D L)
|
176 |
+
delta_bias: r(D), fp32
|
177 |
+
|
178 |
+
out: r(B D L)
|
179 |
+
last_state (optional): r(B D dstate) or c(B D dstate)
|
180 |
+
"""
|
181 |
+
dtype_in = u.dtype
|
182 |
+
u = u.float()
|
183 |
+
delta = delta.float()
|
184 |
+
if delta_bias is not None:
|
185 |
+
delta = delta + delta_bias[..., None].float()
|
186 |
+
if delta_softplus:
|
187 |
+
delta = F.softplus(delta)
|
188 |
+
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
189 |
+
is_variable_B = B.dim() >= 3
|
190 |
+
is_variable_C = C.dim() >= 3
|
191 |
+
if A.is_complex():
|
192 |
+
if is_variable_B:
|
193 |
+
B = torch.view_as_complex(
|
194 |
+
rearrange(B.float(), "... (L two) -> ... L two", two=2)
|
195 |
+
)
|
196 |
+
if is_variable_C:
|
197 |
+
C = torch.view_as_complex(
|
198 |
+
rearrange(C.float(), "... (L two) -> ... L two", two=2)
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
B = B.float()
|
202 |
+
C = C.float()
|
203 |
+
x = A.new_zeros((batch, dim, dstate))
|
204 |
+
ys = []
|
205 |
+
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
206 |
+
if not is_variable_B:
|
207 |
+
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
208 |
+
else:
|
209 |
+
if B.dim() == 3:
|
210 |
+
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
211 |
+
else:
|
212 |
+
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
213 |
+
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
214 |
+
if is_variable_C and C.dim() == 4:
|
215 |
+
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
216 |
+
last_state = None
|
217 |
+
for i in range(u.shape[2]):
|
218 |
+
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
219 |
+
if not is_variable_C:
|
220 |
+
y = torch.einsum("bdn,dn->bd", x, C)
|
221 |
+
else:
|
222 |
+
if C.dim() == 3:
|
223 |
+
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
224 |
+
else:
|
225 |
+
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
226 |
+
if i == u.shape[2] - 1:
|
227 |
+
last_state = x
|
228 |
+
if y.is_complex():
|
229 |
+
y = y.real * 2
|
230 |
+
ys.append(y)
|
231 |
+
y = torch.stack(ys, dim=2) # (batch dim L)
|
232 |
+
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
233 |
+
if z is not None:
|
234 |
+
out = out * F.silu(z)
|
235 |
+
out = out.to(dtype=dtype_in)
|
236 |
+
return out if not return_last_state else (out, last_state)
|
237 |
+
|
238 |
+
|
239 |
+
class MambaInnerFn(torch.autograd.Function):
|
240 |
+
|
241 |
+
@staticmethod
|
242 |
+
@custom_fwd
|
243 |
+
def forward(
|
244 |
+
ctx,
|
245 |
+
xz,
|
246 |
+
conv1d_weight,
|
247 |
+
conv1d_bias,
|
248 |
+
x_proj_weight,
|
249 |
+
delta_proj_weight,
|
250 |
+
out_proj_weight,
|
251 |
+
out_proj_bias,
|
252 |
+
A,
|
253 |
+
B=None,
|
254 |
+
C=None,
|
255 |
+
D=None,
|
256 |
+
delta_bias=None,
|
257 |
+
B_proj_bias=None,
|
258 |
+
C_proj_bias=None,
|
259 |
+
delta_softplus=True,
|
260 |
+
checkpoint_lvl=1,
|
261 |
+
b_rms_weight=None,
|
262 |
+
c_rms_weight=None,
|
263 |
+
dt_rms_weight=None,
|
264 |
+
b_c_dt_rms_eps=1e-6,
|
265 |
+
):
|
266 |
+
"""
|
267 |
+
xz: (batch, dim, seqlen)
|
268 |
+
"""
|
269 |
+
assert (
|
270 |
+
causal_conv1d_cuda is not None
|
271 |
+
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
272 |
+
assert checkpoint_lvl in [0, 1]
|
273 |
+
L = xz.shape[-1]
|
274 |
+
delta_rank = delta_proj_weight.shape[1]
|
275 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
276 |
+
if torch.is_autocast_enabled():
|
277 |
+
x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
278 |
+
delta_proj_weight = delta_proj_weight.to(
|
279 |
+
dtype=torch.get_autocast_gpu_dtype()
|
280 |
+
)
|
281 |
+
out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
|
282 |
+
out_proj_bias = (
|
283 |
+
out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
|
284 |
+
if out_proj_bias is not None
|
285 |
+
else None
|
286 |
+
)
|
287 |
+
if xz.stride(-1) != 1:
|
288 |
+
xz = xz.contiguous()
|
289 |
+
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
|
290 |
+
x, z = xz.chunk(2, dim=1)
|
291 |
+
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
|
292 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
293 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
294 |
+
)
|
295 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
296 |
+
# We want delta to have d as the slowest moving dimension
|
297 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
298 |
+
x_dbl = F.linear(
|
299 |
+
rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
|
300 |
+
) # (bl d)
|
301 |
+
delta = rearrange(
|
302 |
+
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
303 |
+
)
|
304 |
+
ctx.is_variable_B = B is None
|
305 |
+
ctx.is_variable_C = C is None
|
306 |
+
ctx.B_proj_bias_is_None = B_proj_bias is None
|
307 |
+
ctx.C_proj_bias_is_None = C_proj_bias is None
|
308 |
+
if B is None: # variable B
|
309 |
+
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
|
310 |
+
if B_proj_bias is not None:
|
311 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
312 |
+
if not A.is_complex():
|
313 |
+
# B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
314 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
315 |
+
else:
|
316 |
+
B = rearrange(
|
317 |
+
B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
318 |
+
).contiguous()
|
319 |
+
else:
|
320 |
+
if B.stride(-1) != 1:
|
321 |
+
B = B.contiguous()
|
322 |
+
if C is None: # variable C
|
323 |
+
C = x_dbl[:, -d_state:] # (bl dstate)
|
324 |
+
if C_proj_bias is not None:
|
325 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
326 |
+
if not A.is_complex():
|
327 |
+
# C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
328 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
329 |
+
else:
|
330 |
+
C = rearrange(
|
331 |
+
C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
|
332 |
+
).contiguous()
|
333 |
+
else:
|
334 |
+
if C.stride(-1) != 1:
|
335 |
+
C = C.contiguous()
|
336 |
+
if D is not None:
|
337 |
+
D = D.contiguous()
|
338 |
+
|
339 |
+
if b_rms_weight is not None:
|
340 |
+
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
341 |
+
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
342 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
343 |
+
if c_rms_weight is not None:
|
344 |
+
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
345 |
+
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
|
346 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
347 |
+
if dt_rms_weight is not None:
|
348 |
+
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
349 |
+
delta = rms_norm_forward(
|
350 |
+
delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
|
351 |
+
)
|
352 |
+
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
353 |
+
|
354 |
+
out, scan_intermediates, out_z = ops.selective_scan_fwd(
|
355 |
+
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
|
356 |
+
)
|
357 |
+
ctx.delta_softplus = delta_softplus
|
358 |
+
ctx.out_proj_bias_is_None = out_proj_bias is None
|
359 |
+
ctx.checkpoint_lvl = checkpoint_lvl
|
360 |
+
ctx.b_rms_weight = b_rms_weight
|
361 |
+
ctx.c_rms_weight = c_rms_weight
|
362 |
+
ctx.dt_rms_weight = dt_rms_weight
|
363 |
+
ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
|
364 |
+
if (
|
365 |
+
checkpoint_lvl >= 1
|
366 |
+
): # Will recompute conv1d_out and delta in the backward pass
|
367 |
+
conv1d_out, delta = None, None
|
368 |
+
ctx.save_for_backward(
|
369 |
+
xz,
|
370 |
+
conv1d_weight,
|
371 |
+
conv1d_bias,
|
372 |
+
x_dbl,
|
373 |
+
x_proj_weight,
|
374 |
+
delta_proj_weight,
|
375 |
+
out_proj_weight,
|
376 |
+
conv1d_out,
|
377 |
+
delta,
|
378 |
+
A,
|
379 |
+
B,
|
380 |
+
C,
|
381 |
+
D,
|
382 |
+
delta_bias,
|
383 |
+
scan_intermediates,
|
384 |
+
b_rms_weight,
|
385 |
+
c_rms_weight,
|
386 |
+
dt_rms_weight,
|
387 |
+
out,
|
388 |
+
)
|
389 |
+
return F.linear(
|
390 |
+
rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
|
391 |
+
)
|
392 |
+
|
393 |
+
@staticmethod
|
394 |
+
@custom_bwd
|
395 |
+
def backward(ctx, dout):
|
396 |
+
# dout: (batch, seqlen, dim)
|
397 |
+
assert (
|
398 |
+
causal_conv1d_cuda is not None
|
399 |
+
), "causal_conv1d_cuda is not available. Please install causal-conv1d."
|
400 |
+
(
|
401 |
+
xz,
|
402 |
+
conv1d_weight,
|
403 |
+
conv1d_bias,
|
404 |
+
x_dbl,
|
405 |
+
x_proj_weight,
|
406 |
+
delta_proj_weight,
|
407 |
+
out_proj_weight,
|
408 |
+
conv1d_out,
|
409 |
+
delta,
|
410 |
+
A,
|
411 |
+
B,
|
412 |
+
C,
|
413 |
+
D,
|
414 |
+
delta_bias,
|
415 |
+
scan_intermediates,
|
416 |
+
b_rms_weight,
|
417 |
+
c_rms_weight,
|
418 |
+
dt_rms_weight,
|
419 |
+
out,
|
420 |
+
) = ctx.saved_tensors
|
421 |
+
L = xz.shape[-1]
|
422 |
+
delta_rank = delta_proj_weight.shape[1]
|
423 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
424 |
+
x, z = xz.chunk(2, dim=1)
|
425 |
+
if dout.stride(-1) != 1:
|
426 |
+
dout = dout.contiguous()
|
427 |
+
if ctx.checkpoint_lvl == 1:
|
428 |
+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
|
429 |
+
x, conv1d_weight, conv1d_bias, None, None, None, True
|
430 |
+
)
|
431 |
+
delta = rearrange(
|
432 |
+
delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
|
433 |
+
)
|
434 |
+
if dt_rms_weight is not None:
|
435 |
+
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
|
436 |
+
delta = rms_norm_forward(
|
437 |
+
delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
|
438 |
+
)
|
439 |
+
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
|
440 |
+
if b_rms_weight is not None:
|
441 |
+
# Recompute & RMSNorm B
|
442 |
+
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
443 |
+
B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
|
444 |
+
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
445 |
+
if c_rms_weight is not None:
|
446 |
+
# Recompute & RMSNorm C
|
447 |
+
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
|
448 |
+
C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
|
449 |
+
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
|
450 |
+
|
451 |
+
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
452 |
+
# backward of selective_scan_cuda with the backward of chunk).
|
453 |
+
dxz = torch.empty_like(xz) # (batch, dim, seqlen)
|
454 |
+
dx, dz = dxz.chunk(2, dim=1)
|
455 |
+
dout = rearrange(dout, "b l e -> e (b l)")
|
456 |
+
dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
|
457 |
+
dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
|
458 |
+
ops.selective_scan_bwd(
|
459 |
+
conv1d_out,
|
460 |
+
delta,
|
461 |
+
A,
|
462 |
+
B,
|
463 |
+
C,
|
464 |
+
D,
|
465 |
+
z,
|
466 |
+
delta_bias,
|
467 |
+
dout_y,
|
468 |
+
scan_intermediates,
|
469 |
+
out,
|
470 |
+
dz,
|
471 |
+
ctx.delta_softplus,
|
472 |
+
True, # option to recompute out_z
|
473 |
+
)
|
474 |
+
)
|
475 |
+
dout_proj_weight = torch.einsum(
|
476 |
+
"eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
|
477 |
+
)
|
478 |
+
dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
|
479 |
+
dD = dD if D is not None else None
|
480 |
+
dx_dbl = torch.empty_like(x_dbl)
|
481 |
+
dB_proj_bias = None
|
482 |
+
if ctx.is_variable_B:
|
483 |
+
if not A.is_complex():
|
484 |
+
dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
|
485 |
+
else:
|
486 |
+
dB = rearrange(
|
487 |
+
dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
488 |
+
).contiguous()
|
489 |
+
dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
|
490 |
+
dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
|
491 |
+
dB = None
|
492 |
+
dC_proj_bias = None
|
493 |
+
if ctx.is_variable_C:
|
494 |
+
if not A.is_complex():
|
495 |
+
dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
|
496 |
+
else:
|
497 |
+
dC = rearrange(
|
498 |
+
dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
|
499 |
+
).contiguous()
|
500 |
+
dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
|
501 |
+
dx_dbl[:, -d_state:] = dC # (bl d)
|
502 |
+
dC = None
|
503 |
+
ddelta = rearrange(ddelta, "b d l -> d (b l)")
|
504 |
+
ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
|
505 |
+
dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
|
506 |
+
dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
|
507 |
+
dx_proj_weight = torch.einsum(
|
508 |
+
"Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
|
509 |
+
)
|
510 |
+
dconv1d_out = torch.addmm(
|
511 |
+
dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
|
512 |
+
)
|
513 |
+
dconv1d_out = rearrange(
|
514 |
+
dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
|
515 |
+
)
|
516 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
517 |
+
# backward of conv1d with the backward of chunk).
|
518 |
+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
|
519 |
+
x,
|
520 |
+
conv1d_weight,
|
521 |
+
conv1d_bias,
|
522 |
+
dconv1d_out,
|
523 |
+
None,
|
524 |
+
None,
|
525 |
+
None,
|
526 |
+
dx,
|
527 |
+
False,
|
528 |
+
True,
|
529 |
+
)
|
530 |
+
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
|
531 |
+
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
|
532 |
+
return (
|
533 |
+
dxz,
|
534 |
+
dconv1d_weight,
|
535 |
+
dconv1d_bias,
|
536 |
+
dx_proj_weight,
|
537 |
+
ddelta_proj_weight,
|
538 |
+
dout_proj_weight,
|
539 |
+
dout_proj_bias,
|
540 |
+
dA,
|
541 |
+
dB,
|
542 |
+
dC,
|
543 |
+
dD,
|
544 |
+
ddelta_bias if delta_bias is not None else None,
|
545 |
+
# 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
|
546 |
+
dB_proj_bias,
|
547 |
+
dC_proj_bias,
|
548 |
+
None,
|
549 |
+
None,
|
550 |
+
None,
|
551 |
+
None,
|
552 |
+
None,
|
553 |
+
None,
|
554 |
+
)
|
555 |
+
|
556 |
+
|
557 |
+
def mamba_inner_fn(
|
558 |
+
xz,
|
559 |
+
conv1d_weight,
|
560 |
+
conv1d_bias,
|
561 |
+
x_proj_weight,
|
562 |
+
delta_proj_weight,
|
563 |
+
out_proj_weight,
|
564 |
+
out_proj_bias,
|
565 |
+
A,
|
566 |
+
B=None,
|
567 |
+
C=None,
|
568 |
+
D=None,
|
569 |
+
delta_bias=None,
|
570 |
+
B_proj_bias=None,
|
571 |
+
C_proj_bias=None,
|
572 |
+
delta_softplus=True,
|
573 |
+
checkpoint_lvl=1,
|
574 |
+
b_rms_weight=None,
|
575 |
+
c_rms_weight=None,
|
576 |
+
dt_rms_weight=None,
|
577 |
+
b_c_dt_rms_eps=1e-6,
|
578 |
+
):
|
579 |
+
return MambaInnerFn.apply(
|
580 |
+
xz,
|
581 |
+
conv1d_weight,
|
582 |
+
conv1d_bias,
|
583 |
+
x_proj_weight,
|
584 |
+
delta_proj_weight,
|
585 |
+
out_proj_weight,
|
586 |
+
out_proj_bias,
|
587 |
+
A,
|
588 |
+
B,
|
589 |
+
C,
|
590 |
+
D,
|
591 |
+
delta_bias,
|
592 |
+
B_proj_bias,
|
593 |
+
C_proj_bias,
|
594 |
+
delta_softplus,
|
595 |
+
checkpoint_lvl,
|
596 |
+
b_rms_weight,
|
597 |
+
c_rms_weight,
|
598 |
+
dt_rms_weight,
|
599 |
+
b_c_dt_rms_eps,
|
600 |
+
)
|
601 |
+
|
602 |
+
|
603 |
+
def mamba_inner_ref(
|
604 |
+
xz,
|
605 |
+
conv1d_weight,
|
606 |
+
conv1d_bias,
|
607 |
+
x_proj_weight,
|
608 |
+
delta_proj_weight,
|
609 |
+
out_proj_weight,
|
610 |
+
out_proj_bias,
|
611 |
+
A,
|
612 |
+
B=None,
|
613 |
+
C=None,
|
614 |
+
D=None,
|
615 |
+
delta_bias=None,
|
616 |
+
B_proj_bias=None,
|
617 |
+
C_proj_bias=None,
|
618 |
+
delta_softplus=True,
|
619 |
+
):
|
620 |
+
assert (
|
621 |
+
causal_conv1d_fn is not None
|
622 |
+
), "causal_conv1d_fn is not available. Please install causal-conv1d."
|
623 |
+
L = xz.shape[-1]
|
624 |
+
delta_rank = delta_proj_weight.shape[1]
|
625 |
+
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
|
626 |
+
x, z = xz.chunk(2, dim=1)
|
627 |
+
x = causal_conv1d_fn(
|
628 |
+
x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
|
629 |
+
)
|
630 |
+
# We're being very careful here about the layout, to avoid extra transposes.
|
631 |
+
# We want delta to have d as the slowest moving dimension
|
632 |
+
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
633 |
+
x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
|
634 |
+
delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
|
635 |
+
delta = rearrange(delta, "d (b l) -> b d l", l=L)
|
636 |
+
if B is None: # variable B
|
637 |
+
B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
|
638 |
+
if B_proj_bias is not None:
|
639 |
+
B = B + B_proj_bias.to(dtype=B.dtype)
|
640 |
+
if not A.is_complex():
|
641 |
+
B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
|
642 |
+
else:
|
643 |
+
B = rearrange(
|
644 |
+
B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
645 |
+
).contiguous()
|
646 |
+
if C is None: # variable B
|
647 |
+
C = x_dbl[:, -d_state:] # (bl d)
|
648 |
+
if C_proj_bias is not None:
|
649 |
+
C = C + C_proj_bias.to(dtype=C.dtype)
|
650 |
+
if not A.is_complex():
|
651 |
+
C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
|
652 |
+
else:
|
653 |
+
C = rearrange(
|
654 |
+
C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
|
655 |
+
).contiguous()
|
656 |
+
y = selective_scan_fn(
|
657 |
+
x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
|
658 |
+
)
|
659 |
+
return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
|
torch-ext/mamba_ssm/ops/triton/__init__.py
ADDED
File without changes
|
torch-ext/mamba_ssm/ops/triton/k_activations.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
|
9 |
+
@triton.autotune(
|
10 |
+
configs=[
|
11 |
+
triton.Config({'BLOCK_N': 32}),
|
12 |
+
triton.Config({'BLOCK_N': 64}),
|
13 |
+
triton.Config({'BLOCK_N': 128}),
|
14 |
+
triton.Config({'BLOCK_N': 256}),
|
15 |
+
triton.Config({'BLOCK_N': 512}),
|
16 |
+
triton.Config({'BLOCK_N': 1024}),
|
17 |
+
],
|
18 |
+
key=['ncols'],
|
19 |
+
)
|
20 |
+
@triton.jit
|
21 |
+
def _swiglu_fwd_kernel(
|
22 |
+
X,
|
23 |
+
Y,
|
24 |
+
OUT,
|
25 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
26 |
+
stride_y_row,
|
27 |
+
stride_out_row,
|
28 |
+
ncols,
|
29 |
+
BLOCK_N: tl.constexpr,
|
30 |
+
):
|
31 |
+
# Map the program id to the row of X and Y it should compute.
|
32 |
+
row = tl.program_id(0)
|
33 |
+
start_col = tl.program_id(1) * BLOCK_N
|
34 |
+
X += row * stride_x_row
|
35 |
+
Y += row * stride_y_row
|
36 |
+
OUT += row * stride_out_row
|
37 |
+
cols = start_col + tl.arange(0, BLOCK_N)
|
38 |
+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
39 |
+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
40 |
+
out = x * tl.sigmoid(x) * y
|
41 |
+
tl.store(OUT + cols, out, mask=cols < ncols)
|
42 |
+
|
43 |
+
|
44 |
+
def _swiglu_fwd(xy, out=None):
|
45 |
+
if xy.stride(-1) != 1:
|
46 |
+
xy = xy.contiguous()
|
47 |
+
batch_shape = xy.shape[:-1]
|
48 |
+
xy = xy.reshape(-1, xy.shape[-1])
|
49 |
+
x, y = xy.chunk(2, dim=-1)
|
50 |
+
if out is None:
|
51 |
+
out = torch.empty_like(x)
|
52 |
+
else:
|
53 |
+
out = out.reshape(-1, out.shape[-1])
|
54 |
+
assert out.shape == x.shape
|
55 |
+
assert out.stride(-1) == 1
|
56 |
+
M, N = x.shape
|
57 |
+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
58 |
+
with torch.cuda.device(x.device.index):
|
59 |
+
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
|
60 |
+
return out.reshape(*batch_shape, out.shape[-1])
|
61 |
+
|
62 |
+
|
63 |
+
@triton.autotune(
|
64 |
+
configs=[
|
65 |
+
triton.Config({'BLOCK_N': 32}),
|
66 |
+
triton.Config({'BLOCK_N': 64}),
|
67 |
+
triton.Config({'BLOCK_N': 128}),
|
68 |
+
triton.Config({'BLOCK_N': 256}),
|
69 |
+
triton.Config({'BLOCK_N': 512}),
|
70 |
+
triton.Config({'BLOCK_N': 1024}),
|
71 |
+
],
|
72 |
+
key=['ncols'],
|
73 |
+
)
|
74 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
|
75 |
+
@triton.jit
|
76 |
+
def _swiglu_bwd_kernel(
|
77 |
+
X,
|
78 |
+
Y,
|
79 |
+
DOUT,
|
80 |
+
OUT,
|
81 |
+
DX,
|
82 |
+
DY,
|
83 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
84 |
+
stride_y_row,
|
85 |
+
stride_dout_row,
|
86 |
+
stride_out_row,
|
87 |
+
stride_dx_row,
|
88 |
+
stride_dy_row,
|
89 |
+
ncols,
|
90 |
+
BLOCK_N: tl.constexpr,
|
91 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
92 |
+
):
|
93 |
+
# Map the program id to the row of X and Y it should compute.
|
94 |
+
row = tl.program_id(0)
|
95 |
+
start_col = tl.program_id(1) * BLOCK_N
|
96 |
+
X += row * stride_x_row
|
97 |
+
Y += row * stride_y_row
|
98 |
+
DOUT += row * stride_dout_row
|
99 |
+
if RECOMPUTE_OUTPUT:
|
100 |
+
OUT += row * stride_out_row
|
101 |
+
DX += row * stride_dx_row
|
102 |
+
DY += row * stride_dy_row
|
103 |
+
cols = start_col + tl.arange(0, BLOCK_N)
|
104 |
+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
105 |
+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
106 |
+
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
|
107 |
+
x_sigmoid = tl.sigmoid(x)
|
108 |
+
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
|
109 |
+
dy = x * x_sigmoid * dout
|
110 |
+
tl.store(DX + cols, dx, mask=cols < ncols)
|
111 |
+
tl.store(DY + cols, dy, mask=cols < ncols)
|
112 |
+
if RECOMPUTE_OUTPUT:
|
113 |
+
out = x * x_sigmoid * y
|
114 |
+
tl.store(OUT + cols, out, mask=cols < ncols)
|
115 |
+
|
116 |
+
|
117 |
+
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
|
118 |
+
if xy.stride(-1) != 1:
|
119 |
+
xy = xy.contiguous()
|
120 |
+
if dout.stride(-1) != 1:
|
121 |
+
dout = dout.contiguous()
|
122 |
+
batch_shape = xy.shape[:-1]
|
123 |
+
xy = xy.reshape(-1, xy.shape[-1])
|
124 |
+
x, y = xy.chunk(2, dim=-1)
|
125 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
126 |
+
assert dout.shape == x.shape
|
127 |
+
if dxy is None:
|
128 |
+
dxy = torch.empty_like(xy)
|
129 |
+
else:
|
130 |
+
dxy = dxy.reshape(-1, dxy.shape[-1])
|
131 |
+
assert dxy.shape == xy.shape
|
132 |
+
dx, dy = dxy.chunk(2, dim=-1)
|
133 |
+
assert dx.stride(-1) == 1
|
134 |
+
assert dy.stride(-1) == 1
|
135 |
+
if recompute_output:
|
136 |
+
if out is None:
|
137 |
+
out = torch.empty_like(x)
|
138 |
+
else:
|
139 |
+
out = out.reshape(-1, out.shape[-1])
|
140 |
+
assert out.shape == x.shape
|
141 |
+
assert out.stride(-1) == 1
|
142 |
+
M, N = x.shape
|
143 |
+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
|
144 |
+
with torch.cuda.device(x.device.index):
|
145 |
+
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
|
146 |
+
x.stride(0), y.stride(0), dout.stride(0),
|
147 |
+
out.stride(0) if recompute_output else 0,
|
148 |
+
dx.stride(0), dy.stride(0),
|
149 |
+
N)
|
150 |
+
if not recompute_output:
|
151 |
+
return dxy.reshape(*batch_shape, dxy.shape[-1])
|
152 |
+
else:
|
153 |
+
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
|
154 |
+
|
155 |
+
|
156 |
+
class SwiGLU(torch.autograd.Function):
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def forward(ctx, xy):
|
160 |
+
ctx.save_for_backward(xy)
|
161 |
+
return _swiglu_fwd(xy)
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def backward(ctx, dout):
|
165 |
+
xy, = ctx.saved_tensors
|
166 |
+
return _swiglu_bwd(xy, dout)
|
167 |
+
|
168 |
+
|
169 |
+
swiglu = SwiGLU.apply
|
torch-ext/mamba_ssm/ops/triton/layer_norm.py
ADDED
@@ -0,0 +1,1166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Implement dropout + residual + layer_norm / rms_norm.
|
3 |
+
|
4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
+
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from ...utils.torch import custom_bwd, custom_fwd
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
|
19 |
+
|
20 |
+
def layer_norm_ref(
|
21 |
+
x,
|
22 |
+
weight,
|
23 |
+
bias,
|
24 |
+
residual=None,
|
25 |
+
x1=None,
|
26 |
+
weight1=None,
|
27 |
+
bias1=None,
|
28 |
+
eps=1e-6,
|
29 |
+
dropout_p=0.0,
|
30 |
+
rowscale=None,
|
31 |
+
prenorm=False,
|
32 |
+
dropout_mask=None,
|
33 |
+
dropout_mask1=None,
|
34 |
+
upcast=False,
|
35 |
+
):
|
36 |
+
dtype = x.dtype
|
37 |
+
if upcast:
|
38 |
+
x = x.float()
|
39 |
+
weight = weight.float()
|
40 |
+
bias = bias.float() if bias is not None else None
|
41 |
+
residual = residual.float() if residual is not None else residual
|
42 |
+
x1 = x1.float() if x1 is not None else None
|
43 |
+
weight1 = weight1.float() if weight1 is not None else None
|
44 |
+
bias1 = bias1.float() if bias1 is not None else None
|
45 |
+
if x1 is not None:
|
46 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
47 |
+
if rowscale is not None:
|
48 |
+
x = x * rowscale[..., None]
|
49 |
+
if dropout_p > 0.0:
|
50 |
+
if dropout_mask is not None:
|
51 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
52 |
+
else:
|
53 |
+
x = F.dropout(x, p=dropout_p)
|
54 |
+
if x1 is not None:
|
55 |
+
if dropout_mask1 is not None:
|
56 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
57 |
+
else:
|
58 |
+
x1 = F.dropout(x1, p=dropout_p)
|
59 |
+
if x1 is not None:
|
60 |
+
x = x + x1
|
61 |
+
if residual is not None:
|
62 |
+
x = (x + residual).to(x.dtype)
|
63 |
+
out = F.layer_norm(
|
64 |
+
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
65 |
+
).to(dtype)
|
66 |
+
if weight1 is None:
|
67 |
+
return out if not prenorm else (out, x)
|
68 |
+
else:
|
69 |
+
out1 = F.layer_norm(
|
70 |
+
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
71 |
+
).to(dtype)
|
72 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
73 |
+
|
74 |
+
|
75 |
+
def rms_norm_ref(
|
76 |
+
x,
|
77 |
+
weight,
|
78 |
+
bias,
|
79 |
+
residual=None,
|
80 |
+
x1=None,
|
81 |
+
weight1=None,
|
82 |
+
bias1=None,
|
83 |
+
eps=1e-6,
|
84 |
+
dropout_p=0.0,
|
85 |
+
rowscale=None,
|
86 |
+
prenorm=False,
|
87 |
+
dropout_mask=None,
|
88 |
+
dropout_mask1=None,
|
89 |
+
upcast=False,
|
90 |
+
):
|
91 |
+
dtype = x.dtype
|
92 |
+
if upcast:
|
93 |
+
x = x.float()
|
94 |
+
weight = weight.float()
|
95 |
+
bias = bias.float() if bias is not None else None
|
96 |
+
residual = residual.float() if residual is not None else residual
|
97 |
+
x1 = x1.float() if x1 is not None else None
|
98 |
+
weight1 = weight1.float() if weight1 is not None else None
|
99 |
+
bias1 = bias1.float() if bias1 is not None else None
|
100 |
+
if x1 is not None:
|
101 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
102 |
+
if rowscale is not None:
|
103 |
+
x = x * rowscale[..., None]
|
104 |
+
if dropout_p > 0.0:
|
105 |
+
if dropout_mask is not None:
|
106 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
107 |
+
else:
|
108 |
+
x = F.dropout(x, p=dropout_p)
|
109 |
+
if x1 is not None:
|
110 |
+
if dropout_mask1 is not None:
|
111 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
112 |
+
else:
|
113 |
+
x1 = F.dropout(x1, p=dropout_p)
|
114 |
+
if x1 is not None:
|
115 |
+
x = x + x1
|
116 |
+
if residual is not None:
|
117 |
+
x = (x + residual).to(x.dtype)
|
118 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
119 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
120 |
+
dtype
|
121 |
+
)
|
122 |
+
if weight1 is None:
|
123 |
+
return out if not prenorm else (out, x)
|
124 |
+
else:
|
125 |
+
out1 = (
|
126 |
+
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
127 |
+
).to(dtype)
|
128 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
129 |
+
|
130 |
+
|
131 |
+
def config_prune(configs):
|
132 |
+
|
133 |
+
if torch.version.hip:
|
134 |
+
try:
|
135 |
+
# set warp size based on gcn architecure
|
136 |
+
gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
|
137 |
+
if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
|
138 |
+
# radeon
|
139 |
+
warp_size = 32
|
140 |
+
else:
|
141 |
+
# instinct
|
142 |
+
warp_size = 64
|
143 |
+
except AttributeError as e:
|
144 |
+
# fall back to crude method to set warp size
|
145 |
+
device_name = torch.cuda.get_device_properties(0).name
|
146 |
+
if "instinct" in device_name.lower():
|
147 |
+
warp_size = 64
|
148 |
+
else:
|
149 |
+
warp_size = 32
|
150 |
+
warnings.warn(
|
151 |
+
f"{e}, warp size set to {warp_size} based on device name: {device_name}",
|
152 |
+
UserWarning,
|
153 |
+
)
|
154 |
+
|
155 |
+
else:
|
156 |
+
# cuda
|
157 |
+
warp_size = 32
|
158 |
+
|
159 |
+
max_block_sz = 1024
|
160 |
+
max_num_warps = max_block_sz // warp_size
|
161 |
+
pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
|
162 |
+
return pruned_configs
|
163 |
+
|
164 |
+
|
165 |
+
configs_autotune = [
|
166 |
+
triton.Config({}, num_warps=1),
|
167 |
+
triton.Config({}, num_warps=2),
|
168 |
+
triton.Config({}, num_warps=4),
|
169 |
+
triton.Config({}, num_warps=8),
|
170 |
+
triton.Config({}, num_warps=16),
|
171 |
+
triton.Config({}, num_warps=32),
|
172 |
+
]
|
173 |
+
|
174 |
+
pruned_configs_autotune = config_prune(configs_autotune)
|
175 |
+
|
176 |
+
|
177 |
+
@triton.autotune(
|
178 |
+
configs=pruned_configs_autotune,
|
179 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
180 |
+
)
|
181 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
182 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
183 |
+
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
184 |
+
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
185 |
+
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
186 |
+
@triton.jit
|
187 |
+
def _layer_norm_fwd_1pass_kernel(
|
188 |
+
X, # pointer to the input
|
189 |
+
Y, # pointer to the output
|
190 |
+
W, # pointer to the weights
|
191 |
+
B, # pointer to the biases
|
192 |
+
RESIDUAL, # pointer to the residual
|
193 |
+
X1,
|
194 |
+
W1,
|
195 |
+
B1,
|
196 |
+
Y1,
|
197 |
+
RESIDUAL_OUT, # pointer to the residual
|
198 |
+
ROWSCALE,
|
199 |
+
SEEDS, # Dropout seeds for each row
|
200 |
+
DROPOUT_MASK,
|
201 |
+
Mean, # pointer to the mean
|
202 |
+
Rstd, # pointer to the 1/std
|
203 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
204 |
+
stride_y_row,
|
205 |
+
stride_res_row,
|
206 |
+
stride_res_out_row,
|
207 |
+
stride_x1_row,
|
208 |
+
stride_y1_row,
|
209 |
+
M, # number of rows in X
|
210 |
+
N, # number of columns in X
|
211 |
+
eps, # epsilon to avoid division by zero
|
212 |
+
dropout_p, # Dropout probability
|
213 |
+
IS_RMS_NORM: tl.constexpr,
|
214 |
+
BLOCK_N: tl.constexpr,
|
215 |
+
HAS_RESIDUAL: tl.constexpr,
|
216 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
217 |
+
HAS_BIAS: tl.constexpr,
|
218 |
+
HAS_DROPOUT: tl.constexpr,
|
219 |
+
STORE_DROPOUT_MASK: tl.constexpr,
|
220 |
+
HAS_ROWSCALE: tl.constexpr,
|
221 |
+
HAS_X1: tl.constexpr,
|
222 |
+
HAS_W1: tl.constexpr,
|
223 |
+
HAS_B1: tl.constexpr,
|
224 |
+
):
|
225 |
+
# Map the program id to the row of X and Y it should compute.
|
226 |
+
row = tl.program_id(0)
|
227 |
+
X += row * stride_x_row
|
228 |
+
Y += row * stride_y_row
|
229 |
+
if HAS_RESIDUAL:
|
230 |
+
RESIDUAL += row * stride_res_row
|
231 |
+
if STORE_RESIDUAL_OUT:
|
232 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
233 |
+
if HAS_X1:
|
234 |
+
X1 += row * stride_x1_row
|
235 |
+
if HAS_W1:
|
236 |
+
Y1 += row * stride_y1_row
|
237 |
+
# Compute mean and variance
|
238 |
+
cols = tl.arange(0, BLOCK_N)
|
239 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
240 |
+
if HAS_ROWSCALE:
|
241 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
242 |
+
x *= rowscale
|
243 |
+
if HAS_DROPOUT:
|
244 |
+
# Compute dropout mask
|
245 |
+
# 7 rounds is good enough, and reduces register pressure
|
246 |
+
keep_mask = (
|
247 |
+
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
248 |
+
)
|
249 |
+
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
250 |
+
if STORE_DROPOUT_MASK:
|
251 |
+
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
252 |
+
if HAS_X1:
|
253 |
+
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
254 |
+
if HAS_ROWSCALE:
|
255 |
+
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
256 |
+
x1 *= rowscale
|
257 |
+
if HAS_DROPOUT:
|
258 |
+
# Compute dropout mask
|
259 |
+
# 7 rounds is good enough, and reduces register pressure
|
260 |
+
keep_mask = (
|
261 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
262 |
+
> dropout_p
|
263 |
+
)
|
264 |
+
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
265 |
+
if STORE_DROPOUT_MASK:
|
266 |
+
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
267 |
+
x += x1
|
268 |
+
if HAS_RESIDUAL:
|
269 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
270 |
+
x += residual
|
271 |
+
if STORE_RESIDUAL_OUT:
|
272 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
273 |
+
if not IS_RMS_NORM:
|
274 |
+
mean = tl.sum(x, axis=0) / N
|
275 |
+
tl.store(Mean + row, mean)
|
276 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
277 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
278 |
+
else:
|
279 |
+
xbar = tl.where(cols < N, x, 0.0)
|
280 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
281 |
+
rstd = 1 / tl.sqrt(var + eps)
|
282 |
+
tl.store(Rstd + row, rstd)
|
283 |
+
# Normalize and apply linear transformation
|
284 |
+
mask = cols < N
|
285 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
286 |
+
if HAS_BIAS:
|
287 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
288 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
289 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
290 |
+
# Write output
|
291 |
+
tl.store(Y + cols, y, mask=mask)
|
292 |
+
if HAS_W1:
|
293 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
294 |
+
if HAS_B1:
|
295 |
+
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
296 |
+
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
297 |
+
tl.store(Y1 + cols, y1, mask=mask)
|
298 |
+
|
299 |
+
|
300 |
+
def _layer_norm_fwd(
|
301 |
+
x,
|
302 |
+
weight,
|
303 |
+
bias,
|
304 |
+
eps,
|
305 |
+
residual=None,
|
306 |
+
x1=None,
|
307 |
+
weight1=None,
|
308 |
+
bias1=None,
|
309 |
+
dropout_p=0.0,
|
310 |
+
rowscale=None,
|
311 |
+
out_dtype=None,
|
312 |
+
residual_dtype=None,
|
313 |
+
is_rms_norm=False,
|
314 |
+
return_dropout_mask=False,
|
315 |
+
):
|
316 |
+
if residual is not None:
|
317 |
+
residual_dtype = residual.dtype
|
318 |
+
M, N = x.shape
|
319 |
+
assert x.stride(-1) == 1
|
320 |
+
if residual is not None:
|
321 |
+
assert residual.stride(-1) == 1
|
322 |
+
assert residual.shape == (M, N)
|
323 |
+
assert weight.shape == (N,)
|
324 |
+
assert weight.stride(-1) == 1
|
325 |
+
if bias is not None:
|
326 |
+
assert bias.stride(-1) == 1
|
327 |
+
assert bias.shape == (N,)
|
328 |
+
if x1 is not None:
|
329 |
+
assert x1.shape == x.shape
|
330 |
+
assert rowscale is None
|
331 |
+
assert x1.stride(-1) == 1
|
332 |
+
if weight1 is not None:
|
333 |
+
assert weight1.shape == (N,)
|
334 |
+
assert weight1.stride(-1) == 1
|
335 |
+
if bias1 is not None:
|
336 |
+
assert bias1.shape == (N,)
|
337 |
+
assert bias1.stride(-1) == 1
|
338 |
+
if rowscale is not None:
|
339 |
+
assert rowscale.is_contiguous()
|
340 |
+
assert rowscale.shape == (M,)
|
341 |
+
# allocate output
|
342 |
+
y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
343 |
+
assert y.stride(-1) == 1
|
344 |
+
if weight1 is not None:
|
345 |
+
y1 = torch.empty_like(y)
|
346 |
+
assert y1.stride(-1) == 1
|
347 |
+
else:
|
348 |
+
y1 = None
|
349 |
+
if (
|
350 |
+
residual is not None
|
351 |
+
or (residual_dtype is not None and residual_dtype != x.dtype)
|
352 |
+
or dropout_p > 0.0
|
353 |
+
or rowscale is not None
|
354 |
+
or x1 is not None
|
355 |
+
):
|
356 |
+
residual_out = torch.empty(
|
357 |
+
M,
|
358 |
+
N,
|
359 |
+
device=x.device,
|
360 |
+
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
361 |
+
)
|
362 |
+
assert residual_out.stride(-1) == 1
|
363 |
+
else:
|
364 |
+
residual_out = None
|
365 |
+
mean = (
|
366 |
+
torch.empty((M,), dtype=torch.float32, device=x.device)
|
367 |
+
if not is_rms_norm
|
368 |
+
else None
|
369 |
+
)
|
370 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
371 |
+
if dropout_p > 0.0:
|
372 |
+
seeds = torch.randint(
|
373 |
+
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
374 |
+
)
|
375 |
+
else:
|
376 |
+
seeds = None
|
377 |
+
if return_dropout_mask and dropout_p > 0.0:
|
378 |
+
dropout_mask = torch.empty(
|
379 |
+
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
dropout_mask = None
|
383 |
+
# Less than 64KB per feature: enqueue fused kernel
|
384 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
385 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
386 |
+
if N > BLOCK_N:
|
387 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
388 |
+
with torch.cuda.device(x.device.index):
|
389 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
390 |
+
x,
|
391 |
+
y,
|
392 |
+
weight,
|
393 |
+
bias,
|
394 |
+
residual,
|
395 |
+
x1,
|
396 |
+
weight1,
|
397 |
+
bias1,
|
398 |
+
y1,
|
399 |
+
residual_out,
|
400 |
+
rowscale,
|
401 |
+
seeds,
|
402 |
+
dropout_mask,
|
403 |
+
mean,
|
404 |
+
rstd,
|
405 |
+
x.stride(0),
|
406 |
+
y.stride(0),
|
407 |
+
residual.stride(0) if residual is not None else 0,
|
408 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
409 |
+
x1.stride(0) if x1 is not None else 0,
|
410 |
+
y1.stride(0) if y1 is not None else 0,
|
411 |
+
M,
|
412 |
+
N,
|
413 |
+
eps,
|
414 |
+
dropout_p,
|
415 |
+
is_rms_norm,
|
416 |
+
BLOCK_N,
|
417 |
+
residual is not None,
|
418 |
+
residual_out is not None,
|
419 |
+
bias is not None,
|
420 |
+
dropout_p > 0.0,
|
421 |
+
dropout_mask is not None,
|
422 |
+
rowscale is not None,
|
423 |
+
)
|
424 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
425 |
+
if dropout_mask is not None and x1 is not None:
|
426 |
+
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
427 |
+
else:
|
428 |
+
dropout_mask1 = None
|
429 |
+
return (
|
430 |
+
y,
|
431 |
+
y1,
|
432 |
+
mean,
|
433 |
+
rstd,
|
434 |
+
residual_out if residual_out is not None else x,
|
435 |
+
seeds,
|
436 |
+
dropout_mask,
|
437 |
+
dropout_mask1,
|
438 |
+
)
|
439 |
+
|
440 |
+
|
441 |
+
@triton.autotune(
|
442 |
+
configs=pruned_configs_autotune,
|
443 |
+
key=[
|
444 |
+
"N",
|
445 |
+
"HAS_DRESIDUAL",
|
446 |
+
"STORE_DRESIDUAL",
|
447 |
+
"IS_RMS_NORM",
|
448 |
+
"HAS_BIAS",
|
449 |
+
"HAS_DROPOUT",
|
450 |
+
],
|
451 |
+
)
|
452 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
453 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
454 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
455 |
+
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
456 |
+
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
457 |
+
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
458 |
+
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
459 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
460 |
+
@triton.jit
|
461 |
+
def _layer_norm_bwd_kernel(
|
462 |
+
X, # pointer to the input
|
463 |
+
W, # pointer to the weights
|
464 |
+
B, # pointer to the biases
|
465 |
+
Y, # pointer to the output to be recomputed
|
466 |
+
DY, # pointer to the output gradient
|
467 |
+
DX, # pointer to the input gradient
|
468 |
+
DW, # pointer to the partial sum of weights gradient
|
469 |
+
DB, # pointer to the partial sum of biases gradient
|
470 |
+
DRESIDUAL,
|
471 |
+
W1,
|
472 |
+
DY1,
|
473 |
+
DX1,
|
474 |
+
DW1,
|
475 |
+
DB1,
|
476 |
+
DRESIDUAL_IN,
|
477 |
+
ROWSCALE,
|
478 |
+
SEEDS,
|
479 |
+
Mean, # pointer to the mean
|
480 |
+
Rstd, # pointer to the 1/std
|
481 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
482 |
+
stride_y_row,
|
483 |
+
stride_dy_row,
|
484 |
+
stride_dx_row,
|
485 |
+
stride_dres_row,
|
486 |
+
stride_dy1_row,
|
487 |
+
stride_dx1_row,
|
488 |
+
stride_dres_in_row,
|
489 |
+
M, # number of rows in X
|
490 |
+
N, # number of columns in X
|
491 |
+
eps, # epsilon to avoid division by zero
|
492 |
+
dropout_p,
|
493 |
+
rows_per_program,
|
494 |
+
IS_RMS_NORM: tl.constexpr,
|
495 |
+
BLOCK_N: tl.constexpr,
|
496 |
+
HAS_DRESIDUAL: tl.constexpr,
|
497 |
+
STORE_DRESIDUAL: tl.constexpr,
|
498 |
+
HAS_BIAS: tl.constexpr,
|
499 |
+
HAS_DROPOUT: tl.constexpr,
|
500 |
+
HAS_ROWSCALE: tl.constexpr,
|
501 |
+
HAS_DY1: tl.constexpr,
|
502 |
+
HAS_DX1: tl.constexpr,
|
503 |
+
HAS_B1: tl.constexpr,
|
504 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
505 |
+
):
|
506 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
507 |
+
row_block_id = tl.program_id(0)
|
508 |
+
row_start = row_block_id * rows_per_program
|
509 |
+
# Do not early exit if row_start >= M, because we need to write DW and DB
|
510 |
+
cols = tl.arange(0, BLOCK_N)
|
511 |
+
mask = cols < N
|
512 |
+
X += row_start * stride_x_row
|
513 |
+
if HAS_DRESIDUAL:
|
514 |
+
DRESIDUAL += row_start * stride_dres_row
|
515 |
+
if STORE_DRESIDUAL:
|
516 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
517 |
+
DY += row_start * stride_dy_row
|
518 |
+
DX += row_start * stride_dx_row
|
519 |
+
if HAS_DY1:
|
520 |
+
DY1 += row_start * stride_dy1_row
|
521 |
+
if HAS_DX1:
|
522 |
+
DX1 += row_start * stride_dx1_row
|
523 |
+
if RECOMPUTE_OUTPUT:
|
524 |
+
Y += row_start * stride_y_row
|
525 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
526 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
527 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
528 |
+
if HAS_DY1:
|
529 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
530 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
531 |
+
if HAS_BIAS:
|
532 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
533 |
+
if HAS_DY1:
|
534 |
+
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
535 |
+
if HAS_B1:
|
536 |
+
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
537 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
538 |
+
for row in range(row_start, row_end):
|
539 |
+
# Load data to SRAM
|
540 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
541 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
542 |
+
if HAS_DY1:
|
543 |
+
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
544 |
+
if not IS_RMS_NORM:
|
545 |
+
mean = tl.load(Mean + row)
|
546 |
+
rstd = tl.load(Rstd + row)
|
547 |
+
# Compute dx
|
548 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
549 |
+
xhat = tl.where(mask, xhat, 0.0)
|
550 |
+
if RECOMPUTE_OUTPUT:
|
551 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
552 |
+
tl.store(Y + cols, y, mask=mask)
|
553 |
+
wdy = w * dy
|
554 |
+
dw += dy * xhat
|
555 |
+
if HAS_BIAS:
|
556 |
+
db += dy
|
557 |
+
if HAS_DY1:
|
558 |
+
wdy += w1 * dy1
|
559 |
+
dw1 += dy1 * xhat
|
560 |
+
if HAS_B1:
|
561 |
+
db1 += dy1
|
562 |
+
if not IS_RMS_NORM:
|
563 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
564 |
+
c2 = tl.sum(wdy, axis=0) / N
|
565 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
566 |
+
else:
|
567 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
568 |
+
dx = (wdy - xhat * c1) * rstd
|
569 |
+
if HAS_DRESIDUAL:
|
570 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
571 |
+
dx += dres
|
572 |
+
# Write dx
|
573 |
+
if STORE_DRESIDUAL:
|
574 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
575 |
+
if HAS_DX1:
|
576 |
+
if HAS_DROPOUT:
|
577 |
+
keep_mask = (
|
578 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
579 |
+
> dropout_p
|
580 |
+
)
|
581 |
+
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
582 |
+
else:
|
583 |
+
dx1 = dx
|
584 |
+
tl.store(DX1 + cols, dx1, mask=mask)
|
585 |
+
if HAS_DROPOUT:
|
586 |
+
keep_mask = (
|
587 |
+
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
588 |
+
> dropout_p
|
589 |
+
)
|
590 |
+
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
591 |
+
if HAS_ROWSCALE:
|
592 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
593 |
+
dx *= rowscale
|
594 |
+
tl.store(DX + cols, dx, mask=mask)
|
595 |
+
|
596 |
+
X += stride_x_row
|
597 |
+
if HAS_DRESIDUAL:
|
598 |
+
DRESIDUAL += stride_dres_row
|
599 |
+
if STORE_DRESIDUAL:
|
600 |
+
DRESIDUAL_IN += stride_dres_in_row
|
601 |
+
if RECOMPUTE_OUTPUT:
|
602 |
+
Y += stride_y_row
|
603 |
+
DY += stride_dy_row
|
604 |
+
DX += stride_dx_row
|
605 |
+
if HAS_DY1:
|
606 |
+
DY1 += stride_dy1_row
|
607 |
+
if HAS_DX1:
|
608 |
+
DX1 += stride_dx1_row
|
609 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
610 |
+
if HAS_BIAS:
|
611 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
612 |
+
if HAS_DY1:
|
613 |
+
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
614 |
+
if HAS_B1:
|
615 |
+
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
616 |
+
|
617 |
+
|
618 |
+
def _layer_norm_bwd(
|
619 |
+
dy,
|
620 |
+
x,
|
621 |
+
weight,
|
622 |
+
bias,
|
623 |
+
eps,
|
624 |
+
mean,
|
625 |
+
rstd,
|
626 |
+
dresidual=None,
|
627 |
+
dy1=None,
|
628 |
+
weight1=None,
|
629 |
+
bias1=None,
|
630 |
+
seeds=None,
|
631 |
+
dropout_p=0.0,
|
632 |
+
rowscale=None,
|
633 |
+
has_residual=False,
|
634 |
+
has_x1=False,
|
635 |
+
is_rms_norm=False,
|
636 |
+
x_dtype=None,
|
637 |
+
recompute_output=False,
|
638 |
+
):
|
639 |
+
M, N = x.shape
|
640 |
+
assert x.stride(-1) == 1
|
641 |
+
assert dy.stride(-1) == 1
|
642 |
+
assert dy.shape == (M, N)
|
643 |
+
if dresidual is not None:
|
644 |
+
assert dresidual.stride(-1) == 1
|
645 |
+
assert dresidual.shape == (M, N)
|
646 |
+
assert weight.shape == (N,)
|
647 |
+
assert weight.stride(-1) == 1
|
648 |
+
if bias is not None:
|
649 |
+
assert bias.stride(-1) == 1
|
650 |
+
assert bias.shape == (N,)
|
651 |
+
if dy1 is not None:
|
652 |
+
assert weight1 is not None
|
653 |
+
assert dy1.shape == dy.shape
|
654 |
+
assert dy1.stride(-1) == 1
|
655 |
+
if weight1 is not None:
|
656 |
+
assert weight1.shape == (N,)
|
657 |
+
assert weight1.stride(-1) == 1
|
658 |
+
if bias1 is not None:
|
659 |
+
assert bias1.shape == (N,)
|
660 |
+
assert bias1.stride(-1) == 1
|
661 |
+
if seeds is not None:
|
662 |
+
assert seeds.is_contiguous()
|
663 |
+
assert seeds.shape == (M if not has_x1 else M * 2,)
|
664 |
+
if rowscale is not None:
|
665 |
+
assert rowscale.is_contiguous()
|
666 |
+
assert rowscale.shape == (M,)
|
667 |
+
# allocate output
|
668 |
+
dx = (
|
669 |
+
torch.empty_like(x)
|
670 |
+
if x_dtype is None
|
671 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
672 |
+
)
|
673 |
+
dresidual_in = (
|
674 |
+
torch.empty_like(x)
|
675 |
+
if has_residual
|
676 |
+
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
677 |
+
else None
|
678 |
+
)
|
679 |
+
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
680 |
+
y = (
|
681 |
+
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
682 |
+
if recompute_output
|
683 |
+
else None
|
684 |
+
)
|
685 |
+
if recompute_output:
|
686 |
+
assert (
|
687 |
+
weight1 is None
|
688 |
+
), "recompute_output is not supported with parallel LayerNorm"
|
689 |
+
|
690 |
+
# Less than 64KB per feature: enqueue fused kernel
|
691 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
692 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
693 |
+
if N > BLOCK_N:
|
694 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
695 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
696 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
697 |
+
_db = (
|
698 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
699 |
+
if bias is not None
|
700 |
+
else None
|
701 |
+
)
|
702 |
+
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
703 |
+
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
704 |
+
rows_per_program = math.ceil(M / sm_count)
|
705 |
+
grid = (sm_count,)
|
706 |
+
with torch.cuda.device(x.device.index):
|
707 |
+
_layer_norm_bwd_kernel[grid](
|
708 |
+
x,
|
709 |
+
weight,
|
710 |
+
bias,
|
711 |
+
y,
|
712 |
+
dy,
|
713 |
+
dx,
|
714 |
+
_dw,
|
715 |
+
_db,
|
716 |
+
dresidual,
|
717 |
+
weight1,
|
718 |
+
dy1,
|
719 |
+
dx1,
|
720 |
+
_dw1,
|
721 |
+
_db1,
|
722 |
+
dresidual_in,
|
723 |
+
rowscale,
|
724 |
+
seeds,
|
725 |
+
mean,
|
726 |
+
rstd,
|
727 |
+
x.stride(0),
|
728 |
+
0 if not recompute_output else y.stride(0),
|
729 |
+
dy.stride(0),
|
730 |
+
dx.stride(0),
|
731 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
732 |
+
dy1.stride(0) if dy1 is not None else 0,
|
733 |
+
dx1.stride(0) if dx1 is not None else 0,
|
734 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
735 |
+
M,
|
736 |
+
N,
|
737 |
+
eps,
|
738 |
+
dropout_p,
|
739 |
+
rows_per_program,
|
740 |
+
is_rms_norm,
|
741 |
+
BLOCK_N,
|
742 |
+
dresidual is not None,
|
743 |
+
dresidual_in is not None,
|
744 |
+
bias is not None,
|
745 |
+
dropout_p > 0.0,
|
746 |
+
)
|
747 |
+
dw = _dw.sum(0).to(weight.dtype)
|
748 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
749 |
+
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
750 |
+
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
751 |
+
# Don't need to compute dresidual_in separately in this case
|
752 |
+
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
753 |
+
dresidual_in = dx
|
754 |
+
if has_x1 and dropout_p == 0.0:
|
755 |
+
dx1 = dx
|
756 |
+
return (
|
757 |
+
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
758 |
+
if not recompute_output
|
759 |
+
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
760 |
+
)
|
761 |
+
|
762 |
+
|
763 |
+
class LayerNormFn(torch.autograd.Function):
|
764 |
+
@staticmethod
|
765 |
+
def forward(
|
766 |
+
ctx,
|
767 |
+
x,
|
768 |
+
weight,
|
769 |
+
bias,
|
770 |
+
residual=None,
|
771 |
+
x1=None,
|
772 |
+
weight1=None,
|
773 |
+
bias1=None,
|
774 |
+
eps=1e-6,
|
775 |
+
dropout_p=0.0,
|
776 |
+
rowscale=None,
|
777 |
+
prenorm=False,
|
778 |
+
residual_in_fp32=False,
|
779 |
+
is_rms_norm=False,
|
780 |
+
return_dropout_mask=False,
|
781 |
+
):
|
782 |
+
x_shape_og = x.shape
|
783 |
+
# reshape input data into 2D tensor
|
784 |
+
x = x.reshape(-1, x.shape[-1])
|
785 |
+
if x.stride(-1) != 1:
|
786 |
+
x = x.contiguous()
|
787 |
+
if residual is not None:
|
788 |
+
assert residual.shape == x_shape_og
|
789 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
790 |
+
if residual.stride(-1) != 1:
|
791 |
+
residual = residual.contiguous()
|
792 |
+
if x1 is not None:
|
793 |
+
assert x1.shape == x_shape_og
|
794 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
795 |
+
x1 = x1.reshape(-1, x1.shape[-1])
|
796 |
+
if x1.stride(-1) != 1:
|
797 |
+
x1 = x1.contiguous()
|
798 |
+
weight = weight.contiguous()
|
799 |
+
if bias is not None:
|
800 |
+
bias = bias.contiguous()
|
801 |
+
if weight1 is not None:
|
802 |
+
weight1 = weight1.contiguous()
|
803 |
+
if bias1 is not None:
|
804 |
+
bias1 = bias1.contiguous()
|
805 |
+
if rowscale is not None:
|
806 |
+
rowscale = rowscale.reshape(-1).contiguous()
|
807 |
+
residual_dtype = (
|
808 |
+
residual.dtype
|
809 |
+
if residual is not None
|
810 |
+
else (torch.float32 if residual_in_fp32 else None)
|
811 |
+
)
|
812 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
813 |
+
_layer_norm_fwd(
|
814 |
+
x,
|
815 |
+
weight,
|
816 |
+
bias,
|
817 |
+
eps,
|
818 |
+
residual,
|
819 |
+
x1,
|
820 |
+
weight1,
|
821 |
+
bias1,
|
822 |
+
dropout_p=dropout_p,
|
823 |
+
rowscale=rowscale,
|
824 |
+
residual_dtype=residual_dtype,
|
825 |
+
is_rms_norm=is_rms_norm,
|
826 |
+
return_dropout_mask=return_dropout_mask,
|
827 |
+
)
|
828 |
+
)
|
829 |
+
ctx.save_for_backward(
|
830 |
+
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
831 |
+
)
|
832 |
+
ctx.x_shape_og = x_shape_og
|
833 |
+
ctx.eps = eps
|
834 |
+
ctx.dropout_p = dropout_p
|
835 |
+
ctx.is_rms_norm = is_rms_norm
|
836 |
+
ctx.has_residual = residual is not None
|
837 |
+
ctx.has_x1 = x1 is not None
|
838 |
+
ctx.prenorm = prenorm
|
839 |
+
ctx.x_dtype = x.dtype
|
840 |
+
y = y.reshape(x_shape_og)
|
841 |
+
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
842 |
+
residual_out = (
|
843 |
+
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
844 |
+
)
|
845 |
+
dropout_mask = (
|
846 |
+
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
847 |
+
)
|
848 |
+
dropout_mask1 = (
|
849 |
+
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
850 |
+
)
|
851 |
+
if not return_dropout_mask:
|
852 |
+
if weight1 is None:
|
853 |
+
return y if not prenorm else (y, residual_out)
|
854 |
+
else:
|
855 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
856 |
+
else:
|
857 |
+
if weight1 is None:
|
858 |
+
return (
|
859 |
+
(y, dropout_mask, dropout_mask1)
|
860 |
+
if not prenorm
|
861 |
+
else (y, residual_out, dropout_mask, dropout_mask1)
|
862 |
+
)
|
863 |
+
else:
|
864 |
+
return (
|
865 |
+
(y, y1, dropout_mask, dropout_mask1)
|
866 |
+
if not prenorm
|
867 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
868 |
+
)
|
869 |
+
|
870 |
+
@staticmethod
|
871 |
+
def backward(ctx, dy, *args):
|
872 |
+
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
873 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
874 |
+
if dy.stride(-1) != 1:
|
875 |
+
dy = dy.contiguous()
|
876 |
+
assert dy.shape == x.shape
|
877 |
+
if weight1 is not None:
|
878 |
+
dy1, args = args[0], args[1:]
|
879 |
+
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
880 |
+
if dy1.stride(-1) != 1:
|
881 |
+
dy1 = dy1.contiguous()
|
882 |
+
assert dy1.shape == x.shape
|
883 |
+
else:
|
884 |
+
dy1 = None
|
885 |
+
if ctx.prenorm:
|
886 |
+
dresidual = args[0]
|
887 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
888 |
+
if dresidual.stride(-1) != 1:
|
889 |
+
dresidual = dresidual.contiguous()
|
890 |
+
assert dresidual.shape == x.shape
|
891 |
+
else:
|
892 |
+
dresidual = None
|
893 |
+
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
894 |
+
dy,
|
895 |
+
x,
|
896 |
+
weight,
|
897 |
+
bias,
|
898 |
+
ctx.eps,
|
899 |
+
mean,
|
900 |
+
rstd,
|
901 |
+
dresidual,
|
902 |
+
dy1,
|
903 |
+
weight1,
|
904 |
+
bias1,
|
905 |
+
seeds,
|
906 |
+
ctx.dropout_p,
|
907 |
+
rowscale,
|
908 |
+
ctx.has_residual,
|
909 |
+
ctx.has_x1,
|
910 |
+
ctx.is_rms_norm,
|
911 |
+
x_dtype=ctx.x_dtype,
|
912 |
+
)
|
913 |
+
return (
|
914 |
+
dx.reshape(ctx.x_shape_og),
|
915 |
+
dw,
|
916 |
+
db,
|
917 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
918 |
+
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
919 |
+
dw1,
|
920 |
+
db1,
|
921 |
+
None,
|
922 |
+
None,
|
923 |
+
None,
|
924 |
+
None,
|
925 |
+
None,
|
926 |
+
None,
|
927 |
+
None,
|
928 |
+
)
|
929 |
+
|
930 |
+
|
931 |
+
def layer_norm_fn(
|
932 |
+
x,
|
933 |
+
weight,
|
934 |
+
bias,
|
935 |
+
residual=None,
|
936 |
+
x1=None,
|
937 |
+
weight1=None,
|
938 |
+
bias1=None,
|
939 |
+
eps=1e-6,
|
940 |
+
dropout_p=0.0,
|
941 |
+
rowscale=None,
|
942 |
+
prenorm=False,
|
943 |
+
residual_in_fp32=False,
|
944 |
+
is_rms_norm=False,
|
945 |
+
return_dropout_mask=False,
|
946 |
+
):
|
947 |
+
return LayerNormFn.apply(
|
948 |
+
x,
|
949 |
+
weight,
|
950 |
+
bias,
|
951 |
+
residual,
|
952 |
+
x1,
|
953 |
+
weight1,
|
954 |
+
bias1,
|
955 |
+
eps,
|
956 |
+
dropout_p,
|
957 |
+
rowscale,
|
958 |
+
prenorm,
|
959 |
+
residual_in_fp32,
|
960 |
+
is_rms_norm,
|
961 |
+
return_dropout_mask,
|
962 |
+
)
|
963 |
+
|
964 |
+
|
965 |
+
def rms_norm_fn(
|
966 |
+
x,
|
967 |
+
weight,
|
968 |
+
bias,
|
969 |
+
residual=None,
|
970 |
+
x1=None,
|
971 |
+
weight1=None,
|
972 |
+
bias1=None,
|
973 |
+
eps=1e-6,
|
974 |
+
dropout_p=0.0,
|
975 |
+
rowscale=None,
|
976 |
+
prenorm=False,
|
977 |
+
residual_in_fp32=False,
|
978 |
+
return_dropout_mask=False,
|
979 |
+
):
|
980 |
+
return LayerNormFn.apply(
|
981 |
+
x,
|
982 |
+
weight,
|
983 |
+
bias,
|
984 |
+
residual,
|
985 |
+
x1,
|
986 |
+
weight1,
|
987 |
+
bias1,
|
988 |
+
eps,
|
989 |
+
dropout_p,
|
990 |
+
rowscale,
|
991 |
+
prenorm,
|
992 |
+
residual_in_fp32,
|
993 |
+
True,
|
994 |
+
return_dropout_mask,
|
995 |
+
)
|
996 |
+
|
997 |
+
|
998 |
+
class RMSNorm(torch.nn.Module):
|
999 |
+
|
1000 |
+
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
1001 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
1002 |
+
super().__init__()
|
1003 |
+
self.eps = eps
|
1004 |
+
if dropout_p > 0.0:
|
1005 |
+
self.drop = torch.nn.Dropout(dropout_p)
|
1006 |
+
else:
|
1007 |
+
self.drop = None
|
1008 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
1009 |
+
self.register_parameter("bias", None)
|
1010 |
+
self.reset_parameters()
|
1011 |
+
|
1012 |
+
def reset_parameters(self):
|
1013 |
+
torch.nn.init.ones_(self.weight)
|
1014 |
+
|
1015 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
1016 |
+
return rms_norm_fn(
|
1017 |
+
x,
|
1018 |
+
self.weight,
|
1019 |
+
self.bias,
|
1020 |
+
residual=residual,
|
1021 |
+
eps=self.eps,
|
1022 |
+
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
1023 |
+
prenorm=prenorm,
|
1024 |
+
residual_in_fp32=residual_in_fp32,
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
|
1028 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
1029 |
+
@staticmethod
|
1030 |
+
@custom_fwd
|
1031 |
+
def forward(
|
1032 |
+
ctx,
|
1033 |
+
x,
|
1034 |
+
norm_weight,
|
1035 |
+
norm_bias,
|
1036 |
+
linear_weight,
|
1037 |
+
linear_bias,
|
1038 |
+
residual=None,
|
1039 |
+
eps=1e-6,
|
1040 |
+
prenorm=False,
|
1041 |
+
residual_in_fp32=False,
|
1042 |
+
is_rms_norm=False,
|
1043 |
+
):
|
1044 |
+
x_shape_og = x.shape
|
1045 |
+
# reshape input data into 2D tensor
|
1046 |
+
x = x.reshape(-1, x.shape[-1])
|
1047 |
+
if x.stride(-1) != 1:
|
1048 |
+
x = x.contiguous()
|
1049 |
+
if residual is not None:
|
1050 |
+
assert residual.shape == x_shape_og
|
1051 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
1052 |
+
if residual.stride(-1) != 1:
|
1053 |
+
residual = residual.contiguous()
|
1054 |
+
norm_weight = norm_weight.contiguous()
|
1055 |
+
if norm_bias is not None:
|
1056 |
+
norm_bias = norm_bias.contiguous()
|
1057 |
+
residual_dtype = (
|
1058 |
+
residual.dtype
|
1059 |
+
if residual is not None
|
1060 |
+
else (torch.float32 if residual_in_fp32 else None)
|
1061 |
+
)
|
1062 |
+
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
1063 |
+
x,
|
1064 |
+
norm_weight,
|
1065 |
+
norm_bias,
|
1066 |
+
eps,
|
1067 |
+
residual,
|
1068 |
+
out_dtype=(
|
1069 |
+
None
|
1070 |
+
if not torch.is_autocast_enabled()
|
1071 |
+
else torch.get_autocast_gpu_dtype()
|
1072 |
+
),
|
1073 |
+
residual_dtype=residual_dtype,
|
1074 |
+
is_rms_norm=is_rms_norm,
|
1075 |
+
)
|
1076 |
+
y = y.reshape(x_shape_og)
|
1077 |
+
dtype = (
|
1078 |
+
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
1079 |
+
)
|
1080 |
+
linear_weight = linear_weight.to(dtype)
|
1081 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1082 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1083 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
1084 |
+
ctx.save_for_backward(
|
1085 |
+
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
1086 |
+
)
|
1087 |
+
ctx.x_shape_og = x_shape_og
|
1088 |
+
ctx.eps = eps
|
1089 |
+
ctx.is_rms_norm = is_rms_norm
|
1090 |
+
ctx.has_residual = residual is not None
|
1091 |
+
ctx.prenorm = prenorm
|
1092 |
+
ctx.x_dtype = x.dtype
|
1093 |
+
ctx.linear_bias_is_none = linear_bias is None
|
1094 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1095 |
+
|
1096 |
+
@staticmethod
|
1097 |
+
@custom_bwd
|
1098 |
+
def backward(ctx, dout, *args):
|
1099 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1100 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
1101 |
+
dy = F.linear(dout, linear_weight.t())
|
1102 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
1103 |
+
if dy.stride(-1) != 1:
|
1104 |
+
dy = dy.contiguous()
|
1105 |
+
assert dy.shape == x.shape
|
1106 |
+
if ctx.prenorm:
|
1107 |
+
dresidual = args[0]
|
1108 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
1109 |
+
if dresidual.stride(-1) != 1:
|
1110 |
+
dresidual = dresidual.contiguous()
|
1111 |
+
assert dresidual.shape == x.shape
|
1112 |
+
else:
|
1113 |
+
dresidual = None
|
1114 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
1115 |
+
dy,
|
1116 |
+
x,
|
1117 |
+
norm_weight,
|
1118 |
+
norm_bias,
|
1119 |
+
ctx.eps,
|
1120 |
+
mean,
|
1121 |
+
rstd,
|
1122 |
+
dresidual=dresidual,
|
1123 |
+
has_residual=ctx.has_residual,
|
1124 |
+
is_rms_norm=ctx.is_rms_norm,
|
1125 |
+
x_dtype=ctx.x_dtype,
|
1126 |
+
recompute_output=True,
|
1127 |
+
)
|
1128 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
1129 |
+
return (
|
1130 |
+
dx.reshape(ctx.x_shape_og),
|
1131 |
+
dnorm_weight,
|
1132 |
+
dnorm_bias,
|
1133 |
+
dlinear_weight,
|
1134 |
+
dlinear_bias,
|
1135 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
1136 |
+
None,
|
1137 |
+
None,
|
1138 |
+
None,
|
1139 |
+
None,
|
1140 |
+
)
|
1141 |
+
|
1142 |
+
|
1143 |
+
def layer_norm_linear_fn(
|
1144 |
+
x,
|
1145 |
+
norm_weight,
|
1146 |
+
norm_bias,
|
1147 |
+
linear_weight,
|
1148 |
+
linear_bias,
|
1149 |
+
residual=None,
|
1150 |
+
eps=1e-6,
|
1151 |
+
prenorm=False,
|
1152 |
+
residual_in_fp32=False,
|
1153 |
+
is_rms_norm=False,
|
1154 |
+
):
|
1155 |
+
return LayerNormLinearFn.apply(
|
1156 |
+
x,
|
1157 |
+
norm_weight,
|
1158 |
+
norm_bias,
|
1159 |
+
linear_weight,
|
1160 |
+
linear_bias,
|
1161 |
+
residual,
|
1162 |
+
eps,
|
1163 |
+
prenorm,
|
1164 |
+
residual_in_fp32,
|
1165 |
+
is_rms_norm,
|
1166 |
+
)
|
torch-ext/mamba_ssm/ops/triton/layernorm_gated.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
3 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
4 |
+
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
5 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import triton
|
13 |
+
import triton.language as tl
|
14 |
+
|
15 |
+
from einops import rearrange
|
16 |
+
|
17 |
+
|
18 |
+
def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
|
19 |
+
dtype = x.dtype
|
20 |
+
N = x.shape[-1]
|
21 |
+
weight = weight.float()
|
22 |
+
bias = bias.float() if bias is not None else None
|
23 |
+
if upcast:
|
24 |
+
x = x.float()
|
25 |
+
z = z.float() if z is not None else z
|
26 |
+
if z is not None and not norm_before_gate:
|
27 |
+
x = x * F.silu(z)
|
28 |
+
if group_size is None:
|
29 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
30 |
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
31 |
+
else:
|
32 |
+
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
33 |
+
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
34 |
+
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
35 |
+
if bias is not None:
|
36 |
+
out = out + bias
|
37 |
+
if z is not None and norm_before_gate:
|
38 |
+
out *= F.silu(z)
|
39 |
+
return out.to(dtype)
|
40 |
+
|
41 |
+
|
42 |
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
43 |
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
44 |
+
@triton.jit
|
45 |
+
def _layer_norm_fwd_1pass_kernel(
|
46 |
+
X, # pointer to the input
|
47 |
+
Y, # pointer to the output
|
48 |
+
W, # pointer to the weights
|
49 |
+
B, # pointer to the biases
|
50 |
+
Z, # pointer to the other branch
|
51 |
+
Mean, # pointer to the mean
|
52 |
+
Rstd, # pointer to the 1/std
|
53 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
54 |
+
stride_y_row,
|
55 |
+
stride_z_row,
|
56 |
+
M, # number of rows in X
|
57 |
+
N, # number of columns in X
|
58 |
+
eps, # epsilon to avoid division by zero
|
59 |
+
BLOCK_N: tl.constexpr,
|
60 |
+
HAS_BIAS: tl.constexpr,
|
61 |
+
HAS_Z: tl.constexpr,
|
62 |
+
NORM_BEFORE_GATE: tl.constexpr,
|
63 |
+
IS_RMS_NORM: tl.constexpr,
|
64 |
+
):
|
65 |
+
# Map the program id to the row of X and Y it should compute.
|
66 |
+
row = tl.program_id(0)
|
67 |
+
group = tl.program_id(1)
|
68 |
+
X += row * stride_x_row + group * N
|
69 |
+
Y += row * stride_y_row + group * N
|
70 |
+
if HAS_Z:
|
71 |
+
Z += row * stride_z_row + group * N
|
72 |
+
if not IS_RMS_NORM:
|
73 |
+
Mean += group * M
|
74 |
+
Rstd += group * M
|
75 |
+
W += group * N
|
76 |
+
if HAS_BIAS:
|
77 |
+
B += group * N
|
78 |
+
# Compute mean and variance
|
79 |
+
cols = tl.arange(0, BLOCK_N)
|
80 |
+
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
81 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
82 |
+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
83 |
+
x *= z * tl.sigmoid(z)
|
84 |
+
if not IS_RMS_NORM:
|
85 |
+
mean = tl.sum(x, axis=0) / N
|
86 |
+
tl.store(Mean + row, mean)
|
87 |
+
xbar = tl.where(cols < N, x - mean, 0.)
|
88 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
89 |
+
else:
|
90 |
+
xbar = tl.where(cols < N, x, 0.)
|
91 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
92 |
+
rstd = 1 / tl.sqrt(var + eps)
|
93 |
+
tl.store(Rstd + row, rstd)
|
94 |
+
# Normalize and apply linear transformation
|
95 |
+
mask = cols < N
|
96 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
97 |
+
if HAS_BIAS:
|
98 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
99 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
100 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
101 |
+
if HAS_Z and NORM_BEFORE_GATE:
|
102 |
+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
103 |
+
y *= z * tl.sigmoid(z)
|
104 |
+
# Write output
|
105 |
+
tl.store(Y + cols, y, mask=mask)
|
106 |
+
|
107 |
+
|
108 |
+
def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
109 |
+
M, N = x.shape
|
110 |
+
if group_size is None:
|
111 |
+
group_size = N
|
112 |
+
assert N % group_size == 0
|
113 |
+
ngroups = N // group_size
|
114 |
+
assert x.stride(-1) == 1
|
115 |
+
if z is not None:
|
116 |
+
assert z.stride(-1) == 1
|
117 |
+
assert z.shape == (M, N)
|
118 |
+
assert weight.shape == (N,)
|
119 |
+
assert weight.stride(-1) == 1
|
120 |
+
if bias is not None:
|
121 |
+
assert bias.stride(-1) == 1
|
122 |
+
assert bias.shape == (N,)
|
123 |
+
# allocate output
|
124 |
+
if out is not None:
|
125 |
+
assert out.shape == x.shape
|
126 |
+
else:
|
127 |
+
out = torch.empty_like(x)
|
128 |
+
assert out.stride(-1) == 1
|
129 |
+
mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
130 |
+
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
131 |
+
# Less than 64KB per feature: enqueue fused kernel
|
132 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
133 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
134 |
+
if group_size > BLOCK_N:
|
135 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
136 |
+
# heuristics for number of warps
|
137 |
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
138 |
+
grid = (M, ngroups)
|
139 |
+
with torch.cuda.device(x.device.index):
|
140 |
+
_layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
|
141 |
+
x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
|
142 |
+
M, group_size, eps,
|
143 |
+
BLOCK_N=BLOCK_N,
|
144 |
+
NORM_BEFORE_GATE=norm_before_gate,
|
145 |
+
IS_RMS_NORM=is_rms_norm,
|
146 |
+
num_warps=num_warps)
|
147 |
+
return out, mean, rstd
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
152 |
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
153 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
154 |
+
@triton.jit
|
155 |
+
def _layer_norm_bwd_kernel(
|
156 |
+
X, # pointer to the input
|
157 |
+
W, # pointer to the weights
|
158 |
+
B, # pointer to the biases
|
159 |
+
Z, # pointer to the other branch
|
160 |
+
Y, # pointer to the output to be recomputed
|
161 |
+
DY, # pointer to the output gradient
|
162 |
+
DX, # pointer to the input gradient
|
163 |
+
DW, # pointer to the partial sum of weights gradient
|
164 |
+
DB, # pointer to the partial sum of biases gradient
|
165 |
+
DZ, # pointer to the other branch
|
166 |
+
Mean, # pointer to the mean
|
167 |
+
Rstd, # pointer to the 1/std
|
168 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
169 |
+
stride_z_row,
|
170 |
+
stride_y_row,
|
171 |
+
stride_dy_row,
|
172 |
+
stride_dx_row,
|
173 |
+
stride_dz_row,
|
174 |
+
stride_dw_row,
|
175 |
+
stride_db_row,
|
176 |
+
M, # number of rows in X
|
177 |
+
N, # number of columns in X
|
178 |
+
eps, # epsilon to avoid division by zero
|
179 |
+
rows_per_program,
|
180 |
+
NORM_BEFORE_GATE: tl.constexpr,
|
181 |
+
IS_RMS_NORM: tl.constexpr,
|
182 |
+
HAS_BIAS: tl.constexpr,
|
183 |
+
HAS_Z: tl.constexpr,
|
184 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
185 |
+
BLOCK_N: tl.constexpr,
|
186 |
+
):
|
187 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
188 |
+
row_block_id = tl.program_id(0)
|
189 |
+
group = tl.program_id(1)
|
190 |
+
row_start = row_block_id * rows_per_program
|
191 |
+
cols = tl.arange(0, BLOCK_N)
|
192 |
+
mask = cols < N
|
193 |
+
X += row_start * stride_x_row + group * N
|
194 |
+
if HAS_Z:
|
195 |
+
Z += row_start * stride_z_row + group * N
|
196 |
+
DZ += row_start * stride_dz_row + group * N
|
197 |
+
DY += row_start * stride_dy_row + group * N
|
198 |
+
DX += row_start * stride_dx_row + group * N
|
199 |
+
if RECOMPUTE_OUTPUT:
|
200 |
+
Y += row_start * stride_y_row + group * N
|
201 |
+
if not IS_RMS_NORM:
|
202 |
+
Mean += group * M
|
203 |
+
Rstd += group * M
|
204 |
+
W += group * N
|
205 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
206 |
+
if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
|
207 |
+
B += group * N
|
208 |
+
b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
|
209 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
210 |
+
if HAS_BIAS:
|
211 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
212 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
213 |
+
for row in range(row_start, row_end):
|
214 |
+
# Load data to SRAM
|
215 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
216 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
217 |
+
if not IS_RMS_NORM:
|
218 |
+
mean = tl.load(Mean + row)
|
219 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
220 |
+
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
221 |
+
x_og = x
|
222 |
+
x = x_og * z * tl.sigmoid(z)
|
223 |
+
rstd = tl.load(Rstd + row)
|
224 |
+
# Compute dx
|
225 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
226 |
+
xhat = tl.where(mask, xhat, 0.)
|
227 |
+
if HAS_Z and NORM_BEFORE_GATE:
|
228 |
+
z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
|
229 |
+
z_sigmoid = tl.sigmoid(z)
|
230 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
231 |
+
if RECOMPUTE_OUTPUT:
|
232 |
+
tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
|
233 |
+
dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
234 |
+
tl.store(DZ + cols, dz, mask=mask)
|
235 |
+
dy *= z * z_sigmoid
|
236 |
+
else:
|
237 |
+
if RECOMPUTE_OUTPUT:
|
238 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
239 |
+
tl.store(Y + cols, y, mask=mask)
|
240 |
+
wdy = w * dy
|
241 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
242 |
+
if not IS_RMS_NORM:
|
243 |
+
c2 = tl.sum(wdy, axis=0) / N
|
244 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
245 |
+
else:
|
246 |
+
dx = (wdy - xhat * c1) * rstd
|
247 |
+
dw += dy * xhat
|
248 |
+
if HAS_BIAS:
|
249 |
+
db += dy
|
250 |
+
if HAS_Z and not NORM_BEFORE_GATE:
|
251 |
+
z_sigmoid = tl.sigmoid(z)
|
252 |
+
dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
|
253 |
+
tl.store(DZ + cols, dz, mask=mask)
|
254 |
+
dx *= z * z_sigmoid
|
255 |
+
# Write dx
|
256 |
+
tl.store(DX + cols, dx, mask=mask)
|
257 |
+
|
258 |
+
X += stride_x_row
|
259 |
+
if HAS_Z:
|
260 |
+
Z += stride_z_row
|
261 |
+
DZ += stride_dz_row
|
262 |
+
if RECOMPUTE_OUTPUT:
|
263 |
+
Y += stride_y_row
|
264 |
+
DY += stride_dy_row
|
265 |
+
DX += stride_dx_row
|
266 |
+
tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
|
267 |
+
if HAS_BIAS:
|
268 |
+
tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
|
269 |
+
|
270 |
+
|
271 |
+
def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
|
272 |
+
norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
|
273 |
+
M, N = x.shape
|
274 |
+
if group_size is None:
|
275 |
+
group_size = N
|
276 |
+
assert N % group_size == 0
|
277 |
+
ngroups = N // group_size
|
278 |
+
assert x.stride(-1) == 1
|
279 |
+
assert dy.stride(-1) == 1
|
280 |
+
assert dy.shape == (M, N)
|
281 |
+
if z is not None:
|
282 |
+
assert z.stride(-1) == 1
|
283 |
+
assert z.shape == (M, N)
|
284 |
+
assert weight.shape == (N,)
|
285 |
+
assert weight.stride(-1) == 1
|
286 |
+
if bias is not None:
|
287 |
+
assert bias.stride(-1) == 1
|
288 |
+
assert bias.shape == (N,)
|
289 |
+
# allocate output
|
290 |
+
dx = torch.empty_like(x)
|
291 |
+
if dz is not None:
|
292 |
+
assert z is not None
|
293 |
+
assert dz.shape == z.shape
|
294 |
+
assert dz.stride(-1) == 1
|
295 |
+
else:
|
296 |
+
dz = torch.empty_like(z) if z is not None else None
|
297 |
+
if recompute_output:
|
298 |
+
if out is None:
|
299 |
+
out = torch.empty_like(x)
|
300 |
+
assert out.shape == x.shape
|
301 |
+
|
302 |
+
# Less than 64KB per feature: enqueue fused kernel
|
303 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
304 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
305 |
+
if group_size > BLOCK_N:
|
306 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
307 |
+
# heuristics for number of warps
|
308 |
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
309 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
310 |
+
# If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
|
311 |
+
# would limit the occupancy.
|
312 |
+
nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
|
313 |
+
_dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
|
314 |
+
_db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
|
315 |
+
rows_per_program = math.ceil(M / nrow_groups)
|
316 |
+
grid = (nrow_groups, ngroups)
|
317 |
+
with torch.cuda.device(x.device.index):
|
318 |
+
_layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
|
319 |
+
dy, dx, _dw, _db, dz, mean, rstd,
|
320 |
+
x.stride(0),
|
321 |
+
z.stride(0) if z is not None else 0,
|
322 |
+
0 if not recompute_output else out.stride(0),
|
323 |
+
dy.stride(0), dx.stride(0),
|
324 |
+
dz.stride(0) if dz is not None else 0,
|
325 |
+
_dw.stride(0),
|
326 |
+
_db.stride(0) if _db is not None else 0,
|
327 |
+
M, group_size, eps,
|
328 |
+
rows_per_program,
|
329 |
+
BLOCK_N=BLOCK_N,
|
330 |
+
NORM_BEFORE_GATE=norm_before_gate,
|
331 |
+
IS_RMS_NORM=is_rms_norm,
|
332 |
+
num_warps=num_warps)
|
333 |
+
dw = _dw.sum(0).to(weight.dtype)
|
334 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
335 |
+
return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
|
336 |
+
|
337 |
+
|
338 |
+
class LayerNormFn(torch.autograd.Function):
|
339 |
+
|
340 |
+
@staticmethod
|
341 |
+
def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
|
342 |
+
is_rms_norm=False):
|
343 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
344 |
+
"""
|
345 |
+
|
346 |
+
x_shape_og = x.shape
|
347 |
+
# reshape input data into 2D tensor
|
348 |
+
x = x.reshape(-1, x.shape[-1])
|
349 |
+
if x.stride(-1) != 1:
|
350 |
+
x = x.contiguous()
|
351 |
+
if z is not None:
|
352 |
+
assert z.shape == x_shape_og
|
353 |
+
z = z.reshape(-1, z.shape[-1])
|
354 |
+
if z.stride(-1) != 1:
|
355 |
+
z = z.contiguous()
|
356 |
+
weight = weight.contiguous()
|
357 |
+
if bias is not None:
|
358 |
+
bias = bias.contiguous()
|
359 |
+
y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
|
360 |
+
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
361 |
+
ctx.x_shape_og = x_shape_og
|
362 |
+
ctx.eps = eps
|
363 |
+
ctx.group_size = group_size
|
364 |
+
ctx.norm_before_gate = norm_before_gate
|
365 |
+
ctx.is_rms_norm = is_rms_norm
|
366 |
+
return y.reshape(x_shape_og)
|
367 |
+
|
368 |
+
@staticmethod
|
369 |
+
def backward(ctx, dy):
|
370 |
+
x, weight, bias, mean, rstd, z = ctx.saved_tensors
|
371 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
372 |
+
if dy.stride(-1) != 1:
|
373 |
+
dy = dy.contiguous()
|
374 |
+
assert dy.shape == x.shape
|
375 |
+
dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
|
376 |
+
ctx.norm_before_gate, ctx.is_rms_norm)
|
377 |
+
return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
|
378 |
+
|
379 |
+
|
380 |
+
def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
|
381 |
+
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
|
382 |
+
|
383 |
+
|
384 |
+
def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
|
385 |
+
return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
|
386 |
+
|
387 |
+
|
388 |
+
class LayerNorm(torch.nn.Module):
|
389 |
+
|
390 |
+
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
391 |
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
392 |
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
393 |
+
"""
|
394 |
+
|
395 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
396 |
+
super().__init__()
|
397 |
+
self.eps = eps
|
398 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
399 |
+
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
400 |
+
self.group_size = group_size
|
401 |
+
self.norm_before_gate = norm_before_gate
|
402 |
+
self.reset_parameters()
|
403 |
+
|
404 |
+
def reset_parameters(self):
|
405 |
+
torch.nn.init.ones_(self.weight)
|
406 |
+
torch.nn.init.zeros_(self.bias)
|
407 |
+
|
408 |
+
def forward(self, x, z=None):
|
409 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
410 |
+
"""
|
411 |
+
return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
|
412 |
+
norm_before_gate=self.norm_before_gate)
|
413 |
+
|
414 |
+
|
415 |
+
class RMSNorm(torch.nn.Module):
|
416 |
+
|
417 |
+
def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
|
418 |
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
419 |
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
420 |
+
"""
|
421 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
422 |
+
super().__init__()
|
423 |
+
self.eps = eps
|
424 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
425 |
+
self.register_parameter("bias", None)
|
426 |
+
self.group_size = group_size
|
427 |
+
self.norm_before_gate = norm_before_gate
|
428 |
+
self.reset_parameters()
|
429 |
+
|
430 |
+
def reset_parameters(self):
|
431 |
+
torch.nn.init.ones_(self.weight)
|
432 |
+
|
433 |
+
def forward(self, x, z=None):
|
434 |
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
435 |
+
"""
|
436 |
+
return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
|
437 |
+
norm_before_gate=self.norm_before_gate)
|
torch-ext/mamba_ssm/ops/triton/selective_state_update.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from .softplus import softplus
|
16 |
+
|
17 |
+
|
18 |
+
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
19 |
+
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
20 |
+
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
21 |
+
@triton.heuristics(
|
22 |
+
{
|
23 |
+
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
24 |
+
is not None
|
25 |
+
}
|
26 |
+
)
|
27 |
+
@triton.heuristics(
|
28 |
+
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
29 |
+
)
|
30 |
+
@triton.jit
|
31 |
+
def _selective_scan_update_kernel(
|
32 |
+
# Pointers to matrices
|
33 |
+
state_ptr,
|
34 |
+
x_ptr,
|
35 |
+
dt_ptr,
|
36 |
+
dt_bias_ptr,
|
37 |
+
A_ptr,
|
38 |
+
B_ptr,
|
39 |
+
C_ptr,
|
40 |
+
D_ptr,
|
41 |
+
z_ptr,
|
42 |
+
out_ptr,
|
43 |
+
state_batch_indices_ptr,
|
44 |
+
# Matrix dimensions
|
45 |
+
batch,
|
46 |
+
nheads,
|
47 |
+
dim,
|
48 |
+
dstate,
|
49 |
+
nheads_ngroups_ratio,
|
50 |
+
# Strides
|
51 |
+
stride_state_batch,
|
52 |
+
stride_state_head,
|
53 |
+
stride_state_dim,
|
54 |
+
stride_state_dstate,
|
55 |
+
stride_x_batch,
|
56 |
+
stride_x_head,
|
57 |
+
stride_x_dim,
|
58 |
+
stride_dt_batch,
|
59 |
+
stride_dt_head,
|
60 |
+
stride_dt_dim,
|
61 |
+
stride_dt_bias_head,
|
62 |
+
stride_dt_bias_dim,
|
63 |
+
stride_A_head,
|
64 |
+
stride_A_dim,
|
65 |
+
stride_A_dstate,
|
66 |
+
stride_B_batch,
|
67 |
+
stride_B_group,
|
68 |
+
stride_B_dstate,
|
69 |
+
stride_C_batch,
|
70 |
+
stride_C_group,
|
71 |
+
stride_C_dstate,
|
72 |
+
stride_D_head,
|
73 |
+
stride_D_dim,
|
74 |
+
stride_z_batch,
|
75 |
+
stride_z_head,
|
76 |
+
stride_z_dim,
|
77 |
+
stride_out_batch,
|
78 |
+
stride_out_head,
|
79 |
+
stride_out_dim,
|
80 |
+
# Meta-parameters
|
81 |
+
DT_SOFTPLUS: tl.constexpr,
|
82 |
+
TIE_HDIM: tl.constexpr,
|
83 |
+
BLOCK_SIZE_M: tl.constexpr,
|
84 |
+
HAS_DT_BIAS: tl.constexpr,
|
85 |
+
HAS_D: tl.constexpr,
|
86 |
+
HAS_Z: tl.constexpr,
|
87 |
+
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
88 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
89 |
+
):
|
90 |
+
pid_m = tl.program_id(axis=0)
|
91 |
+
pid_b = tl.program_id(axis=1)
|
92 |
+
pid_h = tl.program_id(axis=2)
|
93 |
+
|
94 |
+
if HAS_STATE_BATCH_INDICES:
|
95 |
+
state_batch_indices_ptr += pid_b
|
96 |
+
state_batch_idx = tl.load(state_batch_indices_ptr)
|
97 |
+
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
98 |
+
else:
|
99 |
+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
100 |
+
|
101 |
+
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
102 |
+
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
103 |
+
if HAS_DT_BIAS:
|
104 |
+
dt_bias_ptr += pid_h * stride_dt_bias_head
|
105 |
+
A_ptr += pid_h * stride_A_head
|
106 |
+
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
107 |
+
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
108 |
+
if HAS_Z:
|
109 |
+
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
110 |
+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
111 |
+
|
112 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
113 |
+
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
114 |
+
state_ptrs = state_ptr + (
|
115 |
+
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
116 |
+
)
|
117 |
+
x_ptrs = x_ptr + offs_m * stride_x_dim
|
118 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
119 |
+
if HAS_DT_BIAS:
|
120 |
+
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
121 |
+
if HAS_D:
|
122 |
+
D_ptr += pid_h * stride_D_head
|
123 |
+
A_ptrs = A_ptr + (
|
124 |
+
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
125 |
+
)
|
126 |
+
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
127 |
+
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
128 |
+
if HAS_D:
|
129 |
+
D_ptrs = D_ptr + offs_m * stride_D_dim
|
130 |
+
if HAS_Z:
|
131 |
+
z_ptrs = z_ptr + offs_m * stride_z_dim
|
132 |
+
out_ptrs = out_ptr + offs_m * stride_out_dim
|
133 |
+
|
134 |
+
state = tl.load(
|
135 |
+
state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
136 |
+
)
|
137 |
+
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
138 |
+
if not TIE_HDIM:
|
139 |
+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
140 |
+
if HAS_DT_BIAS:
|
141 |
+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
142 |
+
if DT_SOFTPLUS:
|
143 |
+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
144 |
+
A = tl.load(
|
145 |
+
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
146 |
+
).to(tl.float32)
|
147 |
+
dA = tl.exp(A * dt[:, None])
|
148 |
+
else:
|
149 |
+
dt = tl.load(dt_ptr).to(tl.float32)
|
150 |
+
if HAS_DT_BIAS:
|
151 |
+
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
152 |
+
if DT_SOFTPLUS:
|
153 |
+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
154 |
+
A = tl.load(A_ptr).to(tl.float32)
|
155 |
+
dA = tl.exp(A * dt) # scalar, not a matrix
|
156 |
+
|
157 |
+
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
158 |
+
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
159 |
+
if HAS_D:
|
160 |
+
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
161 |
+
if HAS_Z:
|
162 |
+
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
163 |
+
|
164 |
+
if not TIE_HDIM:
|
165 |
+
dB = B[None, :] * dt[:, None]
|
166 |
+
else:
|
167 |
+
dB = B * dt # vector of size (dstate,)
|
168 |
+
state = state * dA + dB * x[:, None]
|
169 |
+
tl.store(
|
170 |
+
state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
171 |
+
)
|
172 |
+
out = tl.sum(state * C[None, :], axis=1)
|
173 |
+
if HAS_D:
|
174 |
+
out += x * D
|
175 |
+
if HAS_Z:
|
176 |
+
out *= z * tl.sigmoid(z)
|
177 |
+
tl.store(out_ptrs, out, mask=offs_m < dim)
|
178 |
+
|
179 |
+
|
180 |
+
def selective_state_update(
|
181 |
+
state,
|
182 |
+
x,
|
183 |
+
dt,
|
184 |
+
A,
|
185 |
+
B,
|
186 |
+
C,
|
187 |
+
D=None,
|
188 |
+
z=None,
|
189 |
+
dt_bias=None,
|
190 |
+
dt_softplus=False,
|
191 |
+
state_batch_indices=None,
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
Argument:
|
195 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
196 |
+
x: (batch, dim) or (batch, nheads, dim)
|
197 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
198 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
199 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
200 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
201 |
+
D: (dim,) or (nheads, dim)
|
202 |
+
z: (batch, dim) or (batch, nheads, dim)
|
203 |
+
dt_bias: (dim,) or (nheads, dim)
|
204 |
+
Return:
|
205 |
+
out: (batch, dim) or (batch, nheads, dim)
|
206 |
+
"""
|
207 |
+
has_heads = state.dim() > 3
|
208 |
+
if state.dim() == 3:
|
209 |
+
state = state.unsqueeze(1)
|
210 |
+
if x.dim() == 2:
|
211 |
+
x = x.unsqueeze(1)
|
212 |
+
if dt.dim() == 2:
|
213 |
+
dt = dt.unsqueeze(1)
|
214 |
+
if A.dim() == 2:
|
215 |
+
A = A.unsqueeze(0)
|
216 |
+
if B.dim() == 2:
|
217 |
+
B = B.unsqueeze(1)
|
218 |
+
if C.dim() == 2:
|
219 |
+
C = C.unsqueeze(1)
|
220 |
+
if D is not None and D.dim() == 1:
|
221 |
+
D = D.unsqueeze(0)
|
222 |
+
if z is not None and z.dim() == 2:
|
223 |
+
z = z.unsqueeze(1)
|
224 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
225 |
+
dt_bias = dt_bias.unsqueeze(0)
|
226 |
+
_, nheads, dim, dstate = state.shape
|
227 |
+
batch = x.shape[0]
|
228 |
+
if x.shape != (batch, nheads, dim):
|
229 |
+
print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
|
230 |
+
assert x.shape == (batch, nheads, dim)
|
231 |
+
assert dt.shape == x.shape
|
232 |
+
assert A.shape == (nheads, dim, dstate)
|
233 |
+
ngroups = B.shape[1]
|
234 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
235 |
+
assert B.shape == (batch, ngroups, dstate)
|
236 |
+
assert C.shape == B.shape
|
237 |
+
if D is not None:
|
238 |
+
assert D.shape == (nheads, dim)
|
239 |
+
if z is not None:
|
240 |
+
assert z.shape == x.shape
|
241 |
+
if dt_bias is not None:
|
242 |
+
assert dt_bias.shape == (nheads, dim)
|
243 |
+
if state_batch_indices is not None:
|
244 |
+
assert state_batch_indices.shape == (batch,)
|
245 |
+
out = torch.empty_like(x)
|
246 |
+
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
247 |
+
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
248 |
+
# We don't want autotune since it will overwrite the state
|
249 |
+
# We instead tune by hand.
|
250 |
+
BLOCK_SIZE_M, num_warps = (
|
251 |
+
(32, 4)
|
252 |
+
if dstate <= 16
|
253 |
+
else (
|
254 |
+
(16, 4)
|
255 |
+
if dstate <= 32
|
256 |
+
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
257 |
+
)
|
258 |
+
)
|
259 |
+
tie_hdim = (
|
260 |
+
A.stride(-1) == 0
|
261 |
+
and A.stride(-2) == 0
|
262 |
+
and dt.stride(-1) == 0
|
263 |
+
and dt_bias.stride(-1) == 0
|
264 |
+
)
|
265 |
+
with torch.cuda.device(x.device.index):
|
266 |
+
_selective_scan_update_kernel[grid](
|
267 |
+
state,
|
268 |
+
x,
|
269 |
+
dt,
|
270 |
+
dt_bias,
|
271 |
+
A,
|
272 |
+
B,
|
273 |
+
C,
|
274 |
+
D,
|
275 |
+
z,
|
276 |
+
out,
|
277 |
+
state_batch_indices,
|
278 |
+
batch,
|
279 |
+
nheads,
|
280 |
+
dim,
|
281 |
+
dstate,
|
282 |
+
nheads // ngroups,
|
283 |
+
state.stride(0),
|
284 |
+
state.stride(1),
|
285 |
+
state.stride(2),
|
286 |
+
state.stride(3),
|
287 |
+
x.stride(0),
|
288 |
+
x.stride(1),
|
289 |
+
x.stride(2),
|
290 |
+
dt.stride(0),
|
291 |
+
dt.stride(1),
|
292 |
+
dt.stride(2),
|
293 |
+
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
294 |
+
A.stride(0),
|
295 |
+
A.stride(1),
|
296 |
+
A.stride(2),
|
297 |
+
B.stride(0),
|
298 |
+
B.stride(1),
|
299 |
+
B.stride(2),
|
300 |
+
C.stride(0),
|
301 |
+
C.stride(1),
|
302 |
+
C.stride(2),
|
303 |
+
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
304 |
+
z_strides[0],
|
305 |
+
z_strides[1],
|
306 |
+
z_strides[2],
|
307 |
+
out.stride(0),
|
308 |
+
out.stride(1),
|
309 |
+
out.stride(2),
|
310 |
+
dt_softplus,
|
311 |
+
tie_hdim,
|
312 |
+
BLOCK_SIZE_M,
|
313 |
+
num_warps=num_warps,
|
314 |
+
)
|
315 |
+
if not has_heads:
|
316 |
+
out = out.squeeze(1)
|
317 |
+
return out
|
318 |
+
|
319 |
+
|
320 |
+
def selective_state_update_ref(
|
321 |
+
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
|
322 |
+
):
|
323 |
+
"""
|
324 |
+
Argument:
|
325 |
+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
326 |
+
x: (batch, dim) or (batch, nheads, dim)
|
327 |
+
dt: (batch, dim) or (batch, nheads, dim)
|
328 |
+
A: (dim, dstate) or (nheads, dim, dstate)
|
329 |
+
B: (batch, dstate) or (batch, ngroups, dstate)
|
330 |
+
C: (batch, dstate) or (batch, ngroups, dstate)
|
331 |
+
D: (dim,) or (nheads, dim)
|
332 |
+
z: (batch, dim) or (batch, nheads, dim)
|
333 |
+
dt_bias: (dim,) or (nheads, dim)
|
334 |
+
Return:
|
335 |
+
out: (batch, dim) or (batch, nheads, dim)
|
336 |
+
"""
|
337 |
+
has_heads = state.dim() > 3
|
338 |
+
if state.dim() == 3:
|
339 |
+
state = state.unsqueeze(1)
|
340 |
+
if x.dim() == 2:
|
341 |
+
x = x.unsqueeze(1)
|
342 |
+
if dt.dim() == 2:
|
343 |
+
dt = dt.unsqueeze(1)
|
344 |
+
if A.dim() == 2:
|
345 |
+
A = A.unsqueeze(0)
|
346 |
+
if B.dim() == 2:
|
347 |
+
B = B.unsqueeze(1)
|
348 |
+
if C.dim() == 2:
|
349 |
+
C = C.unsqueeze(1)
|
350 |
+
if D is not None and D.dim() == 1:
|
351 |
+
D = D.unsqueeze(0)
|
352 |
+
if z is not None and z.dim() == 2:
|
353 |
+
z = z.unsqueeze(1)
|
354 |
+
if dt_bias is not None and dt_bias.dim() == 1:
|
355 |
+
dt_bias = dt_bias.unsqueeze(0)
|
356 |
+
batch, nheads, dim, dstate = state.shape
|
357 |
+
assert x.shape == (batch, nheads, dim)
|
358 |
+
assert dt.shape == x.shape
|
359 |
+
assert A.shape == (nheads, dim, dstate)
|
360 |
+
ngroups = B.shape[1]
|
361 |
+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
362 |
+
assert B.shape == (batch, ngroups, dstate)
|
363 |
+
assert C.shape == B.shape
|
364 |
+
if D is not None:
|
365 |
+
assert D.shape == (nheads, dim)
|
366 |
+
if z is not None:
|
367 |
+
assert z.shape == x.shape
|
368 |
+
if dt_bias is not None:
|
369 |
+
assert dt_bias.shape == (nheads, dim)
|
370 |
+
dt = dt + dt_bias
|
371 |
+
dt = F.softplus(dt) if dt_softplus else dt
|
372 |
+
dA = torch.exp(
|
373 |
+
rearrange(dt, "b h d -> b h d 1") * A
|
374 |
+
) # (batch, nheads, dim, dstate)
|
375 |
+
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
376 |
+
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
|
377 |
+
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
|
378 |
+
B, "b h n -> b h 1 n"
|
379 |
+
) # (batch, nheads, dim, dstate)
|
380 |
+
state.copy_(
|
381 |
+
state * dA + dB * rearrange(x, "b h d -> b h d 1")
|
382 |
+
) # (batch, dim, dstate
|
383 |
+
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
|
384 |
+
if D is not None:
|
385 |
+
out += (x * D).to(out.dtype)
|
386 |
+
out = (out if z is None else out * F.silu(z)).to(x.dtype)
|
387 |
+
if not has_heads:
|
388 |
+
out = out.squeeze(1)
|
389 |
+
return out
|
torch-ext/mamba_ssm/ops/triton/softplus.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
from packaging import version
|
4 |
+
|
5 |
+
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
6 |
+
|
7 |
+
|
8 |
+
if TRITON3:
|
9 |
+
@triton.jit
|
10 |
+
def softplus(dt):
|
11 |
+
return tl.math.log(tl.math.exp(dt) + 1)
|
12 |
+
else:
|
13 |
+
@triton.jit
|
14 |
+
def softplus(dt):
|
15 |
+
return tl.math.log1p(tl.exp(dt))
|
torch-ext/mamba_ssm/ops/triton/ssd_bmm.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
|
16 |
+
def init_to_zero(names):
|
17 |
+
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
|
18 |
+
|
19 |
+
|
20 |
+
@triton.autotune(
|
21 |
+
configs=[
|
22 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
|
23 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
24 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
25 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
26 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
27 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
|
28 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
29 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
|
30 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
|
31 |
+
],
|
32 |
+
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
33 |
+
)
|
34 |
+
@triton.jit
|
35 |
+
def _bmm_chunk_fwd_kernel(
|
36 |
+
# Pointers to matrices
|
37 |
+
a_ptr, b_ptr, out_ptr, seq_idx_ptr,
|
38 |
+
# Matrix dimensions
|
39 |
+
seqlen, chunk_size, K, ngroups,
|
40 |
+
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
41 |
+
stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
|
42 |
+
stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
|
43 |
+
stride_seq_idx_batch, stride_seq_idx_seqlen,
|
44 |
+
# Meta-parameters
|
45 |
+
IS_CAUSAL: tl.constexpr,
|
46 |
+
dot_dtype: tl.constexpr,
|
47 |
+
HAS_SEQ_IDX: tl.constexpr,
|
48 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
49 |
+
):
|
50 |
+
pid_b = tl.program_id(axis=1)
|
51 |
+
pid_ch = tl.program_id(axis=2)
|
52 |
+
pid_c = pid_ch // ngroups
|
53 |
+
pid_h = pid_ch - pid_c * ngroups
|
54 |
+
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
55 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
56 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
57 |
+
if IS_CAUSAL:
|
58 |
+
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
59 |
+
return
|
60 |
+
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
61 |
+
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
62 |
+
if HAS_SEQ_IDX:
|
63 |
+
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
64 |
+
|
65 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
66 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
67 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
68 |
+
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
69 |
+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
70 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
71 |
+
|
72 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
73 |
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
74 |
+
a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
|
75 |
+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
|
76 |
+
acc += tl.dot(a, b)
|
77 |
+
a_ptrs += BLOCK_SIZE_K * stride_ak
|
78 |
+
b_ptrs += BLOCK_SIZE_K * stride_bk
|
79 |
+
|
80 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
81 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
82 |
+
if HAS_SEQ_IDX:
|
83 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
84 |
+
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
85 |
+
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
|
86 |
+
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
87 |
+
out = acc.to(out_ptr.dtype.element_ty)
|
88 |
+
|
89 |
+
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
90 |
+
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
91 |
+
tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
|
92 |
+
|
93 |
+
|
94 |
+
@triton.autotune(
|
95 |
+
configs=[
|
96 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
|
97 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
98 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
99 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
100 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
101 |
+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
|
102 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
103 |
+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
|
104 |
+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
|
105 |
+
],
|
106 |
+
key=['chunk_size', 'K'],
|
107 |
+
)
|
108 |
+
@triton.jit
|
109 |
+
def _bmm_chunk_bwd_kernel(
|
110 |
+
# Pointers to matrices
|
111 |
+
a_ptr, dout_ptr, db_ptr, res_ptr,
|
112 |
+
# Matrix dimensions
|
113 |
+
seqlen, chunk_size, K, ngroups,
|
114 |
+
stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
|
115 |
+
stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
|
116 |
+
stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
|
117 |
+
stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
|
118 |
+
# Meta-parameters
|
119 |
+
dot_dtype: tl.constexpr,
|
120 |
+
HAS_RESIDUAL: tl.constexpr,
|
121 |
+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
|
122 |
+
):
|
123 |
+
pid_b = tl.program_id(axis=1)
|
124 |
+
pid_ch = tl.program_id(axis=2)
|
125 |
+
pid_c = pid_ch // ngroups
|
126 |
+
pid_h = pid_ch - pid_c * ngroups
|
127 |
+
num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
|
128 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
129 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
130 |
+
|
131 |
+
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
132 |
+
dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
|
133 |
+
|
134 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
135 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
136 |
+
offs_cs = tl.arange(0, BLOCK_SIZE_CS)
|
137 |
+
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
|
138 |
+
a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
|
139 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
140 |
+
|
141 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
142 |
+
for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
|
143 |
+
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
|
144 |
+
a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
|
145 |
+
acc += tl.dot(dout, a)
|
146 |
+
dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
|
147 |
+
a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
|
148 |
+
|
149 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
150 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
151 |
+
if HAS_RESIDUAL:
|
152 |
+
res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
|
153 |
+
res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
|
154 |
+
res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
|
155 |
+
acc += res
|
156 |
+
db = acc.to(db_ptr.dtype.element_ty)
|
157 |
+
|
158 |
+
db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
|
159 |
+
db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
|
160 |
+
tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
|
161 |
+
|
162 |
+
|
163 |
+
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
|
164 |
+
"""
|
165 |
+
Argument:
|
166 |
+
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
167 |
+
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
168 |
+
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
169 |
+
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
170 |
+
guaranteed to be correct.
|
171 |
+
Return:
|
172 |
+
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
173 |
+
"""
|
174 |
+
# Check constraints.
|
175 |
+
has_groups = a.dim() == 4
|
176 |
+
if not has_groups:
|
177 |
+
batch, seqlen, k = a.shape
|
178 |
+
else:
|
179 |
+
batch, seqlen, ngroups, k = a.shape
|
180 |
+
assert b.shape == a.shape
|
181 |
+
if seq_idx is not None:
|
182 |
+
assert seq_idx.shape == (batch, seqlen)
|
183 |
+
if a.stride(-1) != 1 and a.stride(1) != 1:
|
184 |
+
a = a.contiguous()
|
185 |
+
if b.stride(-1) != 1 and b.stride(1) != 1:
|
186 |
+
b = b.contiguous()
|
187 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
188 |
+
# Allocates output.
|
189 |
+
out_dtype = a.dtype if output_dtype is None else output_dtype
|
190 |
+
out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
|
191 |
+
device=a.device, dtype=out_dtype)
|
192 |
+
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
193 |
+
(tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
|
194 |
+
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
|
195 |
+
batch, nchunks if not has_groups else nchunks * ngroups)
|
196 |
+
with torch.cuda.device(a.device.index):
|
197 |
+
_bmm_chunk_fwd_kernel[grid](
|
198 |
+
a, b, out, seq_idx,
|
199 |
+
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
200 |
+
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
201 |
+
b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
|
202 |
+
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
|
203 |
+
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
204 |
+
causal,
|
205 |
+
dot_dtype,
|
206 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
207 |
+
)
|
208 |
+
return out
|
209 |
+
|
210 |
+
|
211 |
+
def _bmm_chunk_bwd(a, dout, residual=None, out=None):
|
212 |
+
"""
|
213 |
+
Argument:
|
214 |
+
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
215 |
+
dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
216 |
+
residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
217 |
+
Return:
|
218 |
+
out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
219 |
+
|
220 |
+
If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
|
221 |
+
zeroed out before calling this function.
|
222 |
+
"""
|
223 |
+
# Check constraints.
|
224 |
+
has_groups = a.dim() == 4
|
225 |
+
if not has_groups:
|
226 |
+
batch, seqlen, k = a.shape
|
227 |
+
else:
|
228 |
+
batch, seqlen, ngroups, k = a.shape
|
229 |
+
nchunks, chunk_size = dout.shape[1], dout.shape[-1]
|
230 |
+
if a.stride(-1) != 1 and a.stride(-2) != 1:
|
231 |
+
a = a.contiguous()
|
232 |
+
if dout.stride(-1) != 1 and dout.stride(-2) != 1:
|
233 |
+
dout = dout.contiguous()
|
234 |
+
if residual is not None:
|
235 |
+
assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
|
236 |
+
if residual.stride(-1) != 1 and residual.stride(1) != 1:
|
237 |
+
residual = residual.contiguous()
|
238 |
+
# Allocates output.
|
239 |
+
if out is not None:
|
240 |
+
assert out.shape == a.shape
|
241 |
+
assert out.stride(-1) == 1 or out.stride(1) == 1
|
242 |
+
else:
|
243 |
+
out = torch.empty_like(a)
|
244 |
+
dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
|
245 |
+
(tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
|
246 |
+
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
|
247 |
+
nchunks if not has_groups else nchunks * ngroups)
|
248 |
+
residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
|
249 |
+
residual.stride(-1))
|
250 |
+
if residual is not None else (0, 0, 0, 0))
|
251 |
+
with torch.cuda.device(a.device.index):
|
252 |
+
_bmm_chunk_bwd_kernel[grid](
|
253 |
+
a, dout, out, residual,
|
254 |
+
seqlen, chunk_size, k, ngroups if has_groups else 1,
|
255 |
+
a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
|
256 |
+
dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
|
257 |
+
out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
|
258 |
+
residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
|
259 |
+
dot_dtype,
|
260 |
+
HAS_RESIDUAL=residual is not None,
|
261 |
+
)
|
262 |
+
return out
|
torch-ext/mamba_ssm/ops/triton/ssd_chunk_scan.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
torch-ext/mamba_ssm/ops/triton/ssd_chunk_state.py
ADDED
@@ -0,0 +1,2012 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
2 |
+
|
3 |
+
"""We want triton==2.1.0 or 2.2.0 for this
|
4 |
+
"""
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import triton
|
11 |
+
import triton.language as tl
|
12 |
+
|
13 |
+
from einops import rearrange, repeat
|
14 |
+
|
15 |
+
from .softplus import softplus
|
16 |
+
|
17 |
+
|
18 |
+
def init_to_zero(names):
|
19 |
+
return lambda nargs: [
|
20 |
+
nargs[name].zero_() for name in names if nargs[name] is not None
|
21 |
+
]
|
22 |
+
|
23 |
+
|
24 |
+
@triton.autotune(
|
25 |
+
configs=[
|
26 |
+
triton.Config({"BLOCK_SIZE_H": 1}),
|
27 |
+
triton.Config({"BLOCK_SIZE_H": 2}),
|
28 |
+
triton.Config({"BLOCK_SIZE_H": 4}),
|
29 |
+
triton.Config({"BLOCK_SIZE_H": 8}),
|
30 |
+
triton.Config({"BLOCK_SIZE_H": 16}),
|
31 |
+
triton.Config({"BLOCK_SIZE_H": 32}),
|
32 |
+
triton.Config({"BLOCK_SIZE_H": 64}),
|
33 |
+
],
|
34 |
+
key=["chunk_size", "nheads"],
|
35 |
+
)
|
36 |
+
@triton.jit
|
37 |
+
def _chunk_cumsum_fwd_kernel(
|
38 |
+
# Pointers to matrices
|
39 |
+
dt_ptr,
|
40 |
+
A_ptr,
|
41 |
+
dt_bias_ptr,
|
42 |
+
dt_out_ptr,
|
43 |
+
dA_cumsum_ptr,
|
44 |
+
# Matrix dimension
|
45 |
+
batch,
|
46 |
+
seqlen,
|
47 |
+
nheads,
|
48 |
+
chunk_size,
|
49 |
+
dt_min,
|
50 |
+
dt_max,
|
51 |
+
# Strides
|
52 |
+
stride_dt_batch,
|
53 |
+
stride_dt_seqlen,
|
54 |
+
stride_dt_head,
|
55 |
+
stride_A_head,
|
56 |
+
stride_dt_bias_head,
|
57 |
+
stride_dt_out_batch,
|
58 |
+
stride_dt_out_chunk,
|
59 |
+
stride_dt_out_head,
|
60 |
+
stride_dt_out_csize,
|
61 |
+
stride_dA_cs_batch,
|
62 |
+
stride_dA_cs_chunk,
|
63 |
+
stride_dA_cs_head,
|
64 |
+
stride_dA_cs_csize,
|
65 |
+
# Meta-parameters
|
66 |
+
DT_SOFTPLUS: tl.constexpr,
|
67 |
+
HAS_DT_BIAS: tl.constexpr,
|
68 |
+
BLOCK_SIZE_H: tl.constexpr,
|
69 |
+
BLOCK_SIZE_CHUNK: tl.constexpr,
|
70 |
+
):
|
71 |
+
pid_b = tl.program_id(axis=0)
|
72 |
+
pid_c = tl.program_id(axis=1)
|
73 |
+
pid_h = tl.program_id(axis=2)
|
74 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
75 |
+
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
76 |
+
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
77 |
+
|
78 |
+
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
79 |
+
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
80 |
+
dt_ptrs = dt_ptr + (
|
81 |
+
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
82 |
+
)
|
83 |
+
A_ptrs = A_ptr + offs_h * stride_A_head
|
84 |
+
dt_out_ptrs = dt_out_ptr + (
|
85 |
+
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
86 |
+
)
|
87 |
+
dA_cs_ptrs = dA_cumsum_ptr + (
|
88 |
+
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
89 |
+
)
|
90 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
91 |
+
|
92 |
+
dt = tl.load(
|
93 |
+
dt_ptrs,
|
94 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
95 |
+
other=0.0,
|
96 |
+
).to(tl.float32)
|
97 |
+
if HAS_DT_BIAS:
|
98 |
+
dt_bias = tl.load(
|
99 |
+
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
100 |
+
).to(tl.float32)
|
101 |
+
dt += dt_bias[:, None]
|
102 |
+
if DT_SOFTPLUS:
|
103 |
+
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
104 |
+
# As of Triton 2.2.0, tl.clamp is not available yet
|
105 |
+
# dt = tl.clamp(dt, dt_min, dt_max)
|
106 |
+
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
107 |
+
dt = tl.where(
|
108 |
+
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
109 |
+
)
|
110 |
+
tl.store(
|
111 |
+
dt_out_ptrs,
|
112 |
+
dt,
|
113 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
114 |
+
)
|
115 |
+
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
116 |
+
dA = dt * A[:, None]
|
117 |
+
dA_cs = tl.cumsum(dA, axis=1)
|
118 |
+
tl.store(
|
119 |
+
dA_cs_ptrs,
|
120 |
+
dA_cs,
|
121 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
@triton.autotune(
|
126 |
+
configs=[
|
127 |
+
triton.Config(
|
128 |
+
{"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
129 |
+
),
|
130 |
+
triton.Config(
|
131 |
+
{"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
132 |
+
),
|
133 |
+
triton.Config(
|
134 |
+
{"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
135 |
+
),
|
136 |
+
triton.Config(
|
137 |
+
{"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
138 |
+
),
|
139 |
+
triton.Config(
|
140 |
+
{"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
141 |
+
),
|
142 |
+
triton.Config(
|
143 |
+
{"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
144 |
+
),
|
145 |
+
triton.Config(
|
146 |
+
{"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
|
147 |
+
),
|
148 |
+
],
|
149 |
+
key=["chunk_size", "nheads"],
|
150 |
+
)
|
151 |
+
@triton.jit
|
152 |
+
def _chunk_cumsum_bwd_kernel(
|
153 |
+
# Pointers to matrices
|
154 |
+
ddA_ptr,
|
155 |
+
ddt_out_ptr,
|
156 |
+
dt_ptr,
|
157 |
+
A_ptr,
|
158 |
+
dt_bias_ptr,
|
159 |
+
ddt_ptr,
|
160 |
+
dA_ptr,
|
161 |
+
ddt_bias_ptr,
|
162 |
+
# Matrix dimensions
|
163 |
+
batch,
|
164 |
+
seqlen,
|
165 |
+
nheads,
|
166 |
+
chunk_size,
|
167 |
+
dt_min,
|
168 |
+
dt_max,
|
169 |
+
# Strides
|
170 |
+
stride_ddA_batch,
|
171 |
+
stride_ddA_chunk,
|
172 |
+
stride_ddA_head,
|
173 |
+
stride_ddA_csize,
|
174 |
+
stride_ddt_out_batch,
|
175 |
+
stride_ddt_out_chunk,
|
176 |
+
stride_ddt_out_head,
|
177 |
+
stride_ddt_out_csize,
|
178 |
+
stride_dt_batch,
|
179 |
+
stride_dt_seqlen,
|
180 |
+
stride_dt_head,
|
181 |
+
stride_A_head,
|
182 |
+
stride_dt_bias_head,
|
183 |
+
stride_ddt_batch,
|
184 |
+
stride_ddt_seqlen,
|
185 |
+
stride_ddt_head,
|
186 |
+
stride_dA_head,
|
187 |
+
stride_ddt_bias_head,
|
188 |
+
# Meta-parameters
|
189 |
+
DT_SOFTPLUS: tl.constexpr,
|
190 |
+
HAS_DT_BIAS: tl.constexpr,
|
191 |
+
BLOCK_SIZE_H: tl.constexpr,
|
192 |
+
BLOCK_SIZE_CHUNK: tl.constexpr,
|
193 |
+
):
|
194 |
+
pid_b = tl.program_id(axis=0)
|
195 |
+
pid_c = tl.program_id(axis=1)
|
196 |
+
pid_h = tl.program_id(axis=2)
|
197 |
+
ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
|
198 |
+
ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
|
199 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
200 |
+
ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
|
201 |
+
|
202 |
+
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
203 |
+
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
204 |
+
ddt_out_ptrs = ddt_out_ptr + (
|
205 |
+
offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
|
206 |
+
)
|
207 |
+
ddA_ptrs = ddA_ptr + (
|
208 |
+
offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
|
209 |
+
)
|
210 |
+
dt_ptrs = dt_ptr + (
|
211 |
+
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
212 |
+
)
|
213 |
+
ddt_ptrs = ddt_ptr + (
|
214 |
+
offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
|
215 |
+
)
|
216 |
+
A_ptrs = A_ptr + offs_h * stride_A_head
|
217 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
218 |
+
|
219 |
+
ddA = tl.load(
|
220 |
+
ddA_ptrs,
|
221 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
222 |
+
other=0.0,
|
223 |
+
).to(tl.float32)
|
224 |
+
ddt_out = tl.load(
|
225 |
+
ddt_out_ptrs,
|
226 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
227 |
+
other=0.0,
|
228 |
+
).to(tl.float32)
|
229 |
+
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
230 |
+
ddt = ddA * A[:, None] + ddt_out
|
231 |
+
dt = tl.load(
|
232 |
+
dt_ptrs,
|
233 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
234 |
+
other=0.0,
|
235 |
+
).to(tl.float32)
|
236 |
+
if HAS_DT_BIAS:
|
237 |
+
dt_bias = tl.load(
|
238 |
+
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
239 |
+
).to(tl.float32)
|
240 |
+
dt += dt_bias[:, None]
|
241 |
+
if DT_SOFTPLUS:
|
242 |
+
dt_presoftplus = dt
|
243 |
+
dt = tl.where(dt <= 20.0, softplus(dt), ddt)
|
244 |
+
clamp_mask = (dt < dt_min) | (dt > dt_max)
|
245 |
+
# As of Triton 2.2.0, tl.clamp is not available yet
|
246 |
+
# dt = tl.clamp(dt, dt_min, dt_max)
|
247 |
+
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
248 |
+
dt = tl.where(
|
249 |
+
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
250 |
+
)
|
251 |
+
ddt = tl.where(
|
252 |
+
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
|
253 |
+
)
|
254 |
+
ddt = tl.where(clamp_mask, 0.0, ddt)
|
255 |
+
if DT_SOFTPLUS:
|
256 |
+
ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
|
257 |
+
tl.store(
|
258 |
+
ddt_ptrs,
|
259 |
+
ddt,
|
260 |
+
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
261 |
+
)
|
262 |
+
dA = tl.sum(ddA * dt, axis=1)
|
263 |
+
tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
|
264 |
+
if HAS_DT_BIAS:
|
265 |
+
ddt_bias = tl.sum(ddt, axis=1)
|
266 |
+
tl.atomic_add(
|
267 |
+
ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
@triton.autotune(
|
272 |
+
configs=[
|
273 |
+
triton.Config(
|
274 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
275 |
+
num_stages=3,
|
276 |
+
num_warps=8,
|
277 |
+
),
|
278 |
+
triton.Config(
|
279 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
280 |
+
num_stages=4,
|
281 |
+
num_warps=4,
|
282 |
+
),
|
283 |
+
triton.Config(
|
284 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
285 |
+
num_stages=4,
|
286 |
+
num_warps=4,
|
287 |
+
),
|
288 |
+
triton.Config(
|
289 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
290 |
+
num_stages=4,
|
291 |
+
num_warps=4,
|
292 |
+
),
|
293 |
+
triton.Config(
|
294 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
295 |
+
num_stages=4,
|
296 |
+
num_warps=4,
|
297 |
+
),
|
298 |
+
triton.Config(
|
299 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
300 |
+
num_stages=4,
|
301 |
+
num_warps=4,
|
302 |
+
),
|
303 |
+
triton.Config(
|
304 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
305 |
+
num_stages=5,
|
306 |
+
num_warps=2,
|
307 |
+
),
|
308 |
+
triton.Config(
|
309 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
310 |
+
num_stages=5,
|
311 |
+
num_warps=2,
|
312 |
+
),
|
313 |
+
triton.Config(
|
314 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
315 |
+
num_stages=4,
|
316 |
+
num_warps=2,
|
317 |
+
),
|
318 |
+
],
|
319 |
+
key=["hdim", "dstate", "chunk_size"],
|
320 |
+
)
|
321 |
+
@triton.jit
|
322 |
+
def _chunk_state_fwd_kernel(
|
323 |
+
# Pointers to matrices
|
324 |
+
x_ptr,
|
325 |
+
b_ptr,
|
326 |
+
states_ptr,
|
327 |
+
dt_ptr,
|
328 |
+
dA_cumsum_ptr,
|
329 |
+
seq_idx_ptr,
|
330 |
+
# Matrix dimensions
|
331 |
+
hdim,
|
332 |
+
dstate,
|
333 |
+
chunk_size,
|
334 |
+
batch,
|
335 |
+
seqlen,
|
336 |
+
nheads_ngroups_ratio,
|
337 |
+
# Strides
|
338 |
+
stride_x_batch,
|
339 |
+
stride_x_seqlen,
|
340 |
+
stride_x_head,
|
341 |
+
stride_x_hdim,
|
342 |
+
stride_b_batch,
|
343 |
+
stride_b_seqlen,
|
344 |
+
stride_b_head,
|
345 |
+
stride_b_dstate,
|
346 |
+
stride_states_batch,
|
347 |
+
stride_states_chunk,
|
348 |
+
stride_states_head,
|
349 |
+
stride_states_hdim,
|
350 |
+
stride_states_dstate,
|
351 |
+
stride_dt_batch,
|
352 |
+
stride_dt_chunk,
|
353 |
+
stride_dt_head,
|
354 |
+
stride_dt_csize,
|
355 |
+
stride_dA_cs_batch,
|
356 |
+
stride_dA_cs_chunk,
|
357 |
+
stride_dA_cs_head,
|
358 |
+
stride_dA_cs_csize,
|
359 |
+
stride_seq_idx_batch,
|
360 |
+
stride_seq_idx_seqlen,
|
361 |
+
# Meta-parameters
|
362 |
+
HAS_SEQ_IDX: tl.constexpr,
|
363 |
+
BLOCK_SIZE_M: tl.constexpr,
|
364 |
+
BLOCK_SIZE_N: tl.constexpr,
|
365 |
+
BLOCK_SIZE_K: tl.constexpr,
|
366 |
+
):
|
367 |
+
pid_bc = tl.program_id(axis=1)
|
368 |
+
pid_c = pid_bc // batch
|
369 |
+
pid_b = pid_bc - pid_c * batch
|
370 |
+
pid_h = tl.program_id(axis=2)
|
371 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
372 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
373 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
374 |
+
b_ptr += (
|
375 |
+
pid_b * stride_b_batch
|
376 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
377 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
378 |
+
)
|
379 |
+
x_ptr += (
|
380 |
+
pid_b * stride_x_batch
|
381 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
382 |
+
+ pid_h * stride_x_head
|
383 |
+
)
|
384 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
385 |
+
dA_cumsum_ptr += (
|
386 |
+
pid_b * stride_dA_cs_batch
|
387 |
+
+ pid_c * stride_dA_cs_chunk
|
388 |
+
+ pid_h * stride_dA_cs_head
|
389 |
+
)
|
390 |
+
if HAS_SEQ_IDX:
|
391 |
+
seq_idx_ptr += (
|
392 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
393 |
+
)
|
394 |
+
|
395 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
396 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
397 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
398 |
+
x_ptrs = x_ptr + (
|
399 |
+
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
400 |
+
)
|
401 |
+
b_ptrs = b_ptr + (
|
402 |
+
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
403 |
+
)
|
404 |
+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
405 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
406 |
+
tl.float32
|
407 |
+
)
|
408 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
409 |
+
if HAS_SEQ_IDX:
|
410 |
+
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
411 |
+
|
412 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
413 |
+
if HAS_SEQ_IDX:
|
414 |
+
seq_idx_last = tl.load(
|
415 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
416 |
+
)
|
417 |
+
|
418 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
419 |
+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
420 |
+
x = tl.load(
|
421 |
+
x_ptrs,
|
422 |
+
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
423 |
+
other=0.0,
|
424 |
+
)
|
425 |
+
b = tl.load(
|
426 |
+
b_ptrs,
|
427 |
+
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
428 |
+
other=0.0,
|
429 |
+
).to(tl.float32)
|
430 |
+
dA_cs_k = tl.load(
|
431 |
+
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
432 |
+
).to(tl.float32)
|
433 |
+
if HAS_SEQ_IDX:
|
434 |
+
seq_idx_k = tl.load(
|
435 |
+
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
436 |
+
)
|
437 |
+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
438 |
+
tl.float32
|
439 |
+
)
|
440 |
+
if not HAS_SEQ_IDX:
|
441 |
+
scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
|
442 |
+
else:
|
443 |
+
scale = tl.where(
|
444 |
+
seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
|
445 |
+
)
|
446 |
+
b *= scale[:, None]
|
447 |
+
b = b.to(x_ptr.dtype.element_ty)
|
448 |
+
acc += tl.dot(x, b)
|
449 |
+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
450 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
451 |
+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
452 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
453 |
+
if HAS_SEQ_IDX:
|
454 |
+
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
455 |
+
states = acc.to(states_ptr.dtype.element_ty)
|
456 |
+
|
457 |
+
states_ptr += (
|
458 |
+
pid_b * stride_states_batch
|
459 |
+
+ pid_c * stride_states_chunk
|
460 |
+
+ pid_h * stride_states_head
|
461 |
+
)
|
462 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
463 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
464 |
+
states_ptrs = states_ptr + (
|
465 |
+
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
466 |
+
)
|
467 |
+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
468 |
+
tl.store(states_ptrs, states, mask=c_mask)
|
469 |
+
|
470 |
+
|
471 |
+
@triton.autotune(
|
472 |
+
configs=[
|
473 |
+
triton.Config(
|
474 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
475 |
+
num_stages=3,
|
476 |
+
num_warps=8,
|
477 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
478 |
+
),
|
479 |
+
triton.Config(
|
480 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
481 |
+
num_stages=4,
|
482 |
+
num_warps=4,
|
483 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
484 |
+
),
|
485 |
+
triton.Config(
|
486 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
487 |
+
num_stages=4,
|
488 |
+
num_warps=4,
|
489 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
490 |
+
),
|
491 |
+
triton.Config(
|
492 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
493 |
+
num_stages=4,
|
494 |
+
num_warps=4,
|
495 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
496 |
+
),
|
497 |
+
triton.Config(
|
498 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
499 |
+
num_stages=4,
|
500 |
+
num_warps=4,
|
501 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
502 |
+
),
|
503 |
+
triton.Config(
|
504 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
505 |
+
num_stages=4,
|
506 |
+
num_warps=4,
|
507 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
508 |
+
),
|
509 |
+
triton.Config(
|
510 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
511 |
+
num_stages=5,
|
512 |
+
num_warps=4,
|
513 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
514 |
+
),
|
515 |
+
triton.Config(
|
516 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
517 |
+
num_stages=5,
|
518 |
+
num_warps=4,
|
519 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
520 |
+
),
|
521 |
+
triton.Config(
|
522 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
523 |
+
num_stages=4,
|
524 |
+
num_warps=4,
|
525 |
+
pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
|
526 |
+
),
|
527 |
+
],
|
528 |
+
key=["chunk_size", "hdim", "dstate"],
|
529 |
+
)
|
530 |
+
@triton.jit
|
531 |
+
def _chunk_state_bwd_dx_kernel(
|
532 |
+
# Pointers to matrices
|
533 |
+
x_ptr,
|
534 |
+
b_ptr,
|
535 |
+
dstates_ptr,
|
536 |
+
dt_ptr,
|
537 |
+
dA_cumsum_ptr,
|
538 |
+
dx_ptr,
|
539 |
+
ddt_ptr,
|
540 |
+
ddA_cumsum_ptr,
|
541 |
+
# Matrix dimensions
|
542 |
+
chunk_size,
|
543 |
+
hdim,
|
544 |
+
dstate,
|
545 |
+
batch,
|
546 |
+
seqlen,
|
547 |
+
nheads_ngroups_ratio,
|
548 |
+
# Strides
|
549 |
+
stride_x_batch,
|
550 |
+
stride_x_seqlen,
|
551 |
+
stride_x_head,
|
552 |
+
stride_x_hdim,
|
553 |
+
stride_b_batch,
|
554 |
+
stride_b_seqlen,
|
555 |
+
stride_b_head,
|
556 |
+
stride_b_dstate,
|
557 |
+
stride_dstates_batch,
|
558 |
+
stride_dstates_chunk,
|
559 |
+
stride_states_head,
|
560 |
+
stride_states_hdim,
|
561 |
+
stride_states_dstate,
|
562 |
+
stride_dt_batch,
|
563 |
+
stride_dt_chunk,
|
564 |
+
stride_dt_head,
|
565 |
+
stride_dt_csize,
|
566 |
+
stride_dA_cs_batch,
|
567 |
+
stride_dA_cs_chunk,
|
568 |
+
stride_dA_cs_head,
|
569 |
+
stride_dA_cs_csize,
|
570 |
+
stride_dx_batch,
|
571 |
+
stride_dx_seqlen,
|
572 |
+
stride_dx_head,
|
573 |
+
stride_dx_hdim,
|
574 |
+
stride_ddt_batch,
|
575 |
+
stride_ddt_chunk,
|
576 |
+
stride_ddt_head,
|
577 |
+
stride_ddt_csize,
|
578 |
+
stride_ddA_cs_batch,
|
579 |
+
stride_ddA_cs_chunk,
|
580 |
+
stride_ddA_cs_head,
|
581 |
+
stride_ddA_cs_csize,
|
582 |
+
# Meta-parameters
|
583 |
+
BLOCK_SIZE_M: tl.constexpr,
|
584 |
+
BLOCK_SIZE_N: tl.constexpr,
|
585 |
+
BLOCK_SIZE_K: tl.constexpr,
|
586 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
587 |
+
):
|
588 |
+
pid_bc = tl.program_id(axis=1)
|
589 |
+
pid_c = pid_bc // batch
|
590 |
+
pid_b = pid_bc - pid_c * batch
|
591 |
+
pid_h = tl.program_id(axis=2)
|
592 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
593 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
594 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
595 |
+
x_ptr += (
|
596 |
+
pid_b * stride_x_batch
|
597 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
598 |
+
+ pid_h * stride_x_head
|
599 |
+
)
|
600 |
+
b_ptr += (
|
601 |
+
pid_b * stride_b_batch
|
602 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
603 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
604 |
+
)
|
605 |
+
dstates_ptr += (
|
606 |
+
pid_b * stride_dstates_batch
|
607 |
+
+ pid_c * stride_dstates_chunk
|
608 |
+
+ pid_h * stride_states_head
|
609 |
+
)
|
610 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
611 |
+
ddt_ptr += (
|
612 |
+
pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
|
613 |
+
)
|
614 |
+
ddA_cumsum_ptr += (
|
615 |
+
pid_b * stride_ddA_cs_batch
|
616 |
+
+ pid_c * stride_ddA_cs_chunk
|
617 |
+
+ pid_h * stride_ddA_cs_head
|
618 |
+
)
|
619 |
+
dA_cumsum_ptr += (
|
620 |
+
pid_b * stride_dA_cs_batch
|
621 |
+
+ pid_c * stride_dA_cs_chunk
|
622 |
+
+ pid_h * stride_dA_cs_head
|
623 |
+
)
|
624 |
+
|
625 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
626 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
627 |
+
|
628 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
629 |
+
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
630 |
+
offs_k = tl.arange(
|
631 |
+
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
632 |
+
)
|
633 |
+
b_ptrs = b_ptr + (
|
634 |
+
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
635 |
+
)
|
636 |
+
dstates_ptrs = dstates_ptr + (
|
637 |
+
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
638 |
+
)
|
639 |
+
if BLOCK_SIZE_DSTATE <= 128:
|
640 |
+
b = tl.load(
|
641 |
+
b_ptrs,
|
642 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
643 |
+
other=0.0,
|
644 |
+
)
|
645 |
+
dstates = tl.load(
|
646 |
+
dstates_ptrs,
|
647 |
+
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
648 |
+
other=0.0,
|
649 |
+
)
|
650 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
651 |
+
acc = tl.dot(b, dstates)
|
652 |
+
else:
|
653 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
654 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
655 |
+
b = tl.load(
|
656 |
+
b_ptrs,
|
657 |
+
mask=(offs_m[:, None] < chunk_size_limit)
|
658 |
+
& (offs_k[None, :] < dstate - k),
|
659 |
+
other=0.0,
|
660 |
+
)
|
661 |
+
dstates = tl.load(
|
662 |
+
dstates_ptrs,
|
663 |
+
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
664 |
+
other=0.0,
|
665 |
+
)
|
666 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
667 |
+
acc += tl.dot(b, dstates)
|
668 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
669 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
670 |
+
|
671 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
672 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
673 |
+
|
674 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
675 |
+
tl.float32
|
676 |
+
)
|
677 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
678 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
679 |
+
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
680 |
+
tl.float32
|
681 |
+
)
|
682 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
683 |
+
acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
|
684 |
+
|
685 |
+
x_ptrs = x_ptr + (
|
686 |
+
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
687 |
+
)
|
688 |
+
x = tl.load(
|
689 |
+
x_ptrs,
|
690 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
691 |
+
other=0.0,
|
692 |
+
).to(tl.float32)
|
693 |
+
ddt = tl.sum(acc * x, axis=1)
|
694 |
+
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
|
695 |
+
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
|
696 |
+
ddA_cs = -(ddt * dt_m)
|
697 |
+
ddA_cs_last = -tl.sum(ddA_cs)
|
698 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
699 |
+
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
700 |
+
tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
|
701 |
+
|
702 |
+
dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
|
703 |
+
dx_ptr += (
|
704 |
+
pid_b * stride_dx_batch
|
705 |
+
+ pid_c * chunk_size * stride_dx_seqlen
|
706 |
+
+ pid_h * stride_dx_head
|
707 |
+
)
|
708 |
+
dx_ptrs = dx_ptr + (
|
709 |
+
offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
|
710 |
+
)
|
711 |
+
tl.store(
|
712 |
+
dx_ptrs,
|
713 |
+
dx,
|
714 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
@triton.autotune(
|
719 |
+
configs=[
|
720 |
+
triton.Config(
|
721 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
|
722 |
+
num_stages=3,
|
723 |
+
num_warps=4,
|
724 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
725 |
+
),
|
726 |
+
triton.Config(
|
727 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
|
728 |
+
num_stages=3,
|
729 |
+
num_warps=4,
|
730 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
731 |
+
),
|
732 |
+
triton.Config(
|
733 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
|
734 |
+
num_stages=3,
|
735 |
+
num_warps=4,
|
736 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
737 |
+
),
|
738 |
+
triton.Config(
|
739 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
|
740 |
+
num_stages=3,
|
741 |
+
num_warps=4,
|
742 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
743 |
+
),
|
744 |
+
triton.Config(
|
745 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
|
746 |
+
num_stages=3,
|
747 |
+
num_warps=4,
|
748 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
749 |
+
),
|
750 |
+
triton.Config(
|
751 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
|
752 |
+
num_stages=3,
|
753 |
+
num_warps=4,
|
754 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
755 |
+
),
|
756 |
+
triton.Config(
|
757 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
|
758 |
+
num_stages=3,
|
759 |
+
num_warps=4,
|
760 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
761 |
+
),
|
762 |
+
triton.Config(
|
763 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
|
764 |
+
num_stages=3,
|
765 |
+
num_warps=4,
|
766 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
767 |
+
),
|
768 |
+
],
|
769 |
+
key=["chunk_size", "dstate", "hdim"],
|
770 |
+
)
|
771 |
+
@triton.jit
|
772 |
+
def _chunk_state_bwd_db_kernel(
|
773 |
+
# Pointers to matrices
|
774 |
+
x_ptr,
|
775 |
+
dstates_ptr,
|
776 |
+
b_ptr,
|
777 |
+
dt_ptr,
|
778 |
+
dA_cumsum_ptr,
|
779 |
+
seq_idx_ptr,
|
780 |
+
db_ptr,
|
781 |
+
ddA_cumsum_ptr,
|
782 |
+
# Matrix dimensions
|
783 |
+
chunk_size,
|
784 |
+
dstate,
|
785 |
+
hdim,
|
786 |
+
batch,
|
787 |
+
seqlen,
|
788 |
+
nheads,
|
789 |
+
nheads_per_program,
|
790 |
+
ngroups,
|
791 |
+
# Strides
|
792 |
+
stride_x_batch,
|
793 |
+
stride_x_seqlen,
|
794 |
+
stride_x_head,
|
795 |
+
stride_x_hdim,
|
796 |
+
stride_dstates_batch,
|
797 |
+
stride_dstates_chunk,
|
798 |
+
stride_states_head,
|
799 |
+
stride_states_hdim,
|
800 |
+
stride_states_dstate,
|
801 |
+
stride_b_batch,
|
802 |
+
stride_b_seqlen,
|
803 |
+
stride_b_head,
|
804 |
+
stride_b_dstate,
|
805 |
+
stride_dt_batch,
|
806 |
+
stride_dt_chunk,
|
807 |
+
stride_dt_head,
|
808 |
+
stride_dt_csize,
|
809 |
+
stride_dA_cs_batch,
|
810 |
+
stride_dA_cs_chunk,
|
811 |
+
stride_dA_cs_head,
|
812 |
+
stride_dA_cs_csize,
|
813 |
+
stride_seq_idx_batch,
|
814 |
+
stride_seq_idx_seqlen,
|
815 |
+
stride_db_batch,
|
816 |
+
stride_db_seqlen,
|
817 |
+
stride_db_split,
|
818 |
+
stride_db_group,
|
819 |
+
stride_db_dstate,
|
820 |
+
stride_ddA_cs_batch,
|
821 |
+
stride_ddA_cs_chunk,
|
822 |
+
stride_ddA_cs_head,
|
823 |
+
stride_ddA_cs_csize,
|
824 |
+
# Meta-parameters
|
825 |
+
HAS_DDA_CS: tl.constexpr,
|
826 |
+
HAS_SEQ_IDX: tl.constexpr,
|
827 |
+
BLOCK_SIZE_M: tl.constexpr,
|
828 |
+
BLOCK_SIZE_N: tl.constexpr,
|
829 |
+
BLOCK_SIZE_K: tl.constexpr,
|
830 |
+
):
|
831 |
+
pid_bc = tl.program_id(axis=1)
|
832 |
+
pid_c = pid_bc // batch
|
833 |
+
pid_b = pid_bc - pid_c * batch
|
834 |
+
pid_sg = tl.program_id(axis=2)
|
835 |
+
pid_s = pid_sg // ngroups
|
836 |
+
pid_g = pid_sg - pid_s * ngroups
|
837 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
838 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
839 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
840 |
+
x_ptr += (
|
841 |
+
pid_b * stride_x_batch
|
842 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
843 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
|
844 |
+
)
|
845 |
+
db_ptr += (
|
846 |
+
pid_b * stride_db_batch
|
847 |
+
+ pid_c * chunk_size * stride_db_seqlen
|
848 |
+
+ pid_g * stride_db_group
|
849 |
+
+ pid_s * stride_db_split
|
850 |
+
)
|
851 |
+
dstates_ptr += (
|
852 |
+
pid_b * stride_dstates_batch
|
853 |
+
+ pid_c * stride_dstates_chunk
|
854 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
855 |
+
* stride_states_head
|
856 |
+
)
|
857 |
+
dt_ptr += (
|
858 |
+
pid_b * stride_dt_batch
|
859 |
+
+ pid_c * stride_dt_chunk
|
860 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
|
861 |
+
)
|
862 |
+
dA_cumsum_ptr += (
|
863 |
+
pid_b * stride_dA_cs_batch
|
864 |
+
+ pid_c * stride_dA_cs_chunk
|
865 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
|
866 |
+
)
|
867 |
+
if HAS_DDA_CS:
|
868 |
+
b_ptr += (
|
869 |
+
pid_b * stride_b_batch
|
870 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
871 |
+
+ pid_g * stride_b_head
|
872 |
+
)
|
873 |
+
ddA_cumsum_ptr += (
|
874 |
+
pid_b * stride_ddA_cs_batch
|
875 |
+
+ pid_c * stride_ddA_cs_chunk
|
876 |
+
+ (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
|
877 |
+
* stride_ddA_cs_head
|
878 |
+
)
|
879 |
+
if HAS_SEQ_IDX:
|
880 |
+
seq_idx_ptr += (
|
881 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
882 |
+
)
|
883 |
+
|
884 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
885 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
886 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
887 |
+
x_ptrs = x_ptr + (
|
888 |
+
offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
|
889 |
+
)
|
890 |
+
dstates_ptrs = dstates_ptr + (
|
891 |
+
offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
|
892 |
+
)
|
893 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
894 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
|
895 |
+
if HAS_DDA_CS:
|
896 |
+
b_ptrs = b_ptr + (
|
897 |
+
offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
|
898 |
+
)
|
899 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
900 |
+
|
901 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
902 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
903 |
+
if HAS_DDA_CS:
|
904 |
+
b = tl.load(
|
905 |
+
b_ptrs,
|
906 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
907 |
+
other=0.0,
|
908 |
+
).to(tl.float32)
|
909 |
+
if HAS_SEQ_IDX:
|
910 |
+
seq_idx_m = tl.load(
|
911 |
+
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
912 |
+
mask=offs_m < chunk_size_limit,
|
913 |
+
other=-1,
|
914 |
+
)
|
915 |
+
seq_idx_last = tl.load(
|
916 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
917 |
+
)
|
918 |
+
nheads_iter = min(
|
919 |
+
nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
|
920 |
+
)
|
921 |
+
for h in range(nheads_iter):
|
922 |
+
x = tl.load(
|
923 |
+
x_ptrs,
|
924 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
|
925 |
+
other=0.0,
|
926 |
+
)
|
927 |
+
dstates = tl.load(
|
928 |
+
dstates_ptrs,
|
929 |
+
mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
|
930 |
+
other=0.0,
|
931 |
+
)
|
932 |
+
dstates = dstates.to(x_ptrs.dtype.element_ty)
|
933 |
+
db = tl.dot(x, dstates)
|
934 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
935 |
+
tl.float32
|
936 |
+
)
|
937 |
+
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
|
938 |
+
tl.float32
|
939 |
+
)
|
940 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
941 |
+
if not HAS_SEQ_IDX:
|
942 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
943 |
+
else:
|
944 |
+
scale = tl.where(
|
945 |
+
seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
|
946 |
+
)
|
947 |
+
db *= (scale * dt_m)[:, None]
|
948 |
+
if HAS_DDA_CS:
|
949 |
+
# This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
|
950 |
+
ddA_cs = tl.sum(db * b, axis=1)
|
951 |
+
tl.atomic_add(
|
952 |
+
ddA_cumsum_ptrs + stride_ddA_cs_csize,
|
953 |
+
ddA_cs,
|
954 |
+
mask=offs_m < chunk_size - 1,
|
955 |
+
)
|
956 |
+
acc += db
|
957 |
+
x_ptrs += stride_x_head
|
958 |
+
dstates_ptrs += stride_states_head
|
959 |
+
dt_ptrs += stride_dt_head
|
960 |
+
dA_cumsum_ptr += stride_dA_cs_head
|
961 |
+
dA_cumsum_ptrs += stride_dA_cs_head
|
962 |
+
if HAS_DDA_CS:
|
963 |
+
ddA_cumsum_ptrs += stride_ddA_cs_head
|
964 |
+
|
965 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
966 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
967 |
+
# if HAS_SEQ_IDX:
|
968 |
+
# seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
969 |
+
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
|
970 |
+
# acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
|
971 |
+
db_ptrs = db_ptr + (
|
972 |
+
offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
|
973 |
+
)
|
974 |
+
tl.store(
|
975 |
+
db_ptrs,
|
976 |
+
acc,
|
977 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
|
978 |
+
)
|
979 |
+
|
980 |
+
|
981 |
+
@triton.autotune(
|
982 |
+
configs=[
|
983 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
984 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
985 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
986 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
987 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
988 |
+
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
989 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
990 |
+
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
991 |
+
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
|
992 |
+
triton.Config(
|
993 |
+
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
994 |
+
num_stages=3,
|
995 |
+
num_warps=4,
|
996 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
997 |
+
),
|
998 |
+
triton.Config(
|
999 |
+
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1000 |
+
num_stages=3,
|
1001 |
+
num_warps=4,
|
1002 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1003 |
+
),
|
1004 |
+
triton.Config(
|
1005 |
+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1006 |
+
num_stages=3,
|
1007 |
+
num_warps=4,
|
1008 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1009 |
+
),
|
1010 |
+
triton.Config(
|
1011 |
+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1012 |
+
num_stages=3,
|
1013 |
+
num_warps=4,
|
1014 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1015 |
+
),
|
1016 |
+
triton.Config(
|
1017 |
+
{"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
|
1018 |
+
num_stages=4,
|
1019 |
+
num_warps=8,
|
1020 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1021 |
+
),
|
1022 |
+
triton.Config(
|
1023 |
+
{"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1024 |
+
num_stages=4,
|
1025 |
+
num_warps=8,
|
1026 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1027 |
+
),
|
1028 |
+
triton.Config(
|
1029 |
+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1030 |
+
num_stages=4,
|
1031 |
+
num_warps=8,
|
1032 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1033 |
+
),
|
1034 |
+
triton.Config(
|
1035 |
+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1036 |
+
num_stages=4,
|
1037 |
+
num_warps=8,
|
1038 |
+
pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
|
1039 |
+
),
|
1040 |
+
],
|
1041 |
+
key=["chunk_size", "hdim", "dstate"],
|
1042 |
+
)
|
1043 |
+
@triton.jit
|
1044 |
+
def _chunk_state_bwd_ddAcs_stable_kernel(
|
1045 |
+
# Pointers to matrices
|
1046 |
+
x_ptr,
|
1047 |
+
b_ptr,
|
1048 |
+
dstates_ptr,
|
1049 |
+
dt_ptr,
|
1050 |
+
dA_cumsum_ptr,
|
1051 |
+
seq_idx_ptr,
|
1052 |
+
ddA_cumsum_ptr,
|
1053 |
+
# Matrix dimensions
|
1054 |
+
chunk_size,
|
1055 |
+
hdim,
|
1056 |
+
dstate,
|
1057 |
+
batch,
|
1058 |
+
seqlen,
|
1059 |
+
nheads_ngroups_ratio,
|
1060 |
+
# Strides
|
1061 |
+
stride_x_batch,
|
1062 |
+
stride_x_seqlen,
|
1063 |
+
stride_x_head,
|
1064 |
+
stride_x_hdim,
|
1065 |
+
stride_b_batch,
|
1066 |
+
stride_b_seqlen,
|
1067 |
+
stride_b_head,
|
1068 |
+
stride_b_dstate,
|
1069 |
+
stride_dstates_batch,
|
1070 |
+
stride_dstates_chunk,
|
1071 |
+
stride_states_head,
|
1072 |
+
stride_states_hdim,
|
1073 |
+
stride_states_dstate,
|
1074 |
+
stride_dt_batch,
|
1075 |
+
stride_dt_chunk,
|
1076 |
+
stride_dt_head,
|
1077 |
+
stride_dt_csize,
|
1078 |
+
stride_dA_cs_batch,
|
1079 |
+
stride_dA_cs_chunk,
|
1080 |
+
stride_dA_cs_head,
|
1081 |
+
stride_dA_cs_csize,
|
1082 |
+
stride_seq_idx_batch,
|
1083 |
+
stride_seq_idx_seqlen,
|
1084 |
+
stride_ddA_cs_batch,
|
1085 |
+
stride_ddA_cs_chunk,
|
1086 |
+
stride_ddA_cs_head,
|
1087 |
+
stride_ddA_cs_csize,
|
1088 |
+
# Meta-parameters
|
1089 |
+
HAS_SEQ_IDX: tl.constexpr,
|
1090 |
+
BLOCK_SIZE_M: tl.constexpr,
|
1091 |
+
BLOCK_SIZE_N: tl.constexpr,
|
1092 |
+
BLOCK_SIZE_K: tl.constexpr,
|
1093 |
+
BLOCK_SIZE_DSTATE: tl.constexpr,
|
1094 |
+
):
|
1095 |
+
pid_bc = tl.program_id(axis=1)
|
1096 |
+
pid_c = pid_bc // batch
|
1097 |
+
pid_b = pid_bc - pid_c * batch
|
1098 |
+
pid_h = tl.program_id(axis=2)
|
1099 |
+
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
1100 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
1101 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
1102 |
+
x_ptr += (
|
1103 |
+
pid_b * stride_x_batch
|
1104 |
+
+ pid_c * chunk_size * stride_x_seqlen
|
1105 |
+
+ pid_h * stride_x_head
|
1106 |
+
)
|
1107 |
+
b_ptr += (
|
1108 |
+
pid_b * stride_b_batch
|
1109 |
+
+ pid_c * chunk_size * stride_b_seqlen
|
1110 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
1111 |
+
)
|
1112 |
+
dstates_ptr += (
|
1113 |
+
pid_b * stride_dstates_batch
|
1114 |
+
+ pid_c * stride_dstates_chunk
|
1115 |
+
+ pid_h * stride_states_head
|
1116 |
+
)
|
1117 |
+
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
1118 |
+
ddA_cumsum_ptr += (
|
1119 |
+
pid_b * stride_ddA_cs_batch
|
1120 |
+
+ pid_c * stride_ddA_cs_chunk
|
1121 |
+
+ pid_h * stride_ddA_cs_head
|
1122 |
+
)
|
1123 |
+
dA_cumsum_ptr += (
|
1124 |
+
pid_b * stride_dA_cs_batch
|
1125 |
+
+ pid_c * stride_dA_cs_chunk
|
1126 |
+
+ pid_h * stride_dA_cs_head
|
1127 |
+
)
|
1128 |
+
if HAS_SEQ_IDX:
|
1129 |
+
seq_idx_ptr += (
|
1130 |
+
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1134 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1135 |
+
|
1136 |
+
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
1137 |
+
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
1138 |
+
offs_k = tl.arange(
|
1139 |
+
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
1140 |
+
)
|
1141 |
+
b_ptrs = b_ptr + (
|
1142 |
+
offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
|
1143 |
+
)
|
1144 |
+
dstates_ptrs = dstates_ptr + (
|
1145 |
+
offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
|
1146 |
+
)
|
1147 |
+
if BLOCK_SIZE_DSTATE <= 128:
|
1148 |
+
b = tl.load(
|
1149 |
+
b_ptrs,
|
1150 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
|
1151 |
+
other=0.0,
|
1152 |
+
)
|
1153 |
+
dstates = tl.load(
|
1154 |
+
dstates_ptrs,
|
1155 |
+
mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
|
1156 |
+
other=0.0,
|
1157 |
+
)
|
1158 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
1159 |
+
acc = tl.dot(b, dstates)
|
1160 |
+
else:
|
1161 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
1162 |
+
for k in range(0, dstate, BLOCK_SIZE_K):
|
1163 |
+
b = tl.load(
|
1164 |
+
b_ptrs,
|
1165 |
+
mask=(offs_m[:, None] < chunk_size_limit)
|
1166 |
+
& (offs_k[None, :] < dstate - k),
|
1167 |
+
other=0.0,
|
1168 |
+
)
|
1169 |
+
dstates = tl.load(
|
1170 |
+
dstates_ptrs,
|
1171 |
+
mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
|
1172 |
+
other=0.0,
|
1173 |
+
)
|
1174 |
+
dstates = dstates.to(b_ptr.dtype.element_ty)
|
1175 |
+
acc += tl.dot(b, dstates)
|
1176 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_dstate
|
1177 |
+
dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
|
1178 |
+
|
1179 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1180 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1181 |
+
|
1182 |
+
dA_cs_m = tl.load(
|
1183 |
+
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
1184 |
+
).to(tl.float32)
|
1185 |
+
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
1186 |
+
tl.float32
|
1187 |
+
)
|
1188 |
+
if not HAS_SEQ_IDX:
|
1189 |
+
scale = tl.exp(dA_cs_last - dA_cs_m)
|
1190 |
+
else:
|
1191 |
+
seq_idx_m = tl.load(
|
1192 |
+
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
1193 |
+
mask=offs_m < chunk_size_limit,
|
1194 |
+
other=-1,
|
1195 |
+
)
|
1196 |
+
seq_idx_last = tl.load(
|
1197 |
+
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
1198 |
+
)
|
1199 |
+
scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
|
1200 |
+
acc *= scale[:, None]
|
1201 |
+
|
1202 |
+
x_ptrs = x_ptr + (
|
1203 |
+
offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
1204 |
+
)
|
1205 |
+
x = tl.load(
|
1206 |
+
x_ptrs,
|
1207 |
+
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
1208 |
+
other=0.0,
|
1209 |
+
).to(tl.float32)
|
1210 |
+
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
|
1211 |
+
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
|
1212 |
+
ddt = tl.sum(acc * x, axis=1)
|
1213 |
+
# ddA_cs = -(ddt * dt_m)
|
1214 |
+
# Triton 2.2.0 errors if we have the cumsum here, so we just write it out
|
1215 |
+
# then call torch.cumsum outside this kernel.
|
1216 |
+
# ddA_cs = tl.cumsum(ddt * dt_m)
|
1217 |
+
ddA_cs = ddt * dt_m
|
1218 |
+
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
|
1219 |
+
# tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
|
1220 |
+
tl.atomic_add(
|
1221 |
+
ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
|
1222 |
+
)
|
1223 |
+
|
1224 |
+
|
1225 |
+
@triton.autotune(
|
1226 |
+
configs=[
|
1227 |
+
triton.Config(
|
1228 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
1229 |
+
num_stages=3,
|
1230 |
+
num_warps=8,
|
1231 |
+
),
|
1232 |
+
triton.Config(
|
1233 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
1234 |
+
num_stages=4,
|
1235 |
+
num_warps=4,
|
1236 |
+
),
|
1237 |
+
triton.Config(
|
1238 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1239 |
+
num_stages=4,
|
1240 |
+
num_warps=4,
|
1241 |
+
),
|
1242 |
+
triton.Config(
|
1243 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1244 |
+
num_stages=4,
|
1245 |
+
num_warps=4,
|
1246 |
+
),
|
1247 |
+
triton.Config(
|
1248 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
1249 |
+
num_stages=4,
|
1250 |
+
num_warps=4,
|
1251 |
+
),
|
1252 |
+
triton.Config(
|
1253 |
+
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1254 |
+
num_stages=4,
|
1255 |
+
num_warps=4,
|
1256 |
+
),
|
1257 |
+
triton.Config(
|
1258 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
1259 |
+
num_stages=5,
|
1260 |
+
num_warps=2,
|
1261 |
+
),
|
1262 |
+
triton.Config(
|
1263 |
+
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1264 |
+
num_stages=5,
|
1265 |
+
num_warps=2,
|
1266 |
+
),
|
1267 |
+
triton.Config(
|
1268 |
+
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
1269 |
+
num_stages=4,
|
1270 |
+
num_warps=2,
|
1271 |
+
),
|
1272 |
+
],
|
1273 |
+
key=["hdim", "dstate", "chunk_size"],
|
1274 |
+
)
|
1275 |
+
@triton.jit
|
1276 |
+
def _chunk_state_varlen_kernel(
|
1277 |
+
# Pointers to matrices
|
1278 |
+
x_ptr,
|
1279 |
+
b_ptr,
|
1280 |
+
dt_ptr,
|
1281 |
+
dA_cumsum_ptr,
|
1282 |
+
chunk_states_ptr,
|
1283 |
+
cu_seqlens_ptr,
|
1284 |
+
states_ptr,
|
1285 |
+
# Matrix dimensions
|
1286 |
+
hdim,
|
1287 |
+
dstate,
|
1288 |
+
chunk_size,
|
1289 |
+
seqlen,
|
1290 |
+
nheads_ngroups_ratio,
|
1291 |
+
# Strides
|
1292 |
+
stride_x_seqlen,
|
1293 |
+
stride_x_head,
|
1294 |
+
stride_x_hdim,
|
1295 |
+
stride_b_seqlen,
|
1296 |
+
stride_b_head,
|
1297 |
+
stride_b_dstate,
|
1298 |
+
stride_dt_chunk,
|
1299 |
+
stride_dt_head,
|
1300 |
+
stride_dt_csize,
|
1301 |
+
stride_dA_cs_chunk,
|
1302 |
+
stride_dA_cs_head,
|
1303 |
+
stride_dA_cs_csize,
|
1304 |
+
stride_chunk_states_chunk,
|
1305 |
+
stride_chunk_states_head,
|
1306 |
+
stride_chunk_states_hdim,
|
1307 |
+
stride_chunk_states_dstate,
|
1308 |
+
stride_states_batch,
|
1309 |
+
stride_states_head,
|
1310 |
+
stride_states_hdim,
|
1311 |
+
stride_states_dstate,
|
1312 |
+
# Meta-parameters
|
1313 |
+
BLOCK_SIZE_M: tl.constexpr,
|
1314 |
+
BLOCK_SIZE_N: tl.constexpr,
|
1315 |
+
BLOCK_SIZE_K: tl.constexpr,
|
1316 |
+
):
|
1317 |
+
pid_b = tl.program_id(axis=1)
|
1318 |
+
pid_h = tl.program_id(axis=2)
|
1319 |
+
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
1320 |
+
pid_m = tl.program_id(axis=0) // num_pid_n
|
1321 |
+
pid_n = tl.program_id(axis=0) % num_pid_n
|
1322 |
+
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
1323 |
+
pid_c = (end_idx - 1) // chunk_size
|
1324 |
+
b_ptr += (
|
1325 |
+
pid_c * chunk_size * stride_b_seqlen
|
1326 |
+
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
1327 |
+
)
|
1328 |
+
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
1329 |
+
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
1330 |
+
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
1331 |
+
chunk_states_ptr += (
|
1332 |
+
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
1333 |
+
)
|
1334 |
+
|
1335 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1336 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1337 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
1338 |
+
x_ptrs = x_ptr + (
|
1339 |
+
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
1340 |
+
)
|
1341 |
+
b_ptrs = b_ptr + (
|
1342 |
+
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
1343 |
+
)
|
1344 |
+
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
1345 |
+
dA_cs_last = tl.load(
|
1346 |
+
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
1347 |
+
).to(tl.float32)
|
1348 |
+
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
1349 |
+
|
1350 |
+
chunk_size_limit = end_idx - pid_c * chunk_size
|
1351 |
+
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
1352 |
+
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
1353 |
+
|
1354 |
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
1355 |
+
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
1356 |
+
x = tl.load(
|
1357 |
+
x_ptrs,
|
1358 |
+
mask=(offs_m[:, None] < hdim)
|
1359 |
+
& (offs_k[None, :] < chunk_size_limit - k)
|
1360 |
+
& (offs_k[None, :] >= start_idx_cur - k),
|
1361 |
+
other=0.0,
|
1362 |
+
)
|
1363 |
+
b = tl.load(
|
1364 |
+
b_ptrs,
|
1365 |
+
mask=(offs_k[:, None] < chunk_size_limit - k)
|
1366 |
+
& (offs_n[None, :] < dstate)
|
1367 |
+
& (offs_k[:, None] >= start_idx_cur - k),
|
1368 |
+
other=0.0,
|
1369 |
+
).to(tl.float32)
|
1370 |
+
dA_cs_k = tl.load(
|
1371 |
+
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
1372 |
+
).to(tl.float32)
|
1373 |
+
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
1374 |
+
tl.float32
|
1375 |
+
)
|
1376 |
+
scale = tl.where(
|
1377 |
+
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
1378 |
+
tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
|
1379 |
+
0.0,
|
1380 |
+
)
|
1381 |
+
b *= scale[:, None]
|
1382 |
+
b = b.to(x_ptr.dtype.element_ty)
|
1383 |
+
acc += tl.dot(x, b)
|
1384 |
+
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
1385 |
+
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
1386 |
+
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
1387 |
+
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
1388 |
+
|
1389 |
+
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
1390 |
+
if start_idx < pid_c * chunk_size:
|
1391 |
+
chunk_states_ptrs = chunk_states_ptr + (
|
1392 |
+
offs_m[:, None] * stride_chunk_states_hdim
|
1393 |
+
+ offs_n[None, :] * stride_chunk_states_dstate
|
1394 |
+
)
|
1395 |
+
chunk_states = tl.load(
|
1396 |
+
chunk_states_ptrs,
|
1397 |
+
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
1398 |
+
other=0.0,
|
1399 |
+
).to(tl.float32)
|
1400 |
+
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
|
1401 |
+
scale = tl.exp(dA_cs_last)
|
1402 |
+
acc += chunk_states * scale
|
1403 |
+
|
1404 |
+
states = acc.to(states_ptr.dtype.element_ty)
|
1405 |
+
|
1406 |
+
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
1407 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
1408 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
1409 |
+
states_ptrs = states_ptr + (
|
1410 |
+
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
1411 |
+
)
|
1412 |
+
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
1413 |
+
tl.store(states_ptrs, states, mask=c_mask)
|
1414 |
+
|
1415 |
+
|
1416 |
+
def _chunk_cumsum_fwd(
|
1417 |
+
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
1418 |
+
):
|
1419 |
+
batch, seqlen, nheads = dt.shape
|
1420 |
+
assert A.shape == (nheads,)
|
1421 |
+
if dt_bias is not None:
|
1422 |
+
assert dt_bias.shape == (nheads,)
|
1423 |
+
nchunks = math.ceil(seqlen / chunk_size)
|
1424 |
+
dt_out = torch.empty(
|
1425 |
+
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
1426 |
+
)
|
1427 |
+
dA_cumsum = torch.empty(
|
1428 |
+
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
1429 |
+
)
|
1430 |
+
grid_chunk_cs = lambda META: (
|
1431 |
+
batch,
|
1432 |
+
nchunks,
|
1433 |
+
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
1434 |
+
)
|
1435 |
+
with torch.cuda.device(dt.device.index):
|
1436 |
+
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
1437 |
+
dt,
|
1438 |
+
A,
|
1439 |
+
dt_bias,
|
1440 |
+
dt_out,
|
1441 |
+
dA_cumsum,
|
1442 |
+
batch,
|
1443 |
+
seqlen,
|
1444 |
+
nheads,
|
1445 |
+
chunk_size,
|
1446 |
+
dt_limit[0],
|
1447 |
+
dt_limit[1],
|
1448 |
+
dt.stride(0),
|
1449 |
+
dt.stride(1),
|
1450 |
+
dt.stride(2),
|
1451 |
+
A.stride(0),
|
1452 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
1453 |
+
dt_out.stride(0),
|
1454 |
+
dt_out.stride(2),
|
1455 |
+
dt_out.stride(1),
|
1456 |
+
dt_out.stride(3),
|
1457 |
+
dA_cumsum.stride(0),
|
1458 |
+
dA_cumsum.stride(2),
|
1459 |
+
dA_cumsum.stride(1),
|
1460 |
+
dA_cumsum.stride(3),
|
1461 |
+
dt_softplus,
|
1462 |
+
HAS_DT_BIAS=dt_bias is not None,
|
1463 |
+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
1464 |
+
)
|
1465 |
+
return dA_cumsum, dt_out
|
1466 |
+
|
1467 |
+
|
1468 |
+
def _chunk_cumsum_bwd(
|
1469 |
+
ddA,
|
1470 |
+
ddt_out,
|
1471 |
+
dt,
|
1472 |
+
A,
|
1473 |
+
dt_bias=None,
|
1474 |
+
dt_softplus=False,
|
1475 |
+
dt_limit=(0.0, float("inf")),
|
1476 |
+
ddt=None,
|
1477 |
+
):
|
1478 |
+
batch, seqlen, nheads = dt.shape
|
1479 |
+
_, _, nchunks, chunk_size = ddA.shape
|
1480 |
+
assert ddA.shape == (batch, nheads, nchunks, chunk_size)
|
1481 |
+
assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
|
1482 |
+
assert A.shape == (nheads,)
|
1483 |
+
if dt_bias is not None:
|
1484 |
+
assert dt_bias.shape == (nheads,)
|
1485 |
+
ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
|
1486 |
+
else:
|
1487 |
+
ddt_bias = None
|
1488 |
+
if ddt is not None:
|
1489 |
+
assert ddt.shape == dt.shape
|
1490 |
+
else:
|
1491 |
+
ddt = torch.empty_like(dt)
|
1492 |
+
dA = torch.empty_like(A, dtype=torch.float32)
|
1493 |
+
grid_chunk_cs = lambda META: (
|
1494 |
+
batch,
|
1495 |
+
nchunks,
|
1496 |
+
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
1497 |
+
)
|
1498 |
+
with torch.cuda.device(dt.device.index):
|
1499 |
+
_chunk_cumsum_bwd_kernel[grid_chunk_cs](
|
1500 |
+
ddA,
|
1501 |
+
ddt_out,
|
1502 |
+
dt,
|
1503 |
+
A,
|
1504 |
+
dt_bias,
|
1505 |
+
ddt,
|
1506 |
+
dA,
|
1507 |
+
ddt_bias,
|
1508 |
+
batch,
|
1509 |
+
seqlen,
|
1510 |
+
nheads,
|
1511 |
+
chunk_size,
|
1512 |
+
dt_limit[0],
|
1513 |
+
dt_limit[1],
|
1514 |
+
ddA.stride(0),
|
1515 |
+
ddA.stride(2),
|
1516 |
+
ddA.stride(1),
|
1517 |
+
ddA.stride(3),
|
1518 |
+
ddt_out.stride(0),
|
1519 |
+
ddt_out.stride(2),
|
1520 |
+
ddt_out.stride(1),
|
1521 |
+
ddt_out.stride(3),
|
1522 |
+
dt.stride(0),
|
1523 |
+
dt.stride(1),
|
1524 |
+
dt.stride(2),
|
1525 |
+
A.stride(0),
|
1526 |
+
dt_bias.stride(0) if dt_bias is not None else 0,
|
1527 |
+
ddt.stride(0),
|
1528 |
+
ddt.stride(1),
|
1529 |
+
ddt.stride(2),
|
1530 |
+
dA.stride(0),
|
1531 |
+
ddt_bias.stride(0) if ddt_bias is not None else 0,
|
1532 |
+
dt_softplus,
|
1533 |
+
HAS_DT_BIAS=dt_bias is not None,
|
1534 |
+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
1535 |
+
)
|
1536 |
+
return ddt, dA, ddt_bias
|
1537 |
+
|
1538 |
+
|
1539 |
+
def _chunk_state_fwd(
|
1540 |
+
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
1541 |
+
):
|
1542 |
+
batch, seqlen, nheads, headdim = x.shape
|
1543 |
+
_, _, nchunks, chunk_size = dt.shape
|
1544 |
+
_, _, ngroups, dstate = B.shape
|
1545 |
+
assert nheads % ngroups == 0
|
1546 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1547 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1548 |
+
assert dA_cumsum.shape == dt.shape
|
1549 |
+
if seq_idx is not None:
|
1550 |
+
assert seq_idx.shape == (batch, seqlen)
|
1551 |
+
if states is not None:
|
1552 |
+
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
1553 |
+
else:
|
1554 |
+
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
1555 |
+
states = torch.empty(
|
1556 |
+
(batch, nchunks, nheads, headdim, dstate),
|
1557 |
+
device=x.device,
|
1558 |
+
dtype=states_dtype,
|
1559 |
+
)
|
1560 |
+
grid = lambda META: (
|
1561 |
+
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
1562 |
+
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
1563 |
+
batch * nchunks,
|
1564 |
+
nheads,
|
1565 |
+
)
|
1566 |
+
with torch.cuda.device(x.device.index):
|
1567 |
+
_chunk_state_fwd_kernel[grid](
|
1568 |
+
x,
|
1569 |
+
B,
|
1570 |
+
states,
|
1571 |
+
dt,
|
1572 |
+
dA_cumsum,
|
1573 |
+
seq_idx,
|
1574 |
+
headdim,
|
1575 |
+
dstate,
|
1576 |
+
chunk_size,
|
1577 |
+
batch,
|
1578 |
+
seqlen,
|
1579 |
+
nheads // ngroups,
|
1580 |
+
x.stride(0),
|
1581 |
+
x.stride(1),
|
1582 |
+
x.stride(2),
|
1583 |
+
x.stride(3),
|
1584 |
+
B.stride(0),
|
1585 |
+
B.stride(1),
|
1586 |
+
B.stride(2),
|
1587 |
+
B.stride(-1),
|
1588 |
+
states.stride(0),
|
1589 |
+
states.stride(1),
|
1590 |
+
states.stride(2),
|
1591 |
+
states.stride(3),
|
1592 |
+
states.stride(4),
|
1593 |
+
dt.stride(0),
|
1594 |
+
dt.stride(2),
|
1595 |
+
dt.stride(1),
|
1596 |
+
dt.stride(3),
|
1597 |
+
dA_cumsum.stride(0),
|
1598 |
+
dA_cumsum.stride(2),
|
1599 |
+
dA_cumsum.stride(1),
|
1600 |
+
dA_cumsum.stride(3),
|
1601 |
+
*(
|
1602 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
1603 |
+
if seq_idx is not None
|
1604 |
+
else (0, 0)
|
1605 |
+
),
|
1606 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
1607 |
+
)
|
1608 |
+
return states
|
1609 |
+
|
1610 |
+
|
1611 |
+
def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
|
1612 |
+
batch, seqlen, nheads, headdim = x.shape
|
1613 |
+
_, _, nchunks, chunk_size = dt.shape
|
1614 |
+
_, _, ngroups, dstate = B.shape
|
1615 |
+
assert nheads % ngroups == 0
|
1616 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1617 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1618 |
+
assert dA_cumsum.shape == dt.shape
|
1619 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1620 |
+
if dx is not None:
|
1621 |
+
assert dx.shape == x.shape
|
1622 |
+
else:
|
1623 |
+
dx = torch.empty_like(x)
|
1624 |
+
ddt = torch.empty(
|
1625 |
+
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
1626 |
+
)
|
1627 |
+
ddA_cumsum = torch.empty(
|
1628 |
+
batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
|
1629 |
+
)
|
1630 |
+
grid_dx = lambda META: (
|
1631 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
1632 |
+
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
1633 |
+
batch * nchunks,
|
1634 |
+
nheads,
|
1635 |
+
)
|
1636 |
+
with torch.cuda.device(x.device.index):
|
1637 |
+
_chunk_state_bwd_dx_kernel[grid_dx](
|
1638 |
+
x,
|
1639 |
+
B,
|
1640 |
+
dstates,
|
1641 |
+
dt,
|
1642 |
+
dA_cumsum,
|
1643 |
+
dx,
|
1644 |
+
ddt,
|
1645 |
+
ddA_cumsum,
|
1646 |
+
chunk_size,
|
1647 |
+
headdim,
|
1648 |
+
dstate,
|
1649 |
+
batch,
|
1650 |
+
seqlen,
|
1651 |
+
nheads // ngroups,
|
1652 |
+
x.stride(0),
|
1653 |
+
x.stride(1),
|
1654 |
+
x.stride(2),
|
1655 |
+
x.stride(3),
|
1656 |
+
B.stride(0),
|
1657 |
+
B.stride(1),
|
1658 |
+
B.stride(2),
|
1659 |
+
B.stride(-1),
|
1660 |
+
dstates.stride(0),
|
1661 |
+
dstates.stride(1),
|
1662 |
+
dstates.stride(2),
|
1663 |
+
dstates.stride(3),
|
1664 |
+
dstates.stride(4),
|
1665 |
+
dt.stride(0),
|
1666 |
+
dt.stride(2),
|
1667 |
+
dt.stride(1),
|
1668 |
+
dt.stride(3),
|
1669 |
+
dA_cumsum.stride(0),
|
1670 |
+
dA_cumsum.stride(2),
|
1671 |
+
dA_cumsum.stride(1),
|
1672 |
+
dA_cumsum.stride(3),
|
1673 |
+
dx.stride(0),
|
1674 |
+
dx.stride(1),
|
1675 |
+
dx.stride(2),
|
1676 |
+
dx.stride(3),
|
1677 |
+
ddt.stride(0),
|
1678 |
+
ddt.stride(2),
|
1679 |
+
ddt.stride(1),
|
1680 |
+
ddt.stride(3),
|
1681 |
+
ddA_cumsum.stride(0),
|
1682 |
+
ddA_cumsum.stride(2),
|
1683 |
+
ddA_cumsum.stride(1),
|
1684 |
+
ddA_cumsum.stride(3),
|
1685 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
1686 |
+
)
|
1687 |
+
return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
|
1688 |
+
|
1689 |
+
|
1690 |
+
def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
|
1691 |
+
batch, seqlen, nheads, headdim = x.shape
|
1692 |
+
_, _, nchunks, chunk_size = dt.shape
|
1693 |
+
dstate = dstates.shape[-1]
|
1694 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1695 |
+
assert dA_cumsum.shape == dt.shape
|
1696 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1697 |
+
if seq_idx is not None:
|
1698 |
+
assert seq_idx.shape == (batch, seqlen)
|
1699 |
+
if B is not None:
|
1700 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1701 |
+
B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
|
1702 |
+
# Use torch.empty since the Triton kernel will call init_to_zero
|
1703 |
+
ddA_cumsum = torch.empty(
|
1704 |
+
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
1705 |
+
)
|
1706 |
+
ddA_cumsum_strides = (
|
1707 |
+
ddA_cumsum.stride(0),
|
1708 |
+
ddA_cumsum.stride(2),
|
1709 |
+
ddA_cumsum.stride(1),
|
1710 |
+
ddA_cumsum.stride(3),
|
1711 |
+
)
|
1712 |
+
else:
|
1713 |
+
B_strides = (0, 0, 0, 0)
|
1714 |
+
ddA_cumsum = None
|
1715 |
+
ddA_cumsum_strides = (0, 0, 0, 0)
|
1716 |
+
nheads_ngroups_ratio = nheads // ngroups
|
1717 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
1718 |
+
nheads_per_program = max(
|
1719 |
+
min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
|
1720 |
+
)
|
1721 |
+
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
|
1722 |
+
dB = torch.empty(
|
1723 |
+
batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
|
1724 |
+
)
|
1725 |
+
grid_db = lambda META: (
|
1726 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
1727 |
+
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
1728 |
+
batch * nchunks,
|
1729 |
+
nsplits * ngroups,
|
1730 |
+
)
|
1731 |
+
with torch.cuda.device(x.device.index):
|
1732 |
+
_chunk_state_bwd_db_kernel[grid_db](
|
1733 |
+
x,
|
1734 |
+
dstates,
|
1735 |
+
B,
|
1736 |
+
dt,
|
1737 |
+
dA_cumsum,
|
1738 |
+
seq_idx,
|
1739 |
+
dB,
|
1740 |
+
ddA_cumsum,
|
1741 |
+
chunk_size,
|
1742 |
+
dstate,
|
1743 |
+
headdim,
|
1744 |
+
batch,
|
1745 |
+
seqlen,
|
1746 |
+
nheads,
|
1747 |
+
nheads_per_program,
|
1748 |
+
ngroups,
|
1749 |
+
x.stride(0),
|
1750 |
+
x.stride(1),
|
1751 |
+
x.stride(2),
|
1752 |
+
x.stride(3),
|
1753 |
+
dstates.stride(0),
|
1754 |
+
dstates.stride(1),
|
1755 |
+
dstates.stride(2),
|
1756 |
+
dstates.stride(3),
|
1757 |
+
dstates.stride(4),
|
1758 |
+
*B_strides,
|
1759 |
+
dt.stride(0),
|
1760 |
+
dt.stride(2),
|
1761 |
+
dt.stride(1),
|
1762 |
+
dt.stride(3),
|
1763 |
+
dA_cumsum.stride(0),
|
1764 |
+
dA_cumsum.stride(2),
|
1765 |
+
dA_cumsum.stride(1),
|
1766 |
+
dA_cumsum.stride(3),
|
1767 |
+
*(
|
1768 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
1769 |
+
if seq_idx is not None
|
1770 |
+
else (0, 0)
|
1771 |
+
),
|
1772 |
+
dB.stride(0),
|
1773 |
+
dB.stride(1),
|
1774 |
+
dB.stride(2),
|
1775 |
+
dB.stride(3),
|
1776 |
+
dB.stride(4),
|
1777 |
+
*ddA_cumsum_strides,
|
1778 |
+
HAS_DDA_CS=ddA_cumsum is not None,
|
1779 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
1780 |
+
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
|
1781 |
+
)
|
1782 |
+
dB = dB.sum(2)
|
1783 |
+
if ddA_cumsum is not None:
|
1784 |
+
# The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
|
1785 |
+
# to the state of the chunk.
|
1786 |
+
# torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
1787 |
+
# But it's easier to just do the cumsum for all elements, the result will be the same.
|
1788 |
+
torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
|
1789 |
+
return dB if B is None else (dB, ddA_cumsum)
|
1790 |
+
|
1791 |
+
|
1792 |
+
def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
|
1793 |
+
batch, seqlen, nheads, headdim = x.shape
|
1794 |
+
_, _, nchunks, chunk_size = dt.shape
|
1795 |
+
_, _, ngroups, dstate = B.shape
|
1796 |
+
assert nheads % ngroups == 0
|
1797 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1798 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1799 |
+
assert dA_cumsum.shape == dt.shape
|
1800 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1801 |
+
if seq_idx is not None:
|
1802 |
+
assert seq_idx.shape == (batch, seqlen)
|
1803 |
+
# Use torch.empty since the Triton kernel will call init_to_zero
|
1804 |
+
ddA_cumsum = torch.empty(
|
1805 |
+
batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
|
1806 |
+
)
|
1807 |
+
grid_ddtcs = lambda META: (
|
1808 |
+
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
1809 |
+
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
1810 |
+
batch * nchunks,
|
1811 |
+
nheads,
|
1812 |
+
)
|
1813 |
+
with torch.cuda.device(x.device.index):
|
1814 |
+
_chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
|
1815 |
+
x,
|
1816 |
+
B,
|
1817 |
+
dstates,
|
1818 |
+
dt,
|
1819 |
+
dA_cumsum,
|
1820 |
+
seq_idx,
|
1821 |
+
ddA_cumsum,
|
1822 |
+
chunk_size,
|
1823 |
+
headdim,
|
1824 |
+
dstate,
|
1825 |
+
batch,
|
1826 |
+
seqlen,
|
1827 |
+
nheads // ngroups,
|
1828 |
+
x.stride(0),
|
1829 |
+
x.stride(1),
|
1830 |
+
x.stride(2),
|
1831 |
+
x.stride(3),
|
1832 |
+
B.stride(0),
|
1833 |
+
B.stride(1),
|
1834 |
+
B.stride(2),
|
1835 |
+
B.stride(-1),
|
1836 |
+
dstates.stride(0),
|
1837 |
+
dstates.stride(1),
|
1838 |
+
dstates.stride(2),
|
1839 |
+
dstates.stride(3),
|
1840 |
+
dstates.stride(4),
|
1841 |
+
dt.stride(0),
|
1842 |
+
dt.stride(2),
|
1843 |
+
dt.stride(1),
|
1844 |
+
dt.stride(3),
|
1845 |
+
dA_cumsum.stride(0),
|
1846 |
+
dA_cumsum.stride(2),
|
1847 |
+
dA_cumsum.stride(1),
|
1848 |
+
dA_cumsum.stride(3),
|
1849 |
+
*(
|
1850 |
+
(seq_idx.stride(0), seq_idx.stride(1))
|
1851 |
+
if seq_idx is not None
|
1852 |
+
else (0, 0)
|
1853 |
+
),
|
1854 |
+
ddA_cumsum.stride(0),
|
1855 |
+
ddA_cumsum.stride(2),
|
1856 |
+
ddA_cumsum.stride(1),
|
1857 |
+
ddA_cumsum.stride(3),
|
1858 |
+
HAS_SEQ_IDX=seq_idx is not None,
|
1859 |
+
BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
|
1860 |
+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
1861 |
+
)
|
1862 |
+
torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
|
1863 |
+
return ddA_cumsum
|
1864 |
+
|
1865 |
+
|
1866 |
+
def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
|
1867 |
+
total_seqlen, nheads, headdim = x.shape
|
1868 |
+
_, nchunks, chunk_size = dt.shape
|
1869 |
+
_, ngroups, dstate = B.shape
|
1870 |
+
batch = cu_seqlens.shape[0] - 1
|
1871 |
+
cu_seqlens = cu_seqlens.contiguous()
|
1872 |
+
assert nheads % ngroups == 0
|
1873 |
+
assert B.shape == (total_seqlen, ngroups, dstate)
|
1874 |
+
assert dt.shape == (nheads, nchunks, chunk_size)
|
1875 |
+
assert dA_cumsum.shape == dt.shape
|
1876 |
+
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
1877 |
+
states = torch.empty(
|
1878 |
+
batch,
|
1879 |
+
nheads,
|
1880 |
+
headdim,
|
1881 |
+
dstate,
|
1882 |
+
dtype=chunk_states.dtype,
|
1883 |
+
device=chunk_states.device,
|
1884 |
+
)
|
1885 |
+
grid = lambda META: (
|
1886 |
+
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
1887 |
+
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
1888 |
+
batch,
|
1889 |
+
nheads,
|
1890 |
+
)
|
1891 |
+
with torch.cuda.device(x.device.index):
|
1892 |
+
_chunk_state_varlen_kernel[grid](
|
1893 |
+
x,
|
1894 |
+
B,
|
1895 |
+
dt,
|
1896 |
+
dA_cumsum,
|
1897 |
+
chunk_states,
|
1898 |
+
cu_seqlens,
|
1899 |
+
states,
|
1900 |
+
headdim,
|
1901 |
+
dstate,
|
1902 |
+
chunk_size,
|
1903 |
+
total_seqlen,
|
1904 |
+
nheads // ngroups,
|
1905 |
+
x.stride(0),
|
1906 |
+
x.stride(1),
|
1907 |
+
x.stride(2),
|
1908 |
+
B.stride(0),
|
1909 |
+
B.stride(1),
|
1910 |
+
B.stride(2),
|
1911 |
+
dt.stride(1),
|
1912 |
+
dt.stride(0),
|
1913 |
+
dt.stride(2),
|
1914 |
+
dA_cumsum.stride(1),
|
1915 |
+
dA_cumsum.stride(0),
|
1916 |
+
dA_cumsum.stride(2),
|
1917 |
+
chunk_states.stride(0),
|
1918 |
+
chunk_states.stride(1),
|
1919 |
+
chunk_states.stride(2),
|
1920 |
+
chunk_states.stride(3),
|
1921 |
+
states.stride(0),
|
1922 |
+
states.stride(1),
|
1923 |
+
states.stride(2),
|
1924 |
+
states.stride(3),
|
1925 |
+
)
|
1926 |
+
return states
|
1927 |
+
|
1928 |
+
|
1929 |
+
class ChunkStateFn(torch.autograd.Function):
|
1930 |
+
|
1931 |
+
@staticmethod
|
1932 |
+
def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
|
1933 |
+
batch, seqlen, nheads, headdim = x.shape
|
1934 |
+
_, _, nchunks, chunk_size = dt.shape
|
1935 |
+
assert seqlen <= nchunks * chunk_size
|
1936 |
+
_, _, ngroups, dstate = B.shape
|
1937 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1938 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1939 |
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
1940 |
+
if B.stride(-1) != 1:
|
1941 |
+
B = B.contiguous()
|
1942 |
+
if (
|
1943 |
+
x.stride(-1) != 1 and x.stride(1) != 1
|
1944 |
+
): # Either M or K dimension should be contiguous
|
1945 |
+
x = x.contiguous()
|
1946 |
+
states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
|
1947 |
+
ctx.save_for_backward(B, x, dt, dA_cumsum)
|
1948 |
+
return states
|
1949 |
+
|
1950 |
+
@staticmethod
|
1951 |
+
def backward(ctx, dstates):
|
1952 |
+
B, x, dt, dA_cumsum = ctx.saved_tensors
|
1953 |
+
batch, seqlen, nheads, headdim = x.shape
|
1954 |
+
_, _, nchunks, chunk_size = dt.shape
|
1955 |
+
_, _, ngroups, dstate = B.shape
|
1956 |
+
assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
|
1957 |
+
if dstates.stride(-1) != 1:
|
1958 |
+
dstates = dstates.contiguous()
|
1959 |
+
dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
|
1960 |
+
dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
|
1961 |
+
dB = dB.to(B.dtype)
|
1962 |
+
return dB, dx, ddt, ddA_cumsum, None
|
1963 |
+
|
1964 |
+
|
1965 |
+
def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
|
1966 |
+
"""
|
1967 |
+
Argument:
|
1968 |
+
B: (batch, seqlen, ngroups, headdim)
|
1969 |
+
x: (batch, seqlen, nheads, headdim)
|
1970 |
+
dt: (batch, nheads, nchunks, chunk_size)
|
1971 |
+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
1972 |
+
Return:
|
1973 |
+
states: (batch, nchunks, nheads, headdim, dstate)
|
1974 |
+
"""
|
1975 |
+
return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
|
1976 |
+
|
1977 |
+
|
1978 |
+
def chunk_state_ref(B, x, dt, dA_cumsum):
|
1979 |
+
"""
|
1980 |
+
Argument:
|
1981 |
+
B: (batch, seqlen, ngroups, headdim)
|
1982 |
+
x: (batch, seqlen, nheads, headdim)
|
1983 |
+
dt: (batch, nheads, nchunks, chunk_size)
|
1984 |
+
dA_cumsum: (batch, nheads, nchunks, chunk_size)
|
1985 |
+
Return:
|
1986 |
+
states: (batch, nchunks, nheads, headdim, dstate)
|
1987 |
+
"""
|
1988 |
+
# Check constraints.
|
1989 |
+
batch, seqlen, nheads, headdim = x.shape
|
1990 |
+
dstate = B.shape[-1]
|
1991 |
+
_, _, nchunks, chunk_size = dt.shape
|
1992 |
+
assert seqlen <= nchunks * chunk_size
|
1993 |
+
assert x.shape == (batch, seqlen, nheads, headdim)
|
1994 |
+
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
1995 |
+
ngroups = B.shape[2]
|
1996 |
+
assert nheads % ngroups == 0
|
1997 |
+
assert B.shape == (batch, seqlen, ngroups, dstate)
|
1998 |
+
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
|
1999 |
+
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
2000 |
+
if seqlen < nchunks * chunk_size:
|
2001 |
+
x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
2002 |
+
B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
|
2003 |
+
x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
|
2004 |
+
B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
|
2005 |
+
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
|
2006 |
+
return torch.einsum(
|
2007 |
+
"bclhn,bhcl,bhcl,bclhp->bchpn",
|
2008 |
+
B.to(x.dtype),
|
2009 |
+
decay_states.to(x.dtype),
|
2010 |
+
dt.to(x.dtype),
|
2011 |
+
x,
|
2012 |
+
)
|