danieldk HF Staff commited on
Commit
23d26f4
·
0 Parent(s):

Import mamba-ssm kernels

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +14 -0
  2. build.toml +34 -0
  3. selective-scan/reverse_scan.cuh +415 -0
  4. selective-scan/selective_scan.cpp +497 -0
  5. selective-scan/selective_scan.h +101 -0
  6. selective-scan/selective_scan_bwd_bf16_complex.cu +9 -0
  7. selective-scan/selective_scan_bwd_bf16_real.cu +9 -0
  8. selective-scan/selective_scan_bwd_fp16_complex.cu +9 -0
  9. selective-scan/selective_scan_bwd_fp16_real.cu +9 -0
  10. selective-scan/selective_scan_bwd_fp32_complex.cu +9 -0
  11. selective-scan/selective_scan_bwd_fp32_real.cu +9 -0
  12. selective-scan/selective_scan_bwd_kernel.cuh +561 -0
  13. selective-scan/selective_scan_common.h +255 -0
  14. selective-scan/selective_scan_fwd_bf16.cu +10 -0
  15. selective-scan/selective_scan_fwd_fp16.cu +10 -0
  16. selective-scan/selective_scan_fwd_fp32.cu +10 -0
  17. selective-scan/selective_scan_fwd_kernel.cuh +376 -0
  18. selective-scan/static_switch.h +25 -0
  19. selective-scan/uninitialized_copy.cuh +77 -0
  20. tests/ops/test_selective_scan.py +247 -0
  21. tests/ops/triton/test_layernorm_gated.py +103 -0
  22. tests/ops/triton/test_selective_state_update.py +201 -0
  23. tests/ops/triton/test_ssd.py +78 -0
  24. tests/test_generation.py +113 -0
  25. torch-ext/mamba_ssm/__init__.py +14 -0
  26. torch-ext/mamba_ssm/distributed/__init__.py +0 -0
  27. torch-ext/mamba_ssm/distributed/distributed_utils.py +144 -0
  28. torch-ext/mamba_ssm/distributed/tensor_parallel.py +326 -0
  29. torch-ext/mamba_ssm/models/__init__.py +0 -0
  30. torch-ext/mamba_ssm/models/config_mamba.py +18 -0
  31. torch-ext/mamba_ssm/models/mixer_seq_simple.py +338 -0
  32. torch-ext/mamba_ssm/modules/__init__.py +0 -0
  33. torch-ext/mamba_ssm/modules/block.py +107 -0
  34. torch-ext/mamba_ssm/modules/mamba2.py +502 -0
  35. torch-ext/mamba_ssm/modules/mamba2_simple.py +229 -0
  36. torch-ext/mamba_ssm/modules/mamba_simple.py +339 -0
  37. torch-ext/mamba_ssm/modules/mha.py +294 -0
  38. torch-ext/mamba_ssm/modules/mlp.py +34 -0
  39. torch-ext/mamba_ssm/modules/ssd_minimal.py +111 -0
  40. torch-ext/mamba_ssm/ops/__init__.py +0 -0
  41. torch-ext/mamba_ssm/ops/selective_scan_interface.py +659 -0
  42. torch-ext/mamba_ssm/ops/triton/__init__.py +0 -0
  43. torch-ext/mamba_ssm/ops/triton/k_activations.py +169 -0
  44. torch-ext/mamba_ssm/ops/triton/layer_norm.py +1166 -0
  45. torch-ext/mamba_ssm/ops/triton/layernorm_gated.py +437 -0
  46. torch-ext/mamba_ssm/ops/triton/selective_state_update.py +389 -0
  47. torch-ext/mamba_ssm/ops/triton/softplus.py +15 -0
  48. torch-ext/mamba_ssm/ops/triton/ssd_bmm.py +262 -0
  49. torch-ext/mamba_ssm/ops/triton/ssd_chunk_scan.py +0 -0
  50. torch-ext/mamba_ssm/ops/triton/ssd_chunk_state.py +2012 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ## Mamba
6
+
7
+ Mamba state space kernels + models from [state-spaces/mamba](https://github.com/state-spaces/mamba).
8
+
9
+ ## Warning
10
+
11
+ Some functionality is dependent on `einops` and `transformers`, however we
12
+ currently don't have any way of defining these dependencies yet. The scope
13
+ of the Hub kernel is probably too large (should maybe only contain the
14
+ selective-scan and Triton kernels).
build.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "mamba_ssm"
6
+ src = [
7
+ "torch-ext/registration.h",
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h"
10
+ ]
11
+ pyroot = "torch-ext"
12
+
13
+ [kernel.selective_scan]
14
+ capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
15
+ src = [
16
+ "selective-scan/reverse_scan.cuh",
17
+ "selective-scan/selective_scan.cpp",
18
+ "selective-scan/selective_scan.h",
19
+ "selective-scan/selective_scan_bwd_bf16_complex.cu",
20
+ "selective-scan/selective_scan_bwd_bf16_real.cu",
21
+ "selective-scan/selective_scan_bwd_fp16_complex.cu",
22
+ "selective-scan/selective_scan_bwd_fp16_real.cu",
23
+ "selective-scan/selective_scan_bwd_fp32_complex.cu",
24
+ "selective-scan/selective_scan_bwd_fp32_real.cu",
25
+ "selective-scan/selective_scan_bwd_kernel.cuh",
26
+ "selective-scan/selective_scan_common.h",
27
+ "selective-scan/selective_scan_fwd_bf16.cu",
28
+ "selective-scan/selective_scan_fwd_fp16.cu",
29
+ "selective-scan/selective_scan_fwd_fp32.cu",
30
+ "selective-scan/selective_scan_fwd_kernel.cuh",
31
+ "selective-scan/static_switch.h",
32
+ "selective-scan/uninitialized_copy.cuh",
33
+ ]
34
+ depends = [ "torch" ]
selective-scan/reverse_scan.cuh ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #ifndef USE_ROCM
8
+ #include <cub/config.cuh>
9
+
10
+ #include <cub/util_ptx.cuh>
11
+ #include <cub/util_type.cuh>
12
+ #include <cub/block/block_raking_layout.cuh>
13
+ // #include <cub/detail/uninitialized_copy.cuh>
14
+ #else
15
+ #include <hipcub/hipcub.hpp>
16
+ namespace cub = hipcub;
17
+ #endif
18
+ #include "uninitialized_copy.cuh"
19
+
20
+ /**
21
+ * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned.
22
+ */
23
+ template <
24
+ int LENGTH,
25
+ typename T,
26
+ typename ReductionOp>
27
+ __device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) {
28
+ static_assert(LENGTH > 0);
29
+ T retval = input[LENGTH - 1];
30
+ #pragma unroll
31
+ for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); }
32
+ return retval;
33
+ }
34
+
35
+ /**
36
+ * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
37
+ */
38
+ template <
39
+ int LENGTH,
40
+ typename T,
41
+ typename ScanOp>
42
+ __device__ __forceinline__ T ThreadReverseScanInclusive(
43
+ const T (&input)[LENGTH],
44
+ T (&output)[LENGTH],
45
+ ScanOp scan_op,
46
+ const T postfix)
47
+ {
48
+ T inclusive = postfix;
49
+ #pragma unroll
50
+ for (int i = LENGTH - 1; i >= 0; --i) {
51
+ inclusive = scan_op(inclusive, input[i]);
52
+ output[i] = inclusive;
53
+ }
54
+ return inclusive;
55
+ }
56
+
57
+ /**
58
+ * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned.
59
+ */
60
+ template <
61
+ int LENGTH,
62
+ typename T,
63
+ typename ScanOp>
64
+ __device__ __forceinline__ T ThreadReverseScanExclusive(
65
+ const T (&input)[LENGTH],
66
+ T (&output)[LENGTH],
67
+ ScanOp scan_op,
68
+ const T postfix)
69
+ {
70
+ // Careful, output maybe be aliased to input
71
+ T exclusive = postfix;
72
+ T inclusive;
73
+ #pragma unroll
74
+ for (int i = LENGTH - 1; i >= 0; --i) {
75
+ inclusive = scan_op(exclusive, input[i]);
76
+ output[i] = exclusive;
77
+ exclusive = inclusive;
78
+ }
79
+ return inclusive;
80
+ }
81
+
82
+
83
+ /**
84
+ * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp.
85
+ *
86
+ * LOGICAL_WARP_THREADS must be a power-of-two
87
+ */
88
+ template <
89
+ typename T, ///< Data type being scanned
90
+ int LOGICAL_WARP_THREADS ///< Number of threads per logical warp
91
+ >
92
+ struct WarpReverseScan {
93
+ //---------------------------------------------------------------------
94
+ // Constants and type definitions
95
+ //---------------------------------------------------------------------
96
+
97
+ /// Whether the logical warp size and the PTX warp size coincide
98
+
99
+ // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
100
+ // While in cub, it's defined as a macro that takes a redundant unused argument.
101
+ #ifndef USE_ROCM
102
+ #define WARP_THREADS CUB_WARP_THREADS(0)
103
+ #else
104
+ #define WARP_THREADS HIPCUB_WARP_THREADS
105
+ #endif
106
+ static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
107
+ /// The number of warp scan steps
108
+ static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
109
+ static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
110
+
111
+
112
+ //---------------------------------------------------------------------
113
+ // Thread fields
114
+ //---------------------------------------------------------------------
115
+
116
+ /// Lane index in logical warp
117
+ unsigned int lane_id;
118
+
119
+ /// Logical warp index in 32-thread physical warp
120
+ unsigned int warp_id;
121
+
122
+ /// 32-thread physical warp member mask of logical warp
123
+ unsigned int member_mask;
124
+
125
+ //---------------------------------------------------------------------
126
+ // Construction
127
+ //---------------------------------------------------------------------
128
+
129
+ /// Constructor
130
+ explicit __device__ __forceinline__
131
+ WarpReverseScan()
132
+ : lane_id(cub::LaneId())
133
+ , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS))
134
+ , member_mask(cub::WarpMask<LOGICAL_WARP_THREADS>(warp_id))
135
+ {
136
+ if (!IS_ARCH_WARP) {
137
+ lane_id = lane_id % LOGICAL_WARP_THREADS;
138
+ }
139
+ }
140
+
141
+
142
+ /// Broadcast
143
+ __device__ __forceinline__ T Broadcast(
144
+ T input, ///< [in] The value to broadcast
145
+ int src_lane) ///< [in] Which warp lane is to do the broadcasting
146
+ {
147
+ return cub::ShuffleIndex<LOGICAL_WARP_THREADS>(input, src_lane, member_mask);
148
+ }
149
+
150
+
151
+ /// Inclusive scan
152
+ template <typename ScanOpT>
153
+ __device__ __forceinline__ void InclusiveReverseScan(
154
+ T input, ///< [in] Calling thread's input item.
155
+ T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
156
+ ScanOpT scan_op) ///< [in] Binary scan operator
157
+ {
158
+ inclusive_output = input;
159
+ #pragma unroll
160
+ for (int STEP = 0; STEP < STEPS; STEP++) {
161
+ int offset = 1 << STEP;
162
+ T temp = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
163
+ inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask
164
+ );
165
+ // Perform scan op if from a valid peer
166
+ inclusive_output = static_cast<int>(lane_id) >= LOGICAL_WARP_THREADS - offset
167
+ ? inclusive_output : scan_op(temp, inclusive_output);
168
+ }
169
+ }
170
+
171
+ /// Exclusive scan
172
+ // Get exclusive from inclusive
173
+ template <typename ScanOpT>
174
+ __device__ __forceinline__ void ExclusiveReverseScan(
175
+ T input, ///< [in] Calling thread's input item.
176
+ T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input.
177
+ ScanOpT scan_op, ///< [in] Binary scan operator
178
+ T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items.
179
+ {
180
+ T inclusive_output;
181
+ InclusiveReverseScan(input, inclusive_output, scan_op);
182
+ warp_aggregate = cub::ShuffleIndex<LOGICAL_WARP_THREADS>(inclusive_output, 0, member_mask);
183
+ // initial value unknown
184
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
185
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
186
+ );
187
+ }
188
+
189
+ /**
190
+ * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last <em>warp-lane</em> is undefined.
191
+ */
192
+ template <typename ScanOpT>
193
+ __device__ __forceinline__ void ReverseScan(
194
+ T input, ///< [in] Calling thread's input item.
195
+ T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item.
196
+ T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item.
197
+ ScanOpT scan_op) ///< [in] Binary scan operator
198
+ {
199
+ InclusiveReverseScan(input, inclusive_output, scan_op);
200
+ // initial value unknown
201
+ exclusive_output = cub::ShuffleDown<LOGICAL_WARP_THREADS>(
202
+ inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask
203
+ );
204
+ }
205
+
206
+ };
207
+
208
+ /**
209
+ * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block.
210
+ */
211
+ template <
212
+ typename T, ///< Data type being scanned
213
+ int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension
214
+ bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure
215
+ >
216
+ struct BlockReverseScan {
217
+ //---------------------------------------------------------------------
218
+ // Types and constants
219
+ //---------------------------------------------------------------------
220
+
221
+ /// Constants
222
+ /// The thread block size in threads
223
+ static constexpr int BLOCK_THREADS = BLOCK_DIM_X;
224
+
225
+ /// Layout type for padded thread block raking grid
226
+ using BlockRakingLayout = cub::BlockRakingLayout<T, BLOCK_THREADS>;
227
+ // The number of reduction elements is not a multiple of the number of raking threads for now
228
+ static_assert(BlockRakingLayout::UNGUARDED);
229
+
230
+ /// Number of raking threads
231
+ static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS;
232
+ /// Number of raking elements per warp synchronous raking thread
233
+ static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH;
234
+ /// Cooperative work can be entirely warp synchronous
235
+ static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS));
236
+
237
+ /// WarpReverseScan utility type
238
+ using WarpReverseScan = WarpReverseScan<T, RAKING_THREADS>;
239
+
240
+ /// Shared memory storage layout type
241
+ struct _TempStorage {
242
+ typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid
243
+ };
244
+
245
+
246
+ /// Alias wrapper allowing storage to be unioned
247
+ struct TempStorage : cub::Uninitialized<_TempStorage> {};
248
+
249
+
250
+ //---------------------------------------------------------------------
251
+ // Per-thread fields
252
+ //---------------------------------------------------------------------
253
+
254
+ // Thread fields
255
+ _TempStorage &temp_storage;
256
+ unsigned int linear_tid;
257
+ T cached_segment[SEGMENT_LENGTH];
258
+
259
+
260
+ //---------------------------------------------------------------------
261
+ // Utility methods
262
+ //---------------------------------------------------------------------
263
+
264
+ /// Performs upsweep raking reduction, returning the aggregate
265
+ template <typename ScanOp>
266
+ __device__ __forceinline__ T Upsweep(ScanOp scan_op) {
267
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
268
+ // Read data into registers
269
+ #pragma unroll
270
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
271
+ T raking_partial = cached_segment[SEGMENT_LENGTH - 1];
272
+ #pragma unroll
273
+ for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) {
274
+ raking_partial = scan_op(raking_partial, cached_segment[i]);
275
+ }
276
+ return raking_partial;
277
+ }
278
+
279
+
280
+ /// Performs exclusive downsweep raking scan
281
+ template <typename ScanOp>
282
+ __device__ __forceinline__ void ExclusiveDownsweep(
283
+ ScanOp scan_op,
284
+ T raking_partial)
285
+ {
286
+ T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid);
287
+ // Read data back into registers
288
+ if (!MEMOIZE) {
289
+ #pragma unroll
290
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; }
291
+ }
292
+ ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial);
293
+ // Write data back to smem
294
+ #pragma unroll
295
+ for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; }
296
+ }
297
+
298
+
299
+ //---------------------------------------------------------------------
300
+ // Constructors
301
+ //---------------------------------------------------------------------
302
+
303
+ /// Constructor
304
+ __device__ __forceinline__ BlockReverseScan(
305
+ TempStorage &temp_storage)
306
+ :
307
+ temp_storage(temp_storage.Alias()),
308
+ linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1))
309
+ {}
310
+
311
+
312
+ /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
313
+ template <
314
+ typename ScanOp,
315
+ typename BlockPostfixCallbackOp>
316
+ __device__ __forceinline__ void ExclusiveReverseScan(
317
+ T input, ///< [in] Calling thread's input item
318
+ T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input)
319
+ ScanOp scan_op, ///< [in] Binary scan operator
320
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a thread block-wide postfix to be applied to all inputs.
321
+ {
322
+ if (WARP_SYNCHRONOUS) {
323
+ // Short-circuit directly to warp-synchronous scan
324
+ T block_aggregate;
325
+ WarpReverseScan warp_scan;
326
+ warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate);
327
+ // Obtain warp-wide postfix in lane0, then broadcast to other lanes
328
+ T block_postfix = block_postfix_callback_op(block_aggregate);
329
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
330
+ exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output);
331
+ } else {
332
+ // Place thread partial into shared memory raking grid
333
+ T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid);
334
+ detail::uninitialized_copy(placement_ptr, input);
335
+ cub::CTA_SYNC();
336
+ // Reduce parallelism down to just raking threads
337
+ if (linear_tid < RAKING_THREADS) {
338
+ WarpReverseScan warp_scan;
339
+ // Raking upsweep reduction across shared partials
340
+ T upsweep_partial = Upsweep(scan_op);
341
+ // Warp-synchronous scan
342
+ T exclusive_partial, block_aggregate;
343
+ warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate);
344
+ // Obtain block-wide postfix in lane0, then broadcast to other lanes
345
+ T block_postfix = block_postfix_callback_op(block_aggregate);
346
+ block_postfix = warp_scan.Broadcast(block_postfix, 0);
347
+ // Update postfix with warpscan exclusive partial
348
+ T downsweep_postfix = linear_tid == RAKING_THREADS - 1
349
+ ? block_postfix : scan_op(block_postfix, exclusive_partial);
350
+ // Exclusive raking downsweep scan
351
+ ExclusiveDownsweep(scan_op, downsweep_postfix);
352
+ }
353
+ cub::CTA_SYNC();
354
+ // Grab thread postfix from shared memory
355
+ exclusive_output = *placement_ptr;
356
+
357
+ // // Compute warp scan in each warp.
358
+ // // The exclusive output from the last lane in each warp is invalid.
359
+ // T inclusive_output;
360
+ // WarpReverseScan warp_scan;
361
+ // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op);
362
+
363
+ // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid.
364
+ // T block_aggregate;
365
+ // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate);
366
+
367
+ // // Apply warp postfix to our lane's partial
368
+ // if (warp_id != 0) {
369
+ // exclusive_output = scan_op(warp_postfix, exclusive_output);
370
+ // if (lane_id == 0) { exclusive_output = warp_postfix; }
371
+ // }
372
+
373
+ // // Use the first warp to determine the thread block postfix, returning the result in lane0
374
+ // if (warp_id == 0) {
375
+ // T block_postfix = block_postfix_callback_op(block_aggregate);
376
+ // if (lane_id == 0) {
377
+ // // Share the postfix with all threads
378
+ // detail::uninitialized_copy(&temp_storage.block_postfix,
379
+ // block_postfix);
380
+
381
+ // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0
382
+ // }
383
+ // }
384
+
385
+ // cub::CTA_SYNC();
386
+
387
+ // // Incorporate thread block postfix into outputs
388
+ // T block_postfix = temp_storage.block_postfix;
389
+ // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); }
390
+ }
391
+ }
392
+
393
+
394
+ /**
395
+ * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by <em>lane</em><sub>0</sub> in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs.
396
+ */
397
+ template <
398
+ int ITEMS_PER_THREAD,
399
+ typename ScanOp,
400
+ typename BlockPostfixCallbackOp>
401
+ __device__ __forceinline__ void InclusiveReverseScan(
402
+ T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items
403
+ T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input)
404
+ ScanOp scan_op, ///< [in] Binary scan functor
405
+ BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] <b>[<em>warp</em><sub>0</sub> only]</b> Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence.
406
+ {
407
+ // Reduce consecutive thread items in registers
408
+ T thread_postfix = ThreadReverseReduce(input, scan_op);
409
+ // Exclusive thread block-scan
410
+ ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op);
411
+ // Inclusive scan in registers with postfix as seed
412
+ ThreadReverseScanInclusive(input, output, scan_op, thread_postfix);
413
+ }
414
+
415
+ };
selective-scan/selective_scan.cpp ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/cuda/CUDAGuard.h>
6
+ #include <c10/cuda/CUDAStream.h>
7
+ #include <torch/all.h>
8
+ #include <vector>
9
+
10
+ #include "selective_scan.h"
11
+
12
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
13
+
14
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
15
+ if (ITYPE == at::ScalarType::Half) { \
16
+ using input_t = at::Half; \
17
+ __VA_ARGS__(); \
18
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
19
+ using input_t = at::BFloat16; \
20
+ __VA_ARGS__(); \
21
+ } else if (ITYPE == at::ScalarType::Float) { \
22
+ using input_t = float; \
23
+ __VA_ARGS__(); \
24
+ } else { \
25
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
26
+ }
27
+
28
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
29
+ if (WTYPE == at::ScalarType::Half) { \
30
+ using weight_t = at::Half; \
31
+ __VA_ARGS__(); \
32
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
33
+ using weight_t = at::BFloat16; \
34
+ __VA_ARGS__(); \
35
+ } else if (WTYPE == at::ScalarType::Float) { \
36
+ using weight_t = float; \
37
+ __VA_ARGS__(); \
38
+ } else { \
39
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
40
+ }
41
+
42
+ #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
43
+ if (WTYPE == at::ScalarType::Float) { \
44
+ using weight_t = float; \
45
+ __VA_ARGS__(); \
46
+ } else if (WTYPE == at::ScalarType::ComplexFloat) { \
47
+ using weight_t = c10::complex<float>; \
48
+ __VA_ARGS__(); \
49
+ } else { \
50
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
51
+ }
52
+
53
+ template<typename input_t, typename weight_t>
54
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
55
+
56
+ template <typename input_t, typename weight_t>
57
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream);
58
+
59
+ void set_ssm_params_fwd(SSMParamsBase &params,
60
+ // sizes
61
+ const size_t batch,
62
+ const size_t dim,
63
+ const size_t seqlen,
64
+ const size_t dstate,
65
+ const size_t n_groups,
66
+ const size_t n_chunks,
67
+ const bool is_variable_B,
68
+ const bool is_variable_C,
69
+ // device pointers
70
+ const at::Tensor u,
71
+ const at::Tensor delta,
72
+ const at::Tensor A,
73
+ const at::Tensor B,
74
+ const at::Tensor C,
75
+ const at::Tensor out,
76
+ const at::Tensor z,
77
+ const at::Tensor out_z,
78
+ void* D_ptr,
79
+ void* delta_bias_ptr,
80
+ void* x_ptr,
81
+ bool has_z,
82
+ bool delta_softplus) {
83
+
84
+ // Reset the parameters
85
+ memset(&params, 0, sizeof(params));
86
+
87
+ params.batch = batch;
88
+ params.dim = dim;
89
+ params.seqlen = seqlen;
90
+ params.dstate = dstate;
91
+ params.n_groups = n_groups;
92
+ params.n_chunks = n_chunks;
93
+ params.dim_ngroups_ratio = dim / n_groups;
94
+
95
+ params.delta_softplus = delta_softplus;
96
+
97
+ params.is_variable_B = is_variable_B;
98
+ params.is_variable_C = is_variable_C;
99
+
100
+ // Set the pointers and strides.
101
+ params.u_ptr = u.data_ptr();
102
+ params.delta_ptr = delta.data_ptr();
103
+ params.A_ptr = A.data_ptr();
104
+ params.B_ptr = B.data_ptr();
105
+ params.C_ptr = C.data_ptr();
106
+ params.D_ptr = D_ptr;
107
+ params.delta_bias_ptr = delta_bias_ptr;
108
+ params.out_ptr = out.data_ptr();
109
+ params.x_ptr = x_ptr;
110
+ params.z_ptr = has_z ? z.data_ptr() : nullptr;
111
+ params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
112
+ // All stride are in elements, not bytes.
113
+ params.A_d_stride = A.stride(0);
114
+ params.A_dstate_stride = A.stride(1);
115
+ if (!is_variable_B) {
116
+ params.B_d_stride = B.stride(0);
117
+ } else {
118
+ params.B_batch_stride = B.stride(0);
119
+ params.B_group_stride = B.stride(1);
120
+ }
121
+ params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
122
+ if (!is_variable_C) {
123
+ params.C_d_stride = C.stride(0);
124
+ } else {
125
+ params.C_batch_stride = C.stride(0);
126
+ params.C_group_stride = C.stride(1);
127
+ }
128
+ params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
129
+ params.u_batch_stride = u.stride(0);
130
+ params.u_d_stride = u.stride(1);
131
+ params.delta_batch_stride = delta.stride(0);
132
+ params.delta_d_stride = delta.stride(1);
133
+ if (has_z) {
134
+ params.z_batch_stride = z.stride(0);
135
+ params.z_d_stride = z.stride(1);
136
+ params.out_z_batch_stride = out_z.stride(0);
137
+ params.out_z_d_stride = out_z.stride(1);
138
+ }
139
+ params.out_batch_stride = out.stride(0);
140
+ params.out_d_stride = out.stride(1);
141
+ }
142
+
143
+ void set_ssm_params_bwd(SSMParamsBwd &params,
144
+ // sizes
145
+ const size_t batch,
146
+ const size_t dim,
147
+ const size_t seqlen,
148
+ const size_t dstate,
149
+ const size_t n_groups,
150
+ const size_t n_chunks,
151
+ const bool is_variable_B,
152
+ const bool is_variable_C,
153
+ // device pointers
154
+ const at::Tensor u,
155
+ const at::Tensor delta,
156
+ const at::Tensor A,
157
+ const at::Tensor B,
158
+ const at::Tensor C,
159
+ const at::Tensor z,
160
+ const at::Tensor out,
161
+ const at::Tensor out_z,
162
+ void* D_ptr,
163
+ void* delta_bias_ptr,
164
+ void* x_ptr,
165
+ const at::Tensor dout,
166
+ const at::Tensor du,
167
+ const at::Tensor ddelta,
168
+ const at::Tensor dA,
169
+ const at::Tensor dB,
170
+ const at::Tensor dC,
171
+ const at::Tensor dz,
172
+ void* dD_ptr,
173
+ void* ddelta_bias_ptr,
174
+ bool has_z,
175
+ bool delta_softplus,
176
+ bool recompute_out_z) {
177
+ // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z
178
+ set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
179
+ u, delta, A, B, C, has_z ? out : dout,
180
+ has_z ? z : dout,
181
+ // If not recompute_out_z, pass dout instead of out_z.
182
+ // This won't be used by the bwd kernel
183
+ recompute_out_z ? out_z : dout,
184
+ D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus);
185
+ if (!recompute_out_z) { params.out_z_ptr = nullptr; }
186
+
187
+ // Set the pointers and strides.
188
+ params.dout_ptr = dout.data_ptr();
189
+ params.du_ptr = du.data_ptr();
190
+ params.dA_ptr = dA.data_ptr();
191
+ params.dB_ptr = dB.data_ptr();
192
+ params.dC_ptr = dC.data_ptr();
193
+ params.dD_ptr = dD_ptr;
194
+ params.ddelta_ptr = ddelta.data_ptr();
195
+ params.ddelta_bias_ptr = ddelta_bias_ptr;
196
+ params.dz_ptr = has_z ? dz.data_ptr() : nullptr;
197
+ // All stride are in elements, not bytes.
198
+ params.dout_batch_stride = dout.stride(0);
199
+ params.dout_d_stride = dout.stride(1);
200
+ params.dA_d_stride = dA.stride(0);
201
+ params.dA_dstate_stride = dA.stride(1);
202
+ if (!is_variable_B) {
203
+ params.dB_d_stride = dB.stride(0);
204
+ } else {
205
+ params.dB_batch_stride = dB.stride(0);
206
+ params.dB_group_stride = dB.stride(1);
207
+ }
208
+ params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2);
209
+ if (!is_variable_C) {
210
+ params.dC_d_stride = dC.stride(0);
211
+ } else {
212
+ params.dC_batch_stride = dC.stride(0);
213
+ params.dC_group_stride = dC.stride(1);
214
+ }
215
+ params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2);
216
+ params.du_batch_stride = du.stride(0);
217
+ params.du_d_stride = du.stride(1);
218
+ params.ddelta_batch_stride = ddelta.stride(0);
219
+ params.ddelta_d_stride = ddelta.stride(1);
220
+ if (has_z) {
221
+ params.dz_batch_stride = dz.stride(0);
222
+ params.dz_d_stride = dz.stride(1);
223
+ }
224
+ }
225
+
226
+ std::vector<at::Tensor>
227
+ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
228
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
229
+ const c10::optional<at::Tensor> &D_,
230
+ const c10::optional<at::Tensor> &z_,
231
+ const c10::optional<at::Tensor> &delta_bias_,
232
+ bool delta_softplus) {
233
+ auto input_type = u.scalar_type();
234
+ auto weight_type = A.scalar_type();
235
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
236
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
237
+
238
+ const bool is_variable_B = B.dim() >= 3;
239
+ const bool is_variable_C = C.dim() >= 3;
240
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
241
+
242
+ TORCH_CHECK(delta.scalar_type() == input_type);
243
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
244
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
245
+
246
+ TORCH_CHECK(u.is_cuda());
247
+ TORCH_CHECK(delta.is_cuda());
248
+ TORCH_CHECK(A.is_cuda());
249
+ TORCH_CHECK(B.is_cuda());
250
+ TORCH_CHECK(C.is_cuda());
251
+
252
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
253
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
254
+
255
+ const auto sizes = u.sizes();
256
+ const int batch_size = sizes[0];
257
+ const int dim = sizes[1];
258
+ const int seqlen = sizes[2];
259
+ const int dstate = A.size(1);
260
+ const int n_groups = is_variable_B ? B.size(1) : 1;
261
+
262
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
263
+
264
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
265
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
266
+ CHECK_SHAPE(A, dim, dstate);
267
+ if (!is_variable_B) {
268
+ CHECK_SHAPE(B, dim, dstate);
269
+ } else {
270
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
271
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
272
+ }
273
+ if (!is_variable_C) {
274
+ CHECK_SHAPE(C, dim, dstate);
275
+ } else {
276
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
277
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
278
+ }
279
+
280
+ if (D_.has_value()) {
281
+ auto D = D_.value();
282
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
283
+ TORCH_CHECK(D.is_cuda());
284
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
285
+ CHECK_SHAPE(D, dim);
286
+ }
287
+
288
+ if (delta_bias_.has_value()) {
289
+ auto delta_bias = delta_bias_.value();
290
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
291
+ TORCH_CHECK(delta_bias.is_cuda());
292
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
293
+ CHECK_SHAPE(delta_bias, dim);
294
+ }
295
+
296
+ at::Tensor z, out_z;
297
+ const bool has_z = z_.has_value();
298
+ if (has_z) {
299
+ z = z_.value();
300
+ TORCH_CHECK(z.scalar_type() == input_type);
301
+ TORCH_CHECK(z.is_cuda());
302
+ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
303
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
304
+ out_z = torch::empty_like(z);
305
+ }
306
+
307
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
308
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
309
+ // at::Tensor out = torch::empty_like(u);
310
+ // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
311
+ at::Tensor out = torch::empty_like(delta);
312
+ at::Tensor x;
313
+ x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
314
+
315
+ SSMParamsBase params;
316
+ set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
317
+ u, delta, A, B, C, out, z, out_z,
318
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
319
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
320
+ x.data_ptr(),
321
+ has_z,
322
+ delta_softplus);
323
+
324
+ // Otherwise the kernel will be launched from cuda:0 device
325
+ // Cast to char to avoid compiler warning about narrowing
326
+ at::cuda::CUDAGuard device_guard{u.device()};
327
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
328
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
329
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
330
+ selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
331
+ });
332
+ });
333
+ std::vector<at::Tensor> result = {out, x};
334
+ if (has_z) { result.push_back(out_z); }
335
+ return result;
336
+ }
337
+
338
+ std::vector<at::Tensor>
339
+ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
340
+ const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
341
+ const c10::optional<at::Tensor> &D_,
342
+ const c10::optional<at::Tensor> &z_,
343
+ const c10::optional<at::Tensor> &delta_bias_,
344
+ const at::Tensor &dout,
345
+ const c10::optional<at::Tensor> &x_,
346
+ const c10::optional<at::Tensor> &out_,
347
+ c10::optional<at::Tensor> dz_,
348
+ bool delta_softplus,
349
+ bool recompute_out_z) {
350
+ auto input_type = u.scalar_type();
351
+ auto weight_type = A.scalar_type();
352
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
353
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
354
+
355
+ const bool is_variable_B = B.dim() >= 3;
356
+ const bool is_variable_C = C.dim() >= 3;
357
+ const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
358
+
359
+ TORCH_CHECK(delta.scalar_type() == input_type);
360
+ TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
361
+ TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
362
+ TORCH_CHECK(dout.scalar_type() == input_type);
363
+
364
+ TORCH_CHECK(u.is_cuda());
365
+ TORCH_CHECK(delta.is_cuda());
366
+ TORCH_CHECK(A.is_cuda());
367
+ TORCH_CHECK(B.is_cuda());
368
+ TORCH_CHECK(C.is_cuda());
369
+ TORCH_CHECK(dout.is_cuda());
370
+
371
+ TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
372
+ TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
373
+ TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);
374
+
375
+ const auto sizes = u.sizes();
376
+ const int batch_size = sizes[0];
377
+ const int dim = sizes[1];
378
+ const int seqlen = sizes[2];
379
+ const int dstate = A.size(1);
380
+ const int n_groups = is_variable_B ? B.size(1) : 1;
381
+
382
+ TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
383
+
384
+ CHECK_SHAPE(u, batch_size, dim, seqlen);
385
+ CHECK_SHAPE(delta, batch_size, dim, seqlen);
386
+ CHECK_SHAPE(A, dim, dstate);
387
+ if (!is_variable_B) {
388
+ CHECK_SHAPE(B, dim, dstate);
389
+ } else {
390
+ CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
391
+ TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
392
+ }
393
+ if (!is_variable_C) {
394
+ CHECK_SHAPE(C, dim, dstate);
395
+ } else {
396
+ CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
397
+ TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
398
+ }
399
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
400
+
401
+ if (D_.has_value()) {
402
+ auto D = D_.value();
403
+ TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
404
+ TORCH_CHECK(D.is_cuda());
405
+ TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
406
+ CHECK_SHAPE(D, dim);
407
+ }
408
+
409
+ if (delta_bias_.has_value()) {
410
+ auto delta_bias = delta_bias_.value();
411
+ TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
412
+ TORCH_CHECK(delta_bias.is_cuda());
413
+ TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
414
+ CHECK_SHAPE(delta_bias, dim);
415
+ }
416
+
417
+ at::Tensor z, out, dz, out_z;
418
+ const bool has_z = z_.has_value();
419
+ if (has_z) {
420
+ z = z_.value();
421
+ TORCH_CHECK(z.scalar_type() == input_type);
422
+ TORCH_CHECK(z.is_cuda());
423
+ TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
424
+ CHECK_SHAPE(z, batch_size, dim, seqlen);
425
+
426
+ TORCH_CHECK(out_.has_value());
427
+ out = out_.value();
428
+ TORCH_CHECK(out.scalar_type() == input_type);
429
+ TORCH_CHECK(out.is_cuda());
430
+ TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
431
+ CHECK_SHAPE(out, batch_size, dim, seqlen);
432
+
433
+ if (dz_.has_value()) {
434
+ dz = dz_.value();
435
+ TORCH_CHECK(dz.scalar_type() == input_type);
436
+ TORCH_CHECK(dz.is_cuda());
437
+ TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
438
+ CHECK_SHAPE(dz, batch_size, dim, seqlen);
439
+ } else {
440
+ dz = torch::empty_like(z);
441
+ }
442
+ if (recompute_out_z) {
443
+ out_z = torch::empty_like(out);
444
+ }
445
+ }
446
+
447
+ const int n_chunks = (seqlen + 2048 - 1) / 2048;
448
+ // const int n_chunks = (seqlen + 1024 - 1) / 1024;
449
+ if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); }
450
+ if (x_.has_value()) {
451
+ auto x = x_.value();
452
+ TORCH_CHECK(x.scalar_type() == weight_type);
453
+ TORCH_CHECK(x.is_cuda());
454
+ TORCH_CHECK(x.is_contiguous());
455
+ CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate);
456
+ }
457
+
458
+ at::Tensor du = torch::empty_like(u);
459
+ at::Tensor ddelta = torch::empty_like(delta);
460
+ at::Tensor dA = torch::zeros_like(A);
461
+ at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32));
462
+ at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32));
463
+ at::Tensor dD;
464
+ if (D_.has_value()) { dD = torch::zeros_like(D_.value()); }
465
+ at::Tensor ddelta_bias;
466
+ if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); }
467
+
468
+ SSMParamsBwd params;
469
+ set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
470
+ u, delta, A, B, C, z, out, out_z,
471
+ D_.has_value() ? D_.value().data_ptr() : nullptr,
472
+ delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
473
+ x_.has_value() ? x_.value().data_ptr() : nullptr,
474
+ dout, du, ddelta, dA, dB, dC, dz,
475
+ D_.has_value() ? dD.data_ptr() : nullptr,
476
+ delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr,
477
+ has_z, delta_softplus, recompute_out_z);
478
+
479
+ // Otherwise the kernel will be launched from cuda:0 device
480
+ // Cast to char to avoid compiler warning about narrowing
481
+ at::cuda::CUDAGuard device_guard{u.device()};
482
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
483
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] {
484
+ DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] {
485
+ selective_scan_bwd_cuda<input_t, weight_t>(params, stream);
486
+ });
487
+ });
488
+ std::vector<at::Tensor> result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias};
489
+ if (has_z) { result.push_back(dz); }
490
+ if (recompute_out_z) { result.push_back(out_z); }
491
+ return result;
492
+ }
493
+
494
+ //PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
495
+ // m.def("fwd", &selective_scan_fwd, "Selective scan forward");
496
+ // m.def("bwd", &selective_scan_bwd, "Selective scan backward");
497
+ //}
selective-scan/selective_scan.h ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct SSMScanParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, seqlen, n_chunks;
13
+ index_t a_batch_stride;
14
+ index_t b_batch_stride;
15
+ index_t out_batch_stride;
16
+
17
+ // Common data pointers.
18
+ void *__restrict__ a_ptr;
19
+ void *__restrict__ b_ptr;
20
+ void *__restrict__ out_ptr;
21
+ void *__restrict__ x_ptr;
22
+ };
23
+
24
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
25
+
26
+ struct SSMParamsBase {
27
+ using index_t = uint32_t;
28
+
29
+ int batch, dim, seqlen, dstate, n_groups, n_chunks;
30
+ int dim_ngroups_ratio;
31
+ bool is_variable_B;
32
+ bool is_variable_C;
33
+
34
+ bool delta_softplus;
35
+
36
+ index_t A_d_stride;
37
+ index_t A_dstate_stride;
38
+ index_t B_batch_stride;
39
+ index_t B_d_stride;
40
+ index_t B_dstate_stride;
41
+ index_t B_group_stride;
42
+ index_t C_batch_stride;
43
+ index_t C_d_stride;
44
+ index_t C_dstate_stride;
45
+ index_t C_group_stride;
46
+ index_t u_batch_stride;
47
+ index_t u_d_stride;
48
+ index_t delta_batch_stride;
49
+ index_t delta_d_stride;
50
+ index_t z_batch_stride;
51
+ index_t z_d_stride;
52
+ index_t out_batch_stride;
53
+ index_t out_d_stride;
54
+ index_t out_z_batch_stride;
55
+ index_t out_z_d_stride;
56
+
57
+ // Common data pointers.
58
+ void *__restrict__ A_ptr;
59
+ void *__restrict__ B_ptr;
60
+ void *__restrict__ C_ptr;
61
+ void *__restrict__ D_ptr;
62
+ void *__restrict__ u_ptr;
63
+ void *__restrict__ delta_ptr;
64
+ void *__restrict__ delta_bias_ptr;
65
+ void *__restrict__ out_ptr;
66
+ void *__restrict__ x_ptr;
67
+ void *__restrict__ z_ptr;
68
+ void *__restrict__ out_z_ptr;
69
+ };
70
+
71
+ struct SSMParamsBwd: public SSMParamsBase {
72
+ index_t dout_batch_stride;
73
+ index_t dout_d_stride;
74
+ index_t dA_d_stride;
75
+ index_t dA_dstate_stride;
76
+ index_t dB_batch_stride;
77
+ index_t dB_group_stride;
78
+ index_t dB_d_stride;
79
+ index_t dB_dstate_stride;
80
+ index_t dC_batch_stride;
81
+ index_t dC_group_stride;
82
+ index_t dC_d_stride;
83
+ index_t dC_dstate_stride;
84
+ index_t du_batch_stride;
85
+ index_t du_d_stride;
86
+ index_t dz_batch_stride;
87
+ index_t dz_d_stride;
88
+ index_t ddelta_batch_stride;
89
+ index_t ddelta_d_stride;
90
+
91
+ // Common data pointers.
92
+ void *__restrict__ dout_ptr;
93
+ void *__restrict__ dA_ptr;
94
+ void *__restrict__ dB_ptr;
95
+ void *__restrict__ dC_ptr;
96
+ void *__restrict__ dD_ptr;
97
+ void *__restrict__ du_ptr;
98
+ void *__restrict__ dz_ptr;
99
+ void *__restrict__ ddelta_ptr;
100
+ void *__restrict__ ddelta_bias_ptr;
101
+ };
selective-scan/selective_scan_bwd_bf16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
selective-scan/selective_scan_bwd_bf16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::BFloat16, float>(SSMParamsBwd &params, cudaStream_t stream);
selective-scan/selective_scan_bwd_fp16_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
selective-scan/selective_scan_bwd_fp16_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<at::Half, float>(SSMParamsBwd &params, cudaStream_t stream);
selective-scan/selective_scan_bwd_fp32_complex.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, complex_t>(SSMParamsBwd &params, cudaStream_t stream);
selective-scan/selective_scan_bwd_fp32_real.cu ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_bwd_kernel.cuh"
8
+
9
+ template void selective_scan_bwd_cuda<float, float>(SSMParamsBwd &params, cudaStream_t stream);
selective-scan/selective_scan_bwd_kernel.cuh ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+ #include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex
11
+
12
+ #ifndef USE_ROCM
13
+ #include <cub/block/block_load.cuh>
14
+ #include <cub/block/block_store.cuh>
15
+ #include <cub/block/block_scan.cuh>
16
+ #include <cub/block/block_reduce.cuh>
17
+ #else
18
+ #include <hipcub/hipcub.hpp>
19
+ namespace cub = hipcub;
20
+ #endif
21
+
22
+ #include "selective_scan.h"
23
+ #include "selective_scan_common.h"
24
+ #include "reverse_scan.cuh"
25
+ #include "static_switch.h"
26
+
27
+ template<typename scalar_t> __device__ __forceinline__ scalar_t conj(scalar_t x);
28
+ template<> __device__ __forceinline__ float conj<float>(float x) { return x; }
29
+ template<> __device__ __forceinline__ complex_t conj<complex_t>(complex_t x) { return std::conj(x); }
30
+
31
+ template<int kNThreads_, int kNItems_, bool kIsEvenLen_, bool kIsVariableB_, bool kIsVariableC_,
32
+ bool kDeltaSoftplus_, bool kHasZ_, typename input_t_, typename weight_t_>
33
+ struct Selective_Scan_bwd_kernel_traits {
34
+ static_assert(kNItems_ % 4 == 0);
35
+ using input_t = input_t_;
36
+ using weight_t = weight_t_;
37
+ static constexpr int kNThreads = kNThreads_;
38
+ static constexpr int kNItems = kNItems_;
39
+ static constexpr int kNBytes = sizeof(input_t);
40
+ static_assert(kNBytes == 2 || kNBytes == 4);
41
+ static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
42
+ static_assert(kNItems % kNElts == 0);
43
+ static constexpr int kNLoads = kNItems / kNElts;
44
+ static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
45
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
46
+ static constexpr bool kIsVariableB = kIsVariableB_;
47
+ static constexpr bool kIsVariableC = kIsVariableC_;
48
+ static constexpr bool kDeltaSoftplus = kDeltaSoftplus_;
49
+ static constexpr bool kHasZ = kHasZ_;
50
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy.
51
+ // For complex this would lead to massive register spilling, so we keep it at 2.
52
+ static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2;
53
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
54
+ using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
55
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
56
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
57
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
58
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
59
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
60
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads, cub::BLOCK_STORE_WARP_TRANSPOSE>;
61
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
62
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
63
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
64
+ using BlockReverseScanT = BlockReverseScan<scan_t, kNThreads>;
65
+ using BlockReduceT = cub::BlockReduce<scan_t, kNThreads>;
66
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
67
+ using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
68
+ using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
69
+
70
+ static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
71
+ sizeof(typename BlockLoadVecT::TempStorage),
72
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
73
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
74
+ sizeof(typename BlockStoreT::TempStorage),
75
+ sizeof(typename BlockStoreVecT::TempStorage)});
76
+ static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
77
+ static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
78
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
79
+ };
80
+
81
+ template<typename Ktraits>
82
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
83
+ void selective_scan_bwd_kernel(SSMParamsBwd params) {
84
+ constexpr bool kIsComplex = Ktraits::kIsComplex;
85
+ constexpr bool kIsVariableB = Ktraits::kIsVariableB;
86
+ constexpr bool kIsVariableC = Ktraits::kIsVariableC;
87
+ constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus;
88
+ constexpr bool kHasZ = Ktraits::kHasZ;
89
+ constexpr int kNThreads = Ktraits::kNThreads;
90
+ constexpr int kNItems = Ktraits::kNItems;
91
+ using input_t = typename Ktraits::input_t;
92
+ using weight_t = typename Ktraits::weight_t;
93
+ using scan_t = typename Ktraits::scan_t;
94
+
95
+ // Shared memory.
96
+ extern __shared__ char smem_[];
97
+ // cast to lvalue reference of expected type
98
+ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
99
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
100
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
101
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
102
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
103
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
104
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
105
+ auto& smem_exchange = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
106
+ auto& smem_exchange1 = *reinterpret_cast<typename Ktraits::BlockExchangeT::TempStorage*>(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage));
107
+ auto& smem_reduce = *reinterpret_cast<typename Ktraits::BlockReduceT::TempStorage*>(reinterpret_cast<char *>(&smem_exchange) + Ktraits::kSmemExchangeSize);
108
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(&smem_reduce);
109
+ auto& smem_reduce_complex = *reinterpret_cast<typename Ktraits::BlockReduceComplexT::TempStorage*>(&smem_reduce);
110
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(reinterpret_cast<char *>(&smem_reduce) + Ktraits::kSmemReduceSize);
111
+ auto& smem_reverse_scan = *reinterpret_cast<typename Ktraits::BlockReverseScanT::TempStorage*>(reinterpret_cast<char *>(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage));
112
+ weight_t *smem_delta_a = reinterpret_cast<weight_t *>(smem_ + Ktraits::kSmemSize);
113
+ scan_t *smem_running_postfix = reinterpret_cast<scan_t *>(smem_delta_a + 2 * MAX_DSTATE + kNThreads);
114
+ weight_t *smem_da = reinterpret_cast<weight_t *>(smem_running_postfix + MAX_DSTATE);
115
+ weight_t *smem_dbc = reinterpret_cast<weight_t *>(smem_da + MAX_DSTATE);
116
+
117
+ const int batch_id = blockIdx.x;
118
+ const int dim_id = blockIdx.y;
119
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
120
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
121
+ + dim_id * params.u_d_stride;
122
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
123
+ + dim_id * params.delta_d_stride;
124
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
125
+ + dim_id * params.dout_d_stride;
126
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * params.A_d_stride;
127
+ weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * params.B_d_stride;
128
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
129
+ weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * params.C_d_stride;
130
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
131
+ weight_t *dA = reinterpret_cast<weight_t *>(params.dA_ptr) + dim_id * params.dA_d_stride;
132
+ weight_t *dB = reinterpret_cast<weight_t *>(params.dB_ptr)
133
+ + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride);
134
+ weight_t *dC = reinterpret_cast<weight_t *>(params.dC_ptr)
135
+ + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride);
136
+ float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.dD_ptr) + dim_id;
137
+ float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.D_ptr)[dim_id];
138
+ float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast<float *>(params.ddelta_bias_ptr) + dim_id;
139
+ float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id];
140
+ scan_t *x = params.x_ptr == nullptr
141
+ ? nullptr
142
+ : reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate;
143
+ float dD_val = 0;
144
+ float ddelta_bias_val = 0;
145
+
146
+ constexpr int kChunkSize = kNThreads * kNItems;
147
+ u += (params.n_chunks - 1) * kChunkSize;
148
+ delta += (params.n_chunks - 1) * kChunkSize;
149
+ dout += (params.n_chunks - 1) * kChunkSize;
150
+ Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
151
+ Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2);
152
+ for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) {
153
+ input_t u_vals[kNItems];
154
+ input_t delta_vals_load[kNItems];
155
+ input_t dout_vals_load[kNItems];
156
+ __syncthreads();
157
+ load_input<Ktraits>(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize);
158
+ u -= kChunkSize;
159
+ __syncthreads();
160
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
161
+ // Will reload delta at the same location if kDeltaSoftplus
162
+ if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; }
163
+ __syncthreads();
164
+ load_input<Ktraits>(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
165
+ dout -= kChunkSize;
166
+
167
+ float dout_vals[kNItems], delta_vals[kNItems];
168
+ #pragma unroll
169
+ for (int i = 0; i < kNItems; ++i) {
170
+ dout_vals[i] = float(dout_vals_load[i]);
171
+ delta_vals[i] = float(delta_vals_load[i]) + delta_bias;
172
+ if constexpr (kDeltaSoftplus) {
173
+ delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i];
174
+ }
175
+ }
176
+
177
+ if constexpr (kHasZ) {
178
+ input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
179
+ + dim_id * params.z_d_stride + chunk * kChunkSize;
180
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
181
+ + dim_id * params.out_d_stride + chunk * kChunkSize;
182
+ input_t *dz = reinterpret_cast<input_t *>(params.dz_ptr) + batch_id * params.dz_batch_stride
183
+ + dim_id * params.dz_d_stride + chunk * kChunkSize;
184
+ input_t z_vals[kNItems], out_vals[kNItems];
185
+ __syncthreads();
186
+ load_input<Ktraits>(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
187
+ __syncthreads();
188
+ load_input<Ktraits>(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize);
189
+ float dz_vals[kNItems], z_silu_vals[kNItems];
190
+ #pragma unroll
191
+ for (int i = 0; i < kNItems; ++i) {
192
+ float z_val = z_vals[i];
193
+ float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val));
194
+ z_silu_vals[i] = z_val * z_sigmoid_val;
195
+ dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val
196
+ * (1.0f + z_val * (1.0f - z_sigmoid_val));
197
+ dout_vals[i] *= z_silu_vals[i];
198
+ }
199
+ __syncthreads();
200
+ store_output<Ktraits>(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize);
201
+ if (params.out_z_ptr != nullptr) { // Recompute and store out_z
202
+ float out_z_vals[kNItems];
203
+ #pragma unroll
204
+ for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; }
205
+ // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) {
206
+ // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]);
207
+ // }
208
+ input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
209
+ + dim_id * params.out_z_d_stride + chunk * kChunkSize;
210
+ __syncthreads();
211
+ store_output<Ktraits>(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize);
212
+ }
213
+ }
214
+
215
+ float du_vals[kNItems];
216
+ #pragma unroll
217
+ for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; }
218
+ #pragma unroll
219
+ for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); }
220
+
221
+ float ddelta_vals[kNItems] = {0};
222
+ __syncthreads();
223
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
224
+ const weight_t A_val = A[state_idx * params.A_dstate_stride];
225
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
226
+ weight_t A_scaled;
227
+ constexpr float kLog2e = M_LOG2E;
228
+ if constexpr (!kIsComplex) {
229
+ A_scaled = A_val * kLog2e;
230
+ } else {
231
+ A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_);
232
+ }
233
+ weight_t B_val, C_val;
234
+ weight_t B_vals[kNItems], C_vals[kNItems];
235
+ if constexpr (!kIsVariableB) {
236
+ B_val = B[state_idx * params.B_dstate_stride];
237
+ } else {
238
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
239
+ smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
240
+ }
241
+ if constexpr (!kIsVariableC) {
242
+ C_val = C[state_idx * params.C_dstate_stride];
243
+ } else {
244
+ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
245
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
246
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
247
+ }
248
+ // const weight_t A_val = smem_a[state_idx];
249
+ scan_t thread_data[kNItems], thread_reverse_data[kNItems];
250
+ if constexpr (!kIsComplex) {
251
+ #pragma unroll
252
+ for (int i = 0; i < kNItems; ++i) {
253
+ const float delta_a_exp = exp2f(delta_vals[i] * A_scaled);
254
+ thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
255
+ if (i == 0) {
256
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
257
+ } else {
258
+ thread_reverse_data[i - 1].x = delta_a_exp;
259
+ }
260
+ thread_reverse_data[i].y = dout_vals[i] *
261
+ (!kIsVariableC
262
+ ? (!kIsVariableB ? B_val * C_val : C_val)
263
+ : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
264
+ }
265
+ __syncthreads();
266
+ thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1
267
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
268
+ : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
269
+ // Initialize running total
270
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
271
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
272
+ typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
273
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
274
+ );
275
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
276
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
277
+ typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
278
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
279
+ );
280
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
281
+ weight_t dA_val = 0, dBC_val = 0;
282
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
283
+ #pragma unroll
284
+ for (int i = 0; i < kNItems; ++i) {
285
+ const float dx = thread_reverse_data[i].y;
286
+ const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i];
287
+ du_vals[i] += ddelta_u * delta_vals[i];
288
+ const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]);
289
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a;
290
+ dA_val += dx * delta_vals[i] * a;
291
+ if constexpr (!kIsVariableB || !kIsVariableC) {
292
+ if constexpr (!kIsVariableB) { // dBC_val is dB_val
293
+ dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]);
294
+ } else { // dBC_val is dC_val
295
+ dBC_val += dout_vals[i] * thread_data[i].y;
296
+ }
297
+ }
298
+ if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
299
+ if constexpr (kIsVariableC) {
300
+ dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y);
301
+ }
302
+ }
303
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
304
+ if constexpr (kIsVariableB || kIsVariableC) {
305
+ if constexpr (kIsVariableB) {
306
+ typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
307
+ }
308
+ if constexpr (kIsVariableC) {
309
+ auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
310
+ typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
311
+ }
312
+ const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
313
+ weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
314
+ weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x;
315
+ #pragma unroll
316
+ for (int i = 0; i < kNItems; ++i) {
317
+ if (i * kNThreads < seqlen_remaining) {
318
+ if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); }
319
+ if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); }
320
+ }
321
+ }
322
+ }
323
+ if constexpr (!kIsVariableB || !kIsVariableC) {
324
+ float2 dA_dBC_val = make_float2(dA_val, dBC_val);
325
+ dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
326
+ dA_val = dA_dBC_val.x;
327
+ if (threadIdx.x == 0) {
328
+ smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
329
+ }
330
+ } else {
331
+ dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
332
+ }
333
+ if (threadIdx.x == 0) {
334
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
335
+ }
336
+ } else {
337
+ #pragma unroll
338
+ for (int i = 0; i < kNItems; ++i) {
339
+ // Pytorch's implementation of complex exp (which calls thrust) is very slow
340
+ complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled);
341
+ weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]);
342
+ thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
343
+ if (i == 0) {
344
+ smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp;
345
+ } else {
346
+ thread_reverse_data[i - 1].x = delta_a_exp.real_;
347
+ thread_reverse_data[i - 1].y = -delta_a_exp.imag_;
348
+ }
349
+ complex_t dout_BC = 2 * dout_vals[i]
350
+ * conj(!kIsVariableC
351
+ ? (!kIsVariableB ? B_val * C_val : C_val)
352
+ : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i]));
353
+ thread_reverse_data[i].z = dout_BC.real_;
354
+ thread_reverse_data[i].w = dout_BC.imag_;
355
+ }
356
+ __syncthreads();
357
+ complex_t delta_a_exp = threadIdx.x == kNThreads - 1
358
+ ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE])
359
+ : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE];
360
+ thread_reverse_data[kNItems - 1].x = delta_a_exp.real_;
361
+ thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_;
362
+ // Initialize running total
363
+ scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
364
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
365
+ typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
366
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
367
+ );
368
+ scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
369
+ SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
370
+ typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
371
+ thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
372
+ );
373
+ if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
374
+ weight_t dA_val = 0, dBC_val = 0;
375
+ weight_t dB_vals[kNItems], dC_vals[kNItems];
376
+ #pragma unroll
377
+ for (int i = 0; i < kNItems; ++i) {
378
+ complex_t x = complex_t(thread_data[i].z, thread_data[i].w);
379
+ complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w);
380
+ float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_;
381
+ if constexpr (!kIsVariableB || !kIsVariableC) {
382
+ if constexpr (!kIsVariableB) { // dBC_val is dB_val
383
+ dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]);
384
+ } else { // dBC_val is dC_val
385
+ dBC_val += (2 * dout_vals[i]) * conj(x);
386
+ }
387
+ }
388
+ const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]));
389
+ du_vals[i] += ddelta_u * delta_vals[i];
390
+ ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_;
391
+ dA_val += delta_vals[i] * dx * a_conj;
392
+ if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); }
393
+ if constexpr (kIsVariableC) {
394
+ dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x);
395
+ }
396
+ }
397
+ // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
398
+ if constexpr (kIsVariableB || kIsVariableC) {
399
+ float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2];
400
+ if constexpr (kIsVariableB) {
401
+ #pragma unroll
402
+ for (int i = 0; i < kNItems; ++i) {
403
+ dB_vals_f[i * 2] = dB_vals[i].real_;
404
+ dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
405
+ }
406
+ typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
407
+ }
408
+ if constexpr (kIsVariableC) {
409
+ #pragma unroll
410
+ for (int i = 0; i < kNItems; ++i) {
411
+ dC_vals_f[i * 2] = dC_vals[i].real_;
412
+ dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
413
+ }
414
+ auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
415
+ typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
416
+ }
417
+ const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
418
+ float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
419
+ float *dC_cur = reinterpret_cast<float *>(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
420
+ #pragma unroll
421
+ for (int i = 0; i < kNItems * 2; ++i) {
422
+ if (i * kNThreads < seqlen_remaining) {
423
+ if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); }
424
+ if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); }
425
+ }
426
+ }
427
+ }
428
+ if constexpr (!kIsVariableB || !kIsVariableC) {
429
+ float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
430
+ dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
431
+ dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
432
+ dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
433
+ if (threadIdx.x == 0) {
434
+ smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
435
+ }
436
+ } else {
437
+ dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
438
+ }
439
+ if (threadIdx.x == 0) {
440
+ smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
441
+ }
442
+ }
443
+ }
444
+
445
+ if constexpr (kDeltaSoftplus) {
446
+ __syncthreads();
447
+ input_t delta_vals_load[kNItems];
448
+ load_input<Ktraits>(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize);
449
+ delta -= kChunkSize;
450
+ #pragma unroll
451
+ for (int i = 0; i < kNItems; ++i) {
452
+ float delta_val = float(delta_vals_load[i]) + delta_bias;
453
+ float delta_val_neg_exp = expf(-delta_val);
454
+ ddelta_vals[i] = delta_val <= 20.f
455
+ ? ddelta_vals[i] / (1.f + delta_val_neg_exp)
456
+ : ddelta_vals[i];
457
+ }
458
+ }
459
+ for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; }
460
+
461
+ input_t *du = reinterpret_cast<input_t *>(params.du_ptr) + batch_id * params.du_batch_stride
462
+ + dim_id * params.du_d_stride + chunk * kChunkSize;
463
+ input_t *ddelta = reinterpret_cast<input_t *>(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride
464
+ + dim_id * params.ddelta_d_stride + chunk * kChunkSize;
465
+ __syncthreads();
466
+ store_output<Ktraits>(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize);
467
+ __syncthreads();
468
+ store_output<Ktraits>(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize);
469
+
470
+ Bvar -= kChunkSize * (!kIsComplex ? 1 : 2);
471
+ Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
472
+ }
473
+ if (params.dD_ptr != nullptr) {
474
+ dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
475
+ if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
476
+ }
477
+ if (params.ddelta_bias_ptr != nullptr) {
478
+ __syncthreads();
479
+ ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
480
+ if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
481
+ }
482
+ for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
483
+ gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]);
484
+ weight_t dBC_val;
485
+ if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; }
486
+ if constexpr (!kIsVariableB) {
487
+ gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]),
488
+ !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val);
489
+ }
490
+ if constexpr (!kIsVariableC) {
491
+ gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]),
492
+ !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val);
493
+ }
494
+ }
495
+ }
496
+
497
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
498
+ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
499
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
500
+ BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
501
+ BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
502
+ BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] {
503
+ BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
504
+ using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, kIsEvenLen, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
505
+ // using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
506
+ // TODO: check this
507
+ constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
508
+
509
+ dim3 grid(params.batch, params.dim);
510
+
511
+ auto kernel = &selective_scan_bwd_kernel<Ktraits>;
512
+
513
+ if (kSmemSize >= 48 * 1024) {
514
+
515
+ #ifndef USE_ROCM
516
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
517
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
518
+ #else
519
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
520
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
521
+ std::cerr << "Warning (selective_scan_bwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
522
+ #endif
523
+
524
+ }
525
+
526
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
527
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
528
+ });
529
+ });
530
+ });
531
+ });
532
+ });
533
+ }
534
+
535
+ template<typename input_t, typename weight_t>
536
+ void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
537
+
538
+ #ifndef USE_ROCM
539
+ if (params.seqlen <= 128) {
540
+ selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
541
+ } else if (params.seqlen <= 256) {
542
+ selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
543
+ } else if (params.seqlen <= 512) {
544
+ selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
545
+ } else if (params.seqlen <= 1024) {
546
+ selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
547
+ } else {
548
+ selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
549
+ }
550
+ #else
551
+ if (params.seqlen <= 256) {
552
+ selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
553
+ } else if (params.seqlen <= 512) {
554
+ selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
555
+ } else if (params.seqlen <= 1024) {
556
+ selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
557
+ } else {
558
+ selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
559
+ }
560
+ #endif
561
+ }
selective-scan/selective_scan_common.h ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #ifndef USE_ROCM
8
+ #include <cuda_bf16.h>
9
+ #else
10
+ #include <hip/hip_bf16.h>
11
+ #endif
12
+ #include <cuda_fp16.h>
13
+ #include <c10/util/complex.h> // For scalar_value_type
14
+
15
+
16
+ #ifndef USE_ROCM
17
+
18
+ constexpr size_t custom_max(std::initializer_list<size_t> ilist)
19
+ {
20
+ return std::max(ilist);
21
+ }
22
+
23
+ template<typename T>
24
+ constexpr T constexpr_min(T a, T b) {
25
+ return std::min(a, b);
26
+ }
27
+
28
+ #else
29
+ constexpr size_t custom_max(std::initializer_list<size_t> ilist)
30
+ {
31
+ return *std::max_element(ilist.begin(), ilist.end());
32
+ }
33
+
34
+ template<typename T>
35
+ constexpr T constexpr_min(T a, T b) {
36
+ return a < b ? a : b;
37
+ }
38
+ #endif
39
+
40
+
41
+ #define MAX_DSTATE 256
42
+
43
+ using complex_t = c10::complex<float>;
44
+
45
+ inline __device__ float2 operator+(const float2 & a, const float2 & b){
46
+ return {a.x + b.x, a.y + b.y};
47
+ }
48
+
49
+ inline __device__ float3 operator+(const float3 &a, const float3 &b) {
50
+ return {a.x + b.x, a.y + b.y, a.z + b.z};
51
+ }
52
+
53
+ inline __device__ float4 operator+(const float4 & a, const float4 & b){
54
+ return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
55
+ }
56
+
57
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
58
+
59
+ template<int BYTES> struct BytesToType {};
60
+
61
+ template<> struct BytesToType<16> {
62
+ using Type = uint4;
63
+ static_assert(sizeof(Type) == 16);
64
+ };
65
+
66
+ template<> struct BytesToType<8> {
67
+ using Type = uint64_t;
68
+ static_assert(sizeof(Type) == 8);
69
+ };
70
+
71
+ template<> struct BytesToType<4> {
72
+ using Type = uint32_t;
73
+ static_assert(sizeof(Type) == 4);
74
+ };
75
+
76
+ template<> struct BytesToType<2> {
77
+ using Type = uint16_t;
78
+ static_assert(sizeof(Type) == 2);
79
+ };
80
+
81
+ template<> struct BytesToType<1> {
82
+ using Type = uint8_t;
83
+ static_assert(sizeof(Type) == 1);
84
+ };
85
+
86
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
87
+
88
+ template<typename scalar_t, int N>
89
+ struct Converter{
90
+ static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
91
+ #pragma unroll
92
+ for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
93
+ }
94
+ };
95
+
96
+ template<int N>
97
+ struct Converter<at::Half, N>{
98
+ static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
99
+ static_assert(N % 2 == 0);
100
+ auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
101
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
102
+ #pragma unroll
103
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
104
+ }
105
+ };
106
+
107
+ #if __CUDA_ARCH__ >= 800
108
+ template<int N>
109
+ struct Converter<at::BFloat16, N>{
110
+ static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
111
+ static_assert(N % 2 == 0);
112
+ auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
113
+ auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
114
+ #pragma unroll
115
+ for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
116
+ }
117
+ };
118
+ #endif
119
+
120
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
121
+
122
+ // From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp
123
+ // and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696
124
+ __device__ __forceinline__ complex_t cexp2f(complex_t z) {
125
+ float t = exp2f(z.real_);
126
+ float c, s;
127
+ sincosf(z.imag_, &s, &c);
128
+ return complex_t(c * t, s * t);
129
+ }
130
+
131
+ __device__ __forceinline__ complex_t cexpf(complex_t z) {
132
+ float t = expf(z.real_);
133
+ float c, s;
134
+ sincosf(z.imag_, &s, &c);
135
+ return complex_t(c * t, s * t);
136
+ }
137
+
138
+ template<typename scalar_t> struct SSMScanOp;
139
+
140
+ template<>
141
+ struct SSMScanOp<float> {
142
+ __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
143
+ return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
144
+ }
145
+ };
146
+
147
+ template<>
148
+ struct SSMScanOp<complex_t> {
149
+ __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const {
150
+ complex_t a0 = complex_t(ab0.x, ab0.y);
151
+ complex_t b0 = complex_t(ab0.z, ab0.w);
152
+ complex_t a1 = complex_t(ab1.x, ab1.y);
153
+ complex_t b1 = complex_t(ab1.z, ab1.w);
154
+ complex_t out_a = a1 * a0;
155
+ complex_t out_b = a1 * b0 + b1;
156
+ return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_);
157
+ }
158
+ };
159
+
160
+ // A stateful callback functor that maintains a running prefix to be applied
161
+ // during consecutive scan operations.
162
+ template <typename scalar_t> struct SSMScanPrefixCallbackOp {
163
+ using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
164
+ scan_t running_prefix;
165
+ // Constructor
166
+ __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
167
+ // Callback operator to be entered by the first warp of threads in the block.
168
+ // Thread-0 is responsible for returning a value for seeding the block-wide scan.
169
+ __device__ scan_t operator()(scan_t block_aggregate) {
170
+ scan_t old_prefix = running_prefix;
171
+ running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
172
+ return old_prefix;
173
+ }
174
+ };
175
+
176
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
177
+
178
+ template<typename Ktraits>
179
+ inline __device__ void load_input(typename Ktraits::input_t *u,
180
+ typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
181
+ typename Ktraits::BlockLoadT::TempStorage &smem_load,
182
+ int seqlen) {
183
+ if constexpr (Ktraits::kIsEvenLen) {
184
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
185
+ using vec_t = typename Ktraits::vec_t;
186
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
187
+ reinterpret_cast<vec_t*>(u),
188
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
189
+ #ifdef USE_ROCM
190
+ , Ktraits::kNThreads * Ktraits::kNLoads
191
+ #endif
192
+
193
+ );
194
+ } else {
195
+ typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
196
+ }
197
+ }
198
+
199
+ template<typename Ktraits>
200
+ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
201
+ typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
202
+ typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
203
+ int seqlen) {
204
+ constexpr int kNItems = Ktraits::kNItems;
205
+ if constexpr (!Ktraits::kIsComplex) {
206
+ typename Ktraits::input_t B_vals_load[kNItems];
207
+ if constexpr (Ktraits::kIsEvenLen) {
208
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
209
+ using vec_t = typename Ktraits::vec_t;
210
+ typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
211
+ reinterpret_cast<vec_t*>(Bvar),
212
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
213
+ );
214
+ } else {
215
+ typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
216
+ }
217
+ // #pragma unroll
218
+ // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
219
+ Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
220
+ } else {
221
+ typename Ktraits::input_t B_vals_load[kNItems * 2];
222
+ if constexpr (Ktraits::kIsEvenLen) {
223
+ auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
224
+ using vec_t = typename Ktraits::vec_t;
225
+ typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
226
+ reinterpret_cast<vec_t*>(Bvar),
227
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads * 2]>(B_vals_load)
228
+ );
229
+ } else {
230
+ typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
231
+ }
232
+ #pragma unroll
233
+ for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); }
234
+ }
235
+ }
236
+
237
+ template<typename Ktraits>
238
+ inline __device__ void store_output(typename Ktraits::input_t *out,
239
+ const float (&out_vals)[Ktraits::kNItems],
240
+ typename Ktraits::BlockStoreT::TempStorage &smem_store,
241
+ int seqlen) {
242
+ typename Ktraits::input_t write_vals[Ktraits::kNItems];
243
+ #pragma unroll
244
+ for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
245
+ if constexpr (Ktraits::kIsEvenLen) {
246
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
247
+ using vec_t = typename Ktraits::vec_t;
248
+ typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
249
+ reinterpret_cast<vec_t*>(out),
250
+ reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
251
+ );
252
+ } else {
253
+ typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
254
+ }
255
+ }
selective-scan/selective_scan_fwd_bf16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<at::BFloat16, complex_t>(SSMParamsBase &params, cudaStream_t stream);
selective-scan/selective_scan_fwd_fp16.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<at::Half, complex_t>(SSMParamsBase &params, cudaStream_t stream);
selective-scan/selective_scan_fwd_fp32.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ // Split into multiple files to compile in paralell
6
+
7
+ #include "selective_scan_fwd_kernel.cuh"
8
+
9
+ template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
10
+ template void selective_scan_fwd_cuda<float, complex_t>(SSMParamsBase &params, cudaStream_t stream);
selective-scan/selective_scan_fwd_kernel.cuh ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #include <c10/util/BFloat16.h>
8
+ #include <c10/util/Half.h>
9
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
10
+
11
+ #ifndef USE_ROCM
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_store.cuh>
14
+ #include <cub/block/block_scan.cuh>
15
+ #else
16
+ #include <hipcub/hipcub.hpp>
17
+ namespace cub = hipcub;
18
+ #endif
19
+
20
+ #include "selective_scan.h"
21
+ #include "selective_scan_common.h"
22
+ #include "static_switch.h"
23
+
24
+ template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
25
+ bool kIsVariableB_, bool kIsVariableC_,
26
+ bool kHasZ_, typename input_t_, typename weight_t_>
27
+ struct Selective_Scan_fwd_kernel_traits {
28
+ static_assert(kNItems_ % 4 == 0);
29
+ using input_t = input_t_;
30
+ using weight_t = weight_t_;
31
+ static constexpr int kNThreads = kNThreads_;
32
+ // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
33
+ static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
34
+ static constexpr int kNItems = kNItems_;
35
+ static constexpr int kNRows = kNRows_;
36
+ static constexpr int kNBytes = sizeof(input_t);
37
+ static_assert(kNBytes == 2 || kNBytes == 4);
38
+ static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
39
+ static_assert(kNItems % kNElts == 0);
40
+ static constexpr int kNLoads = kNItems / kNElts;
41
+ static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
42
+ static constexpr bool kIsEvenLen = kIsEvenLen_;
43
+ static constexpr bool kIsVariableB = kIsVariableB_;
44
+ static constexpr bool kIsVariableC = kIsVariableC_;
45
+ static constexpr bool kHasZ = kHasZ_;
46
+
47
+ static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
48
+
49
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
50
+ using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
51
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
53
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
54
+ using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
55
+ using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
56
+ !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
57
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
58
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
59
+ !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
60
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
61
+ // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
62
+ using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
63
+ static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
64
+ sizeof(typename BlockLoadVecT::TempStorage),
65
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
66
+ (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
67
+ sizeof(typename BlockStoreT::TempStorage),
68
+ sizeof(typename BlockStoreVecT::TempStorage)});
69
+ static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
70
+ };
71
+
72
+ template<typename Ktraits>
73
+ __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
74
+ void selective_scan_fwd_kernel(SSMParamsBase params) {
75
+ constexpr bool kIsComplex = Ktraits::kIsComplex;
76
+ constexpr bool kIsVariableB = Ktraits::kIsVariableB;
77
+ constexpr bool kIsVariableC = Ktraits::kIsVariableC;
78
+ constexpr bool kHasZ = Ktraits::kHasZ;
79
+ constexpr int kNThreads = Ktraits::kNThreads;
80
+ constexpr int kNItems = Ktraits::kNItems;
81
+ constexpr int kNRows = Ktraits::kNRows;
82
+ constexpr bool kDirectIO = Ktraits::kDirectIO;
83
+ using input_t = typename Ktraits::input_t;
84
+ using weight_t = typename Ktraits::weight_t;
85
+ using scan_t = typename Ktraits::scan_t;
86
+
87
+ // Shared memory.
88
+ extern __shared__ char smem_[];
89
+ // cast to lvalue reference of expected type
90
+ // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
91
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
92
+ // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
93
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
94
+ auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
95
+ auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
96
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
97
+ auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
98
+ // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
99
+ // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
100
+ scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
101
+
102
+ const int batch_id = blockIdx.x;
103
+ const int dim_id = blockIdx.y;
104
+ const int group_id = dim_id / (params.dim_ngroups_ratio);
105
+ input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
106
+ + dim_id * kNRows * params.u_d_stride;
107
+ input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
108
+ + dim_id * kNRows * params.delta_d_stride;
109
+ weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
110
+ weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
111
+ input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
112
+ weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
113
+ input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
114
+ scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
115
+
116
+ float D_val[kNRows] = {0};
117
+ if (params.D_ptr != nullptr) {
118
+ #pragma unroll
119
+ for (int r = 0; r < kNRows; ++r) {
120
+ D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
121
+ }
122
+ }
123
+ float delta_bias[kNRows] = {0};
124
+ if (params.delta_bias_ptr != nullptr) {
125
+ #pragma unroll
126
+ for (int r = 0; r < kNRows; ++r) {
127
+ delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
128
+ }
129
+ }
130
+
131
+ // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
132
+ // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
133
+ // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
134
+ // }
135
+
136
+ constexpr int kChunkSize = kNThreads * kNItems;
137
+ for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
138
+ input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
139
+ __syncthreads();
140
+ #pragma unroll
141
+ for (int r = 0; r < kNRows; ++r) {
142
+ if constexpr (!kDirectIO) {
143
+ if (r > 0) { __syncthreads(); }
144
+ }
145
+ load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
146
+ if constexpr (!kDirectIO) { __syncthreads(); }
147
+ load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
148
+ }
149
+ u += kChunkSize;
150
+ delta += kChunkSize;
151
+
152
+ float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
153
+ #pragma unroll
154
+ for (int r = 0; r < kNRows; ++r) {
155
+ #pragma unroll
156
+ for (int i = 0; i < kNItems; ++i) {
157
+ float u_val = float(u_vals[r][i]);
158
+ delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
159
+ if (params.delta_softplus) {
160
+ delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
161
+ }
162
+ delta_u_vals[r][i] = delta_vals[r][i] * u_val;
163
+ out_vals[r][i] = D_val[r] * u_val;
164
+ }
165
+ }
166
+
167
+ __syncthreads();
168
+ for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
169
+ weight_t A_val[kNRows];
170
+ #pragma unroll
171
+ for (int r = 0; r < kNRows; ++r) {
172
+ A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
173
+ // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
174
+ constexpr float kLog2e = M_LOG2E;
175
+ if constexpr (!kIsComplex) {
176
+ A_val[r] *= kLog2e;
177
+ } else {
178
+ A_val[r].real_ *= kLog2e;
179
+ }
180
+ }
181
+ // This variable holds B * C if both B and C are constant across seqlen. If only B varies
182
+ // across seqlen, this holds C. If only C varies across seqlen, this holds B.
183
+ // If both B and C vary, this is unused.
184
+ weight_t BC_val[kNRows];
185
+ weight_t B_vals[kNItems], C_vals[kNItems];
186
+ if constexpr (kIsVariableB) {
187
+ load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
188
+ smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
189
+ if constexpr (!kIsVariableC) {
190
+ #pragma unroll
191
+ for (int r = 0; r < kNRows; ++r) {
192
+ BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
193
+ }
194
+ }
195
+ }
196
+ if constexpr (kIsVariableC) {
197
+ auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
198
+ load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
199
+ smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
200
+ if constexpr (!kIsVariableB) {
201
+ #pragma unroll
202
+ for (int r = 0; r < kNRows; ++r) {
203
+ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
204
+ }
205
+ }
206
+ }
207
+ if constexpr (!kIsVariableB && !kIsVariableC) {
208
+ #pragma unroll
209
+ for (int r = 0; r < kNRows; ++r) {
210
+ BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
211
+ }
212
+ }
213
+
214
+ #pragma unroll
215
+ for (int r = 0; r < kNRows; ++r) {
216
+ if (r > 0) { __syncthreads(); } // Scan could be using the same smem
217
+ scan_t thread_data[kNItems];
218
+ #pragma unroll
219
+ for (int i = 0; i < kNItems; ++i) {
220
+ if constexpr (!kIsComplex) {
221
+ thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
222
+ !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
223
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
224
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
225
+ thread_data[i] = make_float2(1.f, 0.f);
226
+ }
227
+ }
228
+ } else {
229
+ // Pytorch's implementation of complex exp (which calls thrust) is very slow
230
+ complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
231
+ weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
232
+ thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
233
+ if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
234
+ if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
235
+ thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
236
+ }
237
+ }
238
+ }
239
+ }
240
+ // Initialize running total
241
+ scan_t running_prefix;
242
+ if constexpr (!kIsComplex) {
243
+ // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
244
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
245
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
246
+ } else {
247
+ running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
248
+ // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
249
+ }
250
+ SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
251
+ typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
252
+ thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
253
+ );
254
+ // There's a syncthreads in the scan op, so we don't need to sync here.
255
+ // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
256
+ if (threadIdx.x == 0) {
257
+ smem_running_prefix[state_idx] = prefix_op.running_prefix;
258
+ x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
259
+ }
260
+ #pragma unroll
261
+ for (int i = 0; i < kNItems; ++i) {
262
+ const weight_t C_val = !kIsVariableC
263
+ ? BC_val[r]
264
+ : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
265
+ if constexpr (!kIsComplex) {
266
+ out_vals[r][i] += thread_data[i].y * C_val;
267
+ } else {
268
+ out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
269
+ }
270
+ }
271
+ }
272
+ }
273
+
274
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
275
+ + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
276
+ __syncthreads();
277
+ #pragma unroll
278
+ for (int r = 0; r < kNRows; ++r) {
279
+ if constexpr (!kDirectIO) {
280
+ if (r > 0) { __syncthreads(); }
281
+ }
282
+ store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
283
+ }
284
+
285
+ if constexpr (kHasZ) {
286
+ input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
287
+ + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
288
+ input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
289
+ + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
290
+ #pragma unroll
291
+ for (int r = 0; r < kNRows; ++r) {
292
+ input_t z_vals[kNItems];
293
+ __syncthreads();
294
+ load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
295
+ #pragma unroll
296
+ for (int i = 0; i < kNItems; ++i) {
297
+ float z_val = z_vals[i];
298
+ out_vals[r][i] *= z_val / (1 + expf(-z_val));
299
+ }
300
+ __syncthreads();
301
+ store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
302
+ }
303
+ }
304
+
305
+ Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
306
+ Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
307
+ }
308
+ }
309
+
310
+ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
311
+ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
312
+ // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
313
+ // processing 1 row.
314
+ constexpr int kNRows = 1;
315
+ BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
316
+ BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
317
+ BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
318
+ BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
319
+ using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
320
+
321
+ constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
322
+ dim3 grid(params.batch, params.dim / kNRows);
323
+
324
+ // Had to change this substantially since potentially the hip
325
+ // interface for setting kernel launch attributes is slightly different from
326
+ // cuda's. In particualar, it seems to expect a plain const void * pointer.
327
+
328
+ auto kernel = &selective_scan_fwd_kernel<Ktraits>;
329
+
330
+
331
+ if (kSmemSize >= 48 * 1024) {
332
+ #ifndef USE_ROCM
333
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
334
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
335
+ #else
336
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
337
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
338
+ std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
339
+ #endif
340
+ }
341
+
342
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
343
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
344
+ });
345
+ });
346
+ });
347
+ });
348
+ }
349
+
350
+ template<typename input_t, typename weight_t>
351
+ void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
352
+
353
+ #ifndef USE_ROCM
354
+ if (params.seqlen <= 128) {
355
+ selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
356
+ } else if (params.seqlen <= 256) {
357
+ selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
358
+ } else if (params.seqlen <= 512) {
359
+ selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
360
+ } else if (params.seqlen <= 1024) {
361
+ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
362
+ } else {
363
+ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
364
+ }
365
+ #else
366
+ if (params.seqlen <= 256) {
367
+ selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
368
+ } else if (params.seqlen <= 512) {
369
+ selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
370
+ } else if (params.seqlen <= 1024) {
371
+ selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
372
+ } else {
373
+ selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
374
+ }
375
+ #endif
376
+ }
selective-scan/static_switch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
+ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
+
4
+ #pragma once
5
+
6
+ /// @param COND - a boolean expression to switch by
7
+ /// @param CONST_NAME - a name given for the constexpr bool variable.
8
+ /// @param ... - code to execute for true and false
9
+ ///
10
+ /// Usage:
11
+ /// ```
12
+ /// BOOL_SWITCH(flag, BoolConst, [&] {
13
+ /// some_function<BoolConst>(...);
14
+ /// });
15
+ /// ```
16
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
+ [&] { \
18
+ if (COND) { \
19
+ constexpr bool CONST_NAME = true; \
20
+ return __VA_ARGS__(); \
21
+ } else { \
22
+ constexpr bool CONST_NAME = false; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ }()
selective-scan/uninitialized_copy.cuh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Redistribution and use in source and binary forms, with or without
5
+ * modification, are permitted provided that the following conditions are met:
6
+ * * Redistributions of source code must retain the above copyright
7
+ * notice, this list of conditions and the following disclaimer.
8
+ * * Redistributions in binary form must reproduce the above copyright
9
+ * notice, this list of conditions and the following disclaimer in the
10
+ * documentation and/or other materials provided with the distribution.
11
+ * * Neither the name of the NVIDIA CORPORATION nor the
12
+ * names of its contributors may be used to endorse or promote products
13
+ * derived from this software without specific prior written permission.
14
+ *
15
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18
+ * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
19
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ *
26
+ ******************************************************************************/
27
+
28
+ #pragma once
29
+
30
+ #ifndef USE_ROCM
31
+ #include <cub/config.cuh>
32
+
33
+ #include <cuda/std/type_traits>
34
+ #else
35
+ #include <hipcub/hipcub.hpp>
36
+ // Map ::cuda::std to the standard std namespace
37
+ namespace cuda {
38
+ namespace std = ::std;
39
+ }
40
+ #endif
41
+
42
+
43
+ namespace detail
44
+ {
45
+
46
+ #if defined(_NVHPC_CUDA)
47
+ template <typename T, typename U>
48
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
49
+ {
50
+ // NVBug 3384810
51
+ new (ptr) T(::cuda::std::forward<U>(val));
52
+ }
53
+ #else
54
+ template <typename T,
55
+ typename U,
56
+ typename ::cuda::std::enable_if<
57
+ ::cuda::std::is_trivially_copyable<T>::value,
58
+ int
59
+ >::type = 0>
60
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
61
+ {
62
+ *ptr = ::cuda::std::forward<U>(val);
63
+ }
64
+
65
+ template <typename T,
66
+ typename U,
67
+ typename ::cuda::std::enable_if<
68
+ !::cuda::std::is_trivially_copyable<T>::value,
69
+ int
70
+ >::type = 0>
71
+ __host__ __device__ void uninitialized_copy(T *ptr, U &&val)
72
+ {
73
+ new (ptr) T(::cuda::std::forward<U>(val));
74
+ }
75
+ #endif
76
+
77
+ } // namespace detail
tests/ops/test_selective_scan.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import pytest
8
+
9
+ from einops import rearrange
10
+
11
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
12
+ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
13
+
14
+
15
+ # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
16
+ @pytest.mark.parametrize('wtype', [torch.float32])
17
+ # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
18
+ @pytest.mark.parametrize('itype', [torch.float32])
19
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
20
+ @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
21
+ # @pytest.mark.parametrize('seqlen', [128])
22
+ # @pytest.mark.parametrize("return_last_state", [False, True])
23
+ @pytest.mark.parametrize("return_last_state", [True])
24
+ # @pytest.mark.parametrize('has_delta_bias', [False, True])
25
+ @pytest.mark.parametrize('has_delta_bias', [True])
26
+ # @pytest.mark.parametrize('delta_softplus', [False, True])
27
+ @pytest.mark.parametrize('delta_softplus', [True])
28
+ # @pytest.mark.parametrize('has_z', [False, True])
29
+ @pytest.mark.parametrize('has_z', [True])
30
+ # @pytest.mark.parametrize('has_D', [False, True])
31
+ @pytest.mark.parametrize('has_D', [True])
32
+ @pytest.mark.parametrize("varBC_groups", [1, 2])
33
+ # @pytest.mark.parametrize("varBC_groups", [1])
34
+ # @pytest.mark.parametrize("is_variable_C", [False, True])
35
+ @pytest.mark.parametrize("is_variable_C", [True])
36
+ # @pytest.mark.parametrize("is_variable_B", [False, True])
37
+ @pytest.mark.parametrize("is_variable_B", [True])
38
+ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
39
+ delta_softplus, return_last_state, seqlen, itype, wtype):
40
+ if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
41
+ pytest.skip() # This config is not applicable
42
+ device = 'cuda'
43
+ rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
44
+ if itype == torch.bfloat16:
45
+ rtol, atol = 3e-2, 5e-2
46
+ rtolw, atolw = (1e-3, 1e-3)
47
+ if has_z: # If we have z, the errors on the weights seem higher
48
+ rtolw = max(rtolw, rtol)
49
+ atolw = max(atolw, atol)
50
+ # set seed
51
+ torch.random.manual_seed(0)
52
+ batch_size = 2
53
+ dim = 4
54
+ dstate = 8
55
+ is_complex = wtype == torch.complex64
56
+ A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
57
+ if not is_variable_B:
58
+ B_shape = (dim, dstate)
59
+ elif varBC_groups == 1:
60
+ B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
61
+ else:
62
+ B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
63
+ B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
64
+ requires_grad=True)
65
+ if not is_variable_C:
66
+ C_shape = (dim, dstate)
67
+ elif varBC_groups == 1:
68
+ C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
69
+ else:
70
+ C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
71
+ C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
72
+ requires_grad=True)
73
+ if has_D:
74
+ D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
75
+ else:
76
+ D = None
77
+ if has_z:
78
+ z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
79
+ else:
80
+ z = None
81
+ if has_delta_bias:
82
+ delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
83
+ else:
84
+ delta_bias = None
85
+ u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
86
+ delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)).requires_grad_()
87
+ A_ref = A.detach().clone().requires_grad_()
88
+ B_ref = B.detach().clone().requires_grad_()
89
+ C_ref = C.detach().clone().requires_grad_()
90
+ D_ref = D.detach().clone().requires_grad_() if D is not None else None
91
+ z_ref = z.detach().clone().requires_grad_() if z is not None else None
92
+ u_ref = u.detach().clone().requires_grad_()
93
+ delta_ref = delta.detach().clone().requires_grad_()
94
+ delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
95
+ out, *rest = selective_scan_fn(
96
+ u, delta, A, B, C, D, z=z,
97
+ delta_bias=delta_bias, delta_softplus=delta_softplus,
98
+ return_last_state=return_last_state
99
+ )
100
+ if return_last_state:
101
+ state = rest[0]
102
+ out_ref, *rest = selective_scan_ref(
103
+ u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
104
+ delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
105
+ return_last_state=return_last_state
106
+ )
107
+ if return_last_state:
108
+ state_ref = rest[0]
109
+ # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
110
+ # dt_u = delta * u
111
+
112
+ print(f'Output max diff: {(out - out_ref).abs().max().item()}')
113
+ print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
114
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
115
+ if return_last_state:
116
+ print(f'State max diff: {(state - state_ref).abs().max().item()}')
117
+ assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
118
+
119
+ g = torch.randn_like(out)
120
+ out_ref.backward(g)
121
+ out.backward(g)
122
+
123
+ print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
124
+ print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
125
+ print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
126
+ print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
127
+ print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
128
+ if has_D:
129
+ print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
130
+ if has_z:
131
+ print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
132
+ if has_delta_bias:
133
+ print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
134
+
135
+ assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
136
+ assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
137
+ assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
138
+ assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
139
+ atol=atolw if not is_variable_B else atol)
140
+ assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
141
+ atol=atolw if not is_variable_C else atol)
142
+ if has_D:
143
+ assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
144
+ if has_z:
145
+ assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
146
+ if has_delta_bias:
147
+ assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
148
+
149
+
150
+ @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
151
+ # @pytest.mark.parametrize('wtype', [torch.complex64])
152
+ # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
153
+ @pytest.mark.parametrize('itype', [torch.float32])
154
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
155
+ @pytest.mark.parametrize('seqlen', [128])
156
+ @pytest.mark.parametrize("is_variable_C", [False, True])
157
+ # @pytest.mark.parametrize("is_variable_C", [False])
158
+ @pytest.mark.parametrize("is_variable_B", [False, True])
159
+ # @pytest.mark.parametrize("is_variable_B", [True])
160
+ def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
161
+ device = 'cuda'
162
+ rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
163
+ if itype == torch.bfloat16:
164
+ rtol, atol = 3e-2, 5e-2
165
+ rtolw, atolw = (1e-3, 1e-3)
166
+ # If we have z, the errors on the weights seem higher
167
+ rtolw = max(rtolw, rtol)
168
+ atolw = max(atolw, atol)
169
+ # set seed
170
+ torch.random.manual_seed(0)
171
+ batch_size = 2
172
+ dim = 768
173
+ dstate = 8
174
+ dt_rank = 48
175
+ is_complex = wtype == torch.complex64
176
+ xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
177
+ conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
178
+ conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
179
+ x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
180
+ * (1 if not is_complex else 2),
181
+ dim, device=device, dtype=itype, requires_grad=True)
182
+ delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
183
+ out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
184
+ out_proj_bias = None
185
+ A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
186
+ B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
187
+ if not is_variable_B else None)
188
+ C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
189
+ if not is_variable_C else None)
190
+ D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
191
+ delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
192
+ B_proj_bias = None
193
+ C_proj_bias = None
194
+ xz_ref = xz.detach().clone().requires_grad_()
195
+ conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
196
+ conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
197
+ x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
198
+ delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
199
+ out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
200
+ out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
201
+ if out_proj_bias is not None else None)
202
+ A_ref = A.detach().clone().requires_grad_()
203
+ B_ref = B.detach().clone().requires_grad_() if B is not None else None
204
+ C_ref = C.detach().clone().requires_grad_() if C is not None else None
205
+ D_ref = D.detach().clone().requires_grad_()
206
+ delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
207
+ out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
208
+ out_proj_weight, out_proj_bias,
209
+ A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
210
+ out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
211
+ delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
212
+ A_ref, B_ref, C_ref, D_ref,
213
+ delta_bias=delta_bias_ref, delta_softplus=True)
214
+ # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
215
+ # dt_u = delta * u
216
+
217
+ print(f'Output max diff: {(out - out_ref).abs().max().item()}')
218
+ print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
219
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
220
+
221
+ g = torch.randn_like(out)
222
+ out_ref.backward(g)
223
+ out.backward(g)
224
+
225
+ print(f'dxz max diff: {(xz.grad - xz_ref.grad).abs().max().item()}')
226
+ print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
227
+ if not is_variable_B:
228
+ print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
229
+ if not is_variable_C:
230
+ print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
231
+ print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
232
+ print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
233
+ print(f'dout_proj_weight max diff: {(out_proj_weight.grad - out_proj_weight_ref.grad).abs().max().item()}')
234
+ print(f'ddelta_proj_weight max diff: {(delta_proj_weight.grad - delta_proj_weight_ref.grad).abs().max().item()}')
235
+ print(f'dx_proj_weight max diff: {(x_proj_weight.grad - x_proj_weight_ref.grad).abs().max().item()}')
236
+ print(f'dconv1d_weight max diff: {(conv1d_weight.grad - conv1d_weight_ref.grad).abs().max().item()}')
237
+ print(f'dconv1d_bias max diff: {(conv1d_bias.grad - conv1d_bias_ref.grad).abs().max().item()}')
238
+
239
+ # assert torch.allclose(xz.grad, xz_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
240
+ # assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
241
+ # assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
242
+ # assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
243
+ # atol=atolw if not is_variable_B else atol)
244
+ # assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
245
+ # atol=atolw if not is_variable_C else atol)
246
+ # assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
247
+ # assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
tests/ops/triton/test_layernorm_gated.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ import pytest
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref
11
+
12
+
13
+ @pytest.mark.parametrize("norm_before_gate", [True, False])
14
+ # @pytest.mark.parametrize("norm_before_gate", [False])
15
+ @pytest.mark.parametrize("has_group", [False, True])
16
+ # @pytest.mark.parametrize("has_group", [False])
17
+ @pytest.mark.parametrize("is_rms_norm", [False, True])
18
+ # @pytest.mark.parametrize("is_rms_norm", [True])
19
+ @pytest.mark.parametrize("has_z", [False, True])
20
+ # @pytest.mark.parametrize("has_z", [True])
21
+ @pytest.mark.parametrize("has_bias", [False, True])
22
+ # @pytest.mark.parametrize("has_bias", [False])
23
+ # @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
24
+ @pytest.mark.parametrize('dtype', [torch.float16])
25
+ # @pytest.mark.parametrize("wtype", [torch.float32, torch.float16, torch.bfloat16])
26
+ @pytest.mark.parametrize("wtype", [torch.float32])
27
+ @pytest.mark.parametrize('d', [2048, 4096])
28
+ # @pytest.mark.parametrize('d', [4096])
29
+ def test_layer_norm_gated(d, dtype, wtype, has_bias, has_z, is_rms_norm, has_group, norm_before_gate):
30
+ if not has_z and not norm_before_gate:
31
+ pytest.skip()
32
+ if not norm_before_gate and not is_rms_norm: # Reference LN isn't implemented for this case yet
33
+ pytest.skip()
34
+ device = 'cuda'
35
+ rtol, atol = (1e-5, 1e-5) if dtype == torch.float32 else (1e-2, 8e-3)
36
+ group_size = None if not has_group else 64
37
+ # set seed
38
+ torch.random.manual_seed(0)
39
+ batch = 16
40
+ seqlen = 1024
41
+ x = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
42
+ if has_z:
43
+ z = torch.randn(batch, seqlen, d, dtype=dtype, device=device, requires_grad=True)
44
+ else:
45
+ z = None
46
+ weight = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
47
+ if has_bias:
48
+ bias = torch.randn(d, dtype=wtype, device=device, requires_grad=True)
49
+ else:
50
+ bias = None
51
+ x_ref = x.detach().clone().requires_grad_()
52
+ x_pt = x.detach().clone().requires_grad_()
53
+ z_ref = z.detach().clone().requires_grad_() if z is not None else None
54
+ z_pt = z.detach().clone().requires_grad_() if z is not None else None
55
+ weight_ref = weight.detach().clone().requires_grad_()
56
+ weight_pt = weight.detach().clone().requires_grad_()
57
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
58
+ bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
59
+ out = layernorm_fn(x, weight, bias, z=z, eps=1e-5, group_size=group_size, norm_before_gate=norm_before_gate,
60
+ is_rms_norm=is_rms_norm)
61
+ if not is_rms_norm:
62
+ if not has_group:
63
+ out_ref = F.layer_norm(x_ref.float(), (d,), weight=weight_ref.float(), bias=bias_ref.float() if bias_ref is not None else None, eps=1e-5)
64
+ out_pt = F.layer_norm(x_pt.to(wtype), (d,), weight=weight_pt, bias=bias_pt, eps=1e-5)
65
+ else:
66
+ out_ref = rearrange(F.layer_norm(rearrange(x_ref, "... (g d) -> ... g d", d=group_size).float(), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_ref.float()
67
+ if has_bias:
68
+ out_ref = out_ref + bias_ref.float()
69
+ out_pt = rearrange(F.layer_norm(rearrange(x_pt, "... (g d) -> ... g d", d=group_size), (group_size,), eps=1e-5), "... g d -> ... (g d)") * weight_pt
70
+ if has_bias:
71
+ out_pt = out_pt + bias_pt
72
+ if has_z and norm_before_gate:
73
+ out_ref = out_ref * F.silu(z_ref.float())
74
+ out_pt = out_pt * F.silu(z_pt)
75
+ else:
76
+ out_ref = rms_norm_ref(x_ref, weight_ref, bias_ref, z=z_ref, eps=1e-5, group_size=group_size,
77
+ norm_before_gate=norm_before_gate)
78
+ out_pt = rms_norm_ref(x_pt, weight_pt, bias_pt, z=z_pt, eps=1e-5, group_size=group_size,
79
+ norm_before_gate=norm_before_gate, upcast=False)
80
+ print(f"Max diff = {(out - out_ref).abs().max().item()}")
81
+ print(f"Max diff Pytorch = {(out_pt - out_ref).abs().max().item()}")
82
+ assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + atol
83
+
84
+ g = torch.randn_like(out)
85
+ out.backward(g)
86
+ out_ref.backward(g)
87
+ out_pt.backward(g)
88
+ print(f"Max dx diff = {(x.grad - x_ref.grad).abs().max().item()}")
89
+ print(f"Max dx diff Pytorch = {(x_pt.grad - x_ref.grad).abs().max().item()}")
90
+ if has_z:
91
+ print(f"Max dz diff = {(z.grad - z_ref.grad).abs().max().item()}")
92
+ print(f"Max dz diff Pytorch = {(z_pt.grad - z_ref.grad).abs().max().item()}")
93
+ print(f"Max dw diff = {(weight.grad - weight_ref.grad).abs().max().item()}")
94
+ print(f"Max dw diff Pytorch = {(weight_pt.grad - weight_ref.grad).abs().max().item()}")
95
+ if has_bias:
96
+ print(f"Max db diff = {(bias.grad - bias_ref.grad).abs().max().item()}")
97
+ print(f"Max db diff Pytorch = {(bias_pt.grad - bias_ref.grad).abs().max().item()}")
98
+ assert (x.grad - x_ref.grad).abs().max().item() <= 2 * (x_pt.grad - x_ref.grad).abs().max().item() + atol
99
+ if has_z:
100
+ assert (z.grad - z_ref.grad).abs().max().item() <= 2 * (z_pt.grad - z_ref.grad).abs().max().item() + atol
101
+ assert (weight.grad - weight_ref.grad).abs().max().item() <= 2 * (weight_pt.grad - weight_ref.grad).abs().max().item() + atol
102
+ if has_bias:
103
+ assert (bias.grad - bias_ref.grad).abs().max().item() <= 2 * (bias_pt.grad - bias_ref.grad).abs().max().item() + atol
tests/ops/triton/test_selective_state_update.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import pytest
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref
12
+
13
+
14
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
15
+ # @pytest.mark.parametrize('itype', [torch.float16])
16
+ @pytest.mark.parametrize("has_z", [False, True])
17
+ # @pytest.mark.parametrize('has_z', [True])
18
+ @pytest.mark.parametrize("dstate", [16, 32, 64])
19
+ # @pytest.mark.parametrize("dstate", [16])
20
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
21
+ # @pytest.mark.parametrize("dim", [2048])
22
+ def test_selective_state_update(dim, dstate, has_z, itype):
23
+ device = "cuda"
24
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
25
+ if itype == torch.bfloat16:
26
+ rtol, atol = 1e-2, 5e-2
27
+ if torch.version.hip:
28
+ atol *= 2
29
+ # set seed
30
+ torch.random.manual_seed(0)
31
+ batch_size = 2
32
+ state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
33
+ x = torch.randn(batch_size, dim, device=device, dtype=itype)
34
+ dt = torch.randn(batch_size, dim, device=device, dtype=itype)
35
+ dt_bias = torch.rand(dim, device=device) - 4.0
36
+ A = -torch.rand(dim, dstate, device=device) - 1.0
37
+ B = torch.randn(batch_size, dstate, device=device)
38
+ C = torch.randn(batch_size, dstate, device=device)
39
+ D = torch.randn(dim, device=device)
40
+ if has_z:
41
+ z = torch.randn_like(x)
42
+ else:
43
+ z = None
44
+ state_ref = state.detach().clone()
45
+ out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
46
+ out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
47
+
48
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
49
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
50
+ assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
51
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
52
+
53
+
54
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
55
+ # @pytest.mark.parametrize('itype', [torch.float16])
56
+ @pytest.mark.parametrize("has_z", [False, True])
57
+ # @pytest.mark.parametrize('has_z', [True])
58
+ @pytest.mark.parametrize("tie_hdim", [False, True])
59
+ # @pytest.mark.parametrize('tie_hdim', [True])
60
+ @pytest.mark.parametrize("ngroups", [1, 2, 4])
61
+ # @pytest.mark.parametrize("ngroups", [2])
62
+ @pytest.mark.parametrize("dstate", [16, 32, 64])
63
+ # @pytest.mark.parametrize("dstate", [16])
64
+ @pytest.mark.parametrize("dim", [2048, 4096])
65
+ # @pytest.mark.parametrize("dim", [2048])
66
+ def test_selective_state_update_with_heads(dim, dstate, ngroups, has_z, tie_hdim, itype):
67
+ device = "cuda"
68
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
69
+ if itype == torch.bfloat16:
70
+ rtol, atol = 1e-2, 1e-1
71
+ # set seed
72
+ torch.random.manual_seed(0)
73
+ batch_size = 2
74
+ headdim = 64
75
+ nheads = dim // headdim
76
+ state = torch.randn(batch_size, nheads, headdim, dstate, dtype=itype, device=device)
77
+ x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
78
+ if not tie_hdim:
79
+ dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
80
+ dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
81
+ A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
82
+ D = torch.randn(nheads, headdim, device=device)
83
+ else:
84
+ dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
85
+ dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
86
+ A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
87
+ D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
88
+ B = torch.randn(batch_size, ngroups, dstate, device=device)
89
+ C = torch.randn(batch_size, ngroups, dstate, device=device)
90
+ if has_z:
91
+ z = torch.randn_like(x)
92
+ else:
93
+ z = None
94
+ state_ref = state.detach().clone()
95
+ state_og = state.detach().clone()
96
+ out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
97
+ out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
98
+
99
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
100
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
101
+ assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
102
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
103
+
104
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
105
+ # @pytest.mark.parametrize('itype', [torch.float16])
106
+ @pytest.mark.parametrize("has_z", [False, True])
107
+ # @pytest.mark.parametrize('has_z', [True])
108
+ @pytest.mark.parametrize("dstate", [16, 32, 64])
109
+ # @pytest.mark.parametrize("dstate", [16])
110
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
111
+ # @pytest.mark.parametrize("dim", [2048])
112
+ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
113
+ device = "cuda"
114
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
115
+ if itype == torch.bfloat16:
116
+ rtol, atol = 6e-2, 6e-2
117
+ if torch.version.hip:
118
+ atol *= 2
119
+ # set seed
120
+ torch.random.manual_seed(0)
121
+ batch_size = 16
122
+
123
+ total_entries = 10 * batch_size
124
+ state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
125
+ state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)
126
+
127
+ x = torch.randn(batch_size, dim, device=device, dtype=itype)
128
+ dt = torch.randn(batch_size, dim, device=device, dtype=itype)
129
+ dt_bias = torch.rand(dim, device=device) - 4.0
130
+ A = -torch.rand(dim, dstate, device=device) - 1.0
131
+ B = torch.randn(batch_size, dstate, device=device)
132
+ C = torch.randn(batch_size, dstate, device=device)
133
+ D = torch.randn(dim, device=device)
134
+ if has_z:
135
+ z = torch.randn_like(x)
136
+ else:
137
+ z = None
138
+ state_ref = state[state_indices,:].detach().clone()
139
+ out = selective_state_update(state, x, dt, A, B, C, D=D, z=z,
140
+ dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)
141
+ out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
142
+
143
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
144
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
145
+ assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)
146
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
147
+
148
+
149
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
150
+ #@pytest.mark.parametrize('itype', [torch.float32])
151
+ @pytest.mark.parametrize("has_z", [False, True])
152
+ # @pytest.mark.parametrize('has_z', [True])
153
+ @pytest.mark.parametrize("tie_hdim", [False, True])
154
+ # @pytest.mark.parametrize('tie_hdim', [True])
155
+ @pytest.mark.parametrize("ngroups", [1, 2, 4])
156
+ # @pytest.mark.parametrize("ngroups", [2])
157
+ @pytest.mark.parametrize("dstate", [16, 32, 64])
158
+ # @pytest.mark.parametrize("dstate", [16])
159
+ @pytest.mark.parametrize("dim", [2048, 4096])
160
+ # @pytest.mark.parametrize("dim", [2048])
161
+ def test_selective_state_update_with_heads_with_batch_indices(dim, dstate, ngroups, has_z, tie_hdim, itype):
162
+ device = "cuda"
163
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
164
+ if itype == torch.bfloat16:
165
+ rtol, atol = 1e-1, 1e-1
166
+ # set seed
167
+ torch.random.manual_seed(0)
168
+ batch_size = 16
169
+ headdim = 64
170
+ nheads = dim // headdim
171
+
172
+ total_entries = 10 * batch_size
173
+ state = torch.randn(total_entries, nheads, headdim, dstate, dtype=itype, device=device)
174
+ state_indices = torch.randperm(total_entries)[:batch_size].to(dtype=torch.int32, device=device)
175
+
176
+ x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
177
+ if not tie_hdim:
178
+ dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
179
+ dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
180
+ A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
181
+ D = torch.randn(nheads, headdim, device=device)
182
+ else:
183
+ dt = repeat(torch.randn(batch_size, nheads, device=device, dtype=itype), "b h -> b h p", p=headdim)
184
+ dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
185
+ A = repeat(-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate)
186
+ D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
187
+ B = torch.randn(batch_size, ngroups, dstate, device=device)
188
+ C = torch.randn(batch_size, ngroups, dstate, device=device)
189
+ if has_z:
190
+ z = torch.randn_like(x)
191
+ else:
192
+ z = None
193
+ state_ref = state[state_indices,:].detach().clone()
194
+ state_og = state[state_indices,:].detach().clone()
195
+ out = selective_state_update(state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, state_batch_indices=state_indices)
196
+ out_ref = selective_state_update_ref(state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True)
197
+
198
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
199
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
200
+ assert torch.allclose(state[state_indices,:], state_ref, rtol=rtol, atol=atol)
201
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
tests/ops/triton/test_ssd.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ import pytest
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
11
+ from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
12
+ from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
13
+ from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
14
+ from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd
15
+ from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
16
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan
17
+ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref
18
+
19
+
20
+ def detach_clone(*args):
21
+ return tuple([arg.detach().clone().requires_grad_() if arg is not None else None for arg in args])
22
+
23
+
24
+ @pytest.mark.parametrize('dtype', [torch.float32, torch.float16, torch.bfloat16])
25
+ # @pytest.mark.parametrize('dtype', [torch.bfloat16])
26
+ @pytest.mark.parametrize('ngroups', [1, 2, 8, "max"])
27
+ # @pytest.mark.parametrize('ngroups', [1])
28
+ @pytest.mark.parametrize('chunk_size', [64, 128])
29
+ # @pytest.mark.parametrize('chunk_size', [128])
30
+ def test_chunk_state_varlen(chunk_size, ngroups, dtype):
31
+ device = 'cuda'
32
+ rtol, atol = (1e-2, 3e-3)
33
+ # set seed
34
+ torch.random.manual_seed(chunk_size + (ngroups if ngroups != "max" else 64))
35
+ batch = 300
36
+ seqlens = torch.randint(1, 200, (batch,), device=device)
37
+ # batch = 3
38
+ # seqlens = torch.tensor([201, 56, 5], device=device)
39
+ cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
40
+ total_seqlen = seqlens.sum().item()
41
+ seq_idx = torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(seqlens)], dim=0).unsqueeze(0)
42
+ dim = 4096
43
+ # dim = 64
44
+ headdim = 64
45
+ # dim = 32
46
+ dstate = 32
47
+ assert dim % headdim == 0
48
+ nheads = dim // headdim
49
+ if ngroups == "max":
50
+ ngroups = nheads
51
+ assert nheads % ngroups == 0
52
+ B = torch.randn(total_seqlen, ngroups, dstate, dtype=dtype, device=device) / 5
53
+ x = torch.randn(total_seqlen, nheads, headdim, dtype=dtype, device=device)
54
+ A = -0.1 * (torch.rand(nheads, device=device))
55
+ dt = F.softplus(torch.randn(total_seqlen, nheads, device=device, dtype=torch.float32) - 4)
56
+ dA_cumsum, dt_rounded = _chunk_cumsum_fwd(dt.unsqueeze(0), A, chunk_size)
57
+ chunk_states = _chunk_state_fwd(B.unsqueeze(0), x.unsqueeze(0), dt_rounded, dA_cumsum, seq_idx=seq_idx)
58
+ chunk_states, _ = _state_passing_fwd(rearrange(chunk_states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
59
+ seq_idx=seq_idx, chunk_size=chunk_size)
60
+ chunk_states = rearrange(chunk_states, "... (p n) -> ... p n", n=dstate)
61
+ chunk_states = chunk_states.squeeze(0)
62
+ dA_cumsum = dA_cumsum.squeeze(0)
63
+ dt_rounded = dt_rounded.squeeze(0)
64
+ out = chunk_state_varlen(B, x, dt_rounded, dA_cumsum, cu_seqlens, chunk_states)
65
+ out_ref = []
66
+ for b in range(batch):
67
+ x_s = x[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
68
+ B_s = B[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
69
+ dt_s = dt[cu_seqlens[b]:cu_seqlens[b + 1]].unsqueeze(0)
70
+ dA_cumsum_s, dt_rounded_s = _chunk_cumsum_fwd(dt_s, A, chunk_size)
71
+ states = chunk_state(B_s, x_s, dt_rounded_s, dA_cumsum_s)
72
+ _, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum_s[:, :, :, -1],
73
+ chunk_size=chunk_size)
74
+ final_states = rearrange(final_states, "... (p n) -> ... p n", n=dstate)
75
+ out_ref.append(final_states)
76
+ out_ref = torch.cat(out_ref, dim=0)
77
+ print(f"Max diff = {(out - out_ref).abs().max().item()}")
78
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
tests/test_generation.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
5
+ from mamba_ssm.models.config_mamba import MambaConfig
6
+ from mamba_ssm.utils.generation import InferenceParams
7
+
8
+ import pytest
9
+
10
+ from einops import rearrange, repeat
11
+
12
+
13
+ def test_generation():
14
+ batch = 3
15
+ seqlen = 20
16
+ device = "cuda"
17
+ dtype = torch.float16
18
+
19
+ config = MambaConfig(
20
+ d_model=1024,
21
+ n_layer=4,
22
+ vocab_size=50277,
23
+ ssm_cfg=dict(layer="Mamba2"),
24
+ rms_norm=True,
25
+ residual_in_fp32=True,
26
+ fused_add_norm=True,
27
+ pad_vocab_size_multiple=16,
28
+ )
29
+ torch.manual_seed(2357)
30
+ model = MambaLMHeadModel(config, device=device, dtype=dtype)
31
+ x = torch.randint(0, 1000, (batch, seqlen), device=device, dtype=torch.long)
32
+ out_ref = model(x).logits
33
+ prompt_len = seqlen // 2
34
+ out = model.generate(
35
+ input_ids = x[:, :prompt_len], max_length=seqlen, output_scores=True, return_dict_in_generate=True,
36
+ cg=True, # Can turn off CUDA graph for easier debugging
37
+ # instead of sampling, we take output tokens from x, to get logits for testing
38
+ # For actual generation, don't pass in teacher_outputs
39
+ teacher_outputs=x,
40
+ )
41
+ out_scores = torch.stack(out.scores, dim=1)
42
+ print(f"Max diff: {(out_scores - out_ref[:, prompt_len - 1: -1]).abs().max()}")
43
+ assert torch.allclose(out_scores, out_ref[:, prompt_len - 1: -1], rtol=1e-3, atol=1e-2)
44
+
45
+
46
+ def test_generation_varlen():
47
+ seqlens = [170, 65, 100]
48
+ genlen = 20
49
+ total_seqlen = sum(seqlens)
50
+ device = "cuda"
51
+ dtype = torch.float16
52
+
53
+ config = MambaConfig(
54
+ d_model=1024,
55
+ n_layer=4,
56
+ vocab_size=50277,
57
+ ssm_cfg=dict(layer="Mamba2"),
58
+ rms_norm=True,
59
+ residual_in_fp32=True,
60
+ fused_add_norm=True,
61
+ pad_vocab_size_multiple=16,
62
+ )
63
+ torch.manual_seed(2357)
64
+ model = MambaLMHeadModel(config, device=device, dtype=dtype)
65
+ xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]
66
+
67
+ # Reference 1: Forward pass with seq_idx
68
+ x = torch.cat(xs, dim=1)
69
+ seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)
70
+ for i, ids in enumerate(xs)], dim=0).unsqueeze(0)
71
+ cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
72
+ out_ref = model(x, seq_idx=seq_idx).logits
73
+ # Only take the last @genlen logits of each sequence
74
+ out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
75
+ for i in range(len(seqlens))], dim=0)
76
+
77
+ # Reference 2: Generate the last @genlen tokens of each sequence in a for loop
78
+ out_loop = []
79
+ for input_ids in xs:
80
+ out = model.generate(
81
+ input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,
82
+ return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,
83
+ ).scores
84
+ out_loop.append(torch.stack(out, dim=1))
85
+ out_loop = torch.cat(out_loop, dim=0)
86
+ print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")
87
+
88
+ # Varlen generation
89
+ input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)
90
+ prompt_seqlens = [seqlen - genlen for seqlen in seqlens]
91
+ cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
92
+ seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)
93
+ for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)
94
+ inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))
95
+
96
+ scores, sequences = [], []
97
+ # Both seq_idx and cu_seqlens must be passed in for varlen generation
98
+ logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits
99
+ logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d")
100
+ scores.append(logits)
101
+ # In practice we should sample. In this case we take from the teacher_output for testing
102
+ sampled_tokens = rearrange(torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1")
103
+ sequences.append(sampled_tokens)
104
+ for i in range(1, genlen):
105
+ inference_params.seqlen_offset += 1
106
+ logits = model(sampled_tokens, inference_params=inference_params, num_last_tokens=1).logits
107
+ scores.append(logits)
108
+ # In practice we should sample. In this case we take from the teacher_output for testing
109
+ sampled_tokens = rearrange(torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1")
110
+ sequences.append(sampled_tokens)
111
+ out_varlen = torch.cat(scores, dim=1)
112
+ print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
113
+ assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()
torch-ext/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.2.4"
2
+
3
+ from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from .modules.mamba_simple import Mamba
5
+ from .modules.mamba2 import Mamba2
6
+ from .models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ __all__ = [
9
+ "selective_scan_fn",
10
+ "mamba_inner_fn",
11
+ "Mamba",
12
+ "Mamba2",
13
+ "MambaLMHeadModel",
14
+ ]
torch-ext/mamba_ssm/distributed/__init__.py ADDED
File without changes
torch-ext/mamba_ssm/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
torch-ext/mamba_ssm/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ from ..utils.torch import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange
13
+
14
+ from ..distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = (
47
+ bias.to(dtype=torch.get_autocast_gpu_dtype())
48
+ if bias is not None
49
+ else None
50
+ )
51
+ weight = weight.contiguous()
52
+ if process_group is not None and sequence_parallel:
53
+ handle_x.wait()
54
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
55
+ batch_dim = batch_shape.numel()
56
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
57
+ output = F.linear(total_x, weight, bias)
58
+ if ctx.compute_weight_gradient:
59
+ ctx.save_for_backward(x, weight)
60
+ else:
61
+ ctx.save_for_backward(weight)
62
+ return output
63
+
64
+ @staticmethod
65
+ @custom_bwd
66
+ def backward(ctx, grad_output):
67
+ grad_output = grad_output.contiguous()
68
+ process_group = ctx.process_group
69
+ sequence_parallel = ctx.sequence_parallel
70
+ if ctx.compute_weight_gradient:
71
+ x, weight = ctx.saved_tensors
72
+ if process_group is not None and sequence_parallel:
73
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
74
+ else:
75
+ total_x = x
76
+ else:
77
+ (weight,) = ctx.saved_tensors
78
+ total_x = None
79
+ batch_shape = grad_output.shape[:-1]
80
+ batch_dim = batch_shape.numel()
81
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
82
+ if ctx.needs_input_grad[0]:
83
+ grad_input = F.linear(grad_output, weight.t())
84
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
85
+ if process_group is not None:
86
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
87
+ grad_input, handle_grad_input = reduce_fn(
88
+ grad_input, process_group, async_op=True
89
+ )
90
+ else:
91
+ grad_input = None
92
+ if ctx.needs_input_grad[1]:
93
+ assert ctx.compute_weight_gradient
94
+ if process_group is not None and sequence_parallel:
95
+ handle_x.wait()
96
+ grad_weight = torch.einsum(
97
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
98
+ )
99
+ else:
100
+ grad_weight = None
101
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
102
+ if process_group is not None and ctx.needs_input_grad[0]:
103
+ handle_grad_input.wait()
104
+ return grad_input, grad_weight, grad_bias, None, None
105
+
106
+
107
+ def parallel_linear_func(
108
+ x: Tensor,
109
+ weight: Tensor,
110
+ bias: Optional[Tensor] = None,
111
+ process_group: Optional[ProcessGroup] = None,
112
+ sequence_parallel: bool = True,
113
+ ):
114
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
115
+
116
+
117
+ class ColumnParallelLinear(nn.Linear):
118
+ def __init__(
119
+ self,
120
+ in_features: int,
121
+ out_features: int,
122
+ process_group: ProcessGroup,
123
+ bias: bool = True,
124
+ sequence_parallel=True,
125
+ multiple_of=1,
126
+ device=None,
127
+ dtype=None,
128
+ ) -> None:
129
+ world_size = torch.distributed.get_world_size(process_group)
130
+ if out_features % multiple_of:
131
+ raise ValueError(
132
+ f"out_features ({out_features}) must be a multiple of {multiple_of}"
133
+ )
134
+ multiple = out_features // multiple_of
135
+ # We want to split @multiple across world_size, but it could be an uneven split
136
+ div = multiple // world_size
137
+ mod = multiple % world_size
138
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
139
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
140
+ super().__init__(
141
+ in_features,
142
+ local_multiple * multiple_of,
143
+ bias=bias,
144
+ device=device,
145
+ dtype=dtype,
146
+ )
147
+ self.process_group = process_group
148
+ self.sequence_parallel = sequence_parallel
149
+
150
+ def forward(self, x):
151
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
152
+ # we do an all_gather of x before doing the matmul.
153
+ # If not, then the input is already gathered.
154
+ return parallel_linear_func(
155
+ x,
156
+ self.weight,
157
+ self.bias,
158
+ process_group=self.process_group,
159
+ sequence_parallel=self.sequence_parallel,
160
+ )
161
+
162
+
163
+ class RowParallelLinear(nn.Linear):
164
+ def __init__(
165
+ self,
166
+ in_features: int,
167
+ out_features: int,
168
+ process_group: ProcessGroup,
169
+ bias: bool = True,
170
+ sequence_parallel=True,
171
+ multiple_of=1,
172
+ device=None,
173
+ dtype=None,
174
+ ) -> None:
175
+ world_size = torch.distributed.get_world_size(process_group)
176
+ rank = torch.distributed.get_rank(process_group)
177
+ if in_features % multiple_of:
178
+ raise ValueError(
179
+ f"in_features ({in_features}) must be a multiple of {multiple_of}"
180
+ )
181
+ multiple = in_features // multiple_of
182
+ # We want to split @multiple across world_size, but it could be an uneven split
183
+ div = multiple // world_size
184
+ mod = multiple % world_size
185
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
186
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
187
+ # Only rank 0 will have bias
188
+ super().__init__(
189
+ local_multiple * multiple_of,
190
+ out_features,
191
+ bias=bias and rank == 0,
192
+ device=device,
193
+ dtype=dtype,
194
+ )
195
+ self.process_group = process_group
196
+ self.sequence_parallel = sequence_parallel
197
+
198
+ def forward(self, x):
199
+ """
200
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
201
+ a reduce_scatter of the result.
202
+ """
203
+ out = parallel_linear_func(x, self.weight, self.bias)
204
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
205
+ return reduce_fn(out, self.process_group)
206
+
207
+
208
+ class VocabParallelEmbedding(nn.Embedding):
209
+ def __init__(
210
+ self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs
211
+ ):
212
+ self.process_group = process_group
213
+ if process_group is not None:
214
+ world_size = torch.distributed.get_world_size(process_group)
215
+ if num_embeddings % world_size != 0:
216
+ raise ValueError(
217
+ f"num_embeddings ({num_embeddings}) must be divisible by "
218
+ f"world_size ({world_size})"
219
+ )
220
+ if world_size > 1 and padding_idx is not None:
221
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
222
+ else:
223
+ world_size = 1
224
+ super().__init__(
225
+ num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs
226
+ )
227
+
228
+ def forward(self, input: Tensor) -> Tensor:
229
+ if self.process_group is None:
230
+ return super().forward(input)
231
+ else:
232
+ rank = torch.distributed.get_rank(self.process_group)
233
+ vocab_size = self.num_embeddings
234
+ vocab_start_index, vocab_end_index = (
235
+ rank * vocab_size,
236
+ (rank + 1) * vocab_size,
237
+ )
238
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
239
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
240
+ input = input - vocab_start_index
241
+ input[input_ids_mask] = 0
242
+ embeddings = super().forward(input)
243
+ embeddings[input_ids_mask] = 0.0
244
+ return embeddings
245
+
246
+
247
+ class ColumnParallelEmbedding(nn.Embedding):
248
+ def __init__(
249
+ self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs
250
+ ):
251
+ self.process_group = process_group
252
+ if process_group is not None:
253
+ world_size = torch.distributed.get_world_size(process_group)
254
+ if embedding_dim % world_size != 0:
255
+ raise ValueError(
256
+ f"embedding_dim ({embedding_dim}) must be divisible by "
257
+ f"world_size ({world_size})"
258
+ )
259
+ else:
260
+ world_size = 1
261
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
262
+
263
+
264
+ class ParallelEmbeddings(nn.Module):
265
+ def __init__(
266
+ self,
267
+ embed_dim,
268
+ vocab_size,
269
+ max_position_embeddings,
270
+ process_group,
271
+ padding_idx=None,
272
+ sequence_parallel=True,
273
+ device=None,
274
+ dtype=None,
275
+ ):
276
+ """
277
+ If max_position_embeddings <= 0, there's no position embeddings
278
+ """
279
+ factory_kwargs = {"device": device, "dtype": dtype}
280
+ super().__init__()
281
+ self.process_group = process_group
282
+ self.sequence_parallel = sequence_parallel
283
+ self.word_embeddings = VocabParallelEmbedding(
284
+ vocab_size,
285
+ embed_dim,
286
+ padding_idx=padding_idx,
287
+ process_group=process_group,
288
+ **factory_kwargs,
289
+ )
290
+ self.max_position_embeddings = max_position_embeddings
291
+ if self.max_position_embeddings > 0:
292
+ self.position_embeddings = ColumnParallelEmbedding(
293
+ max_position_embeddings,
294
+ embed_dim,
295
+ process_group=process_group,
296
+ **factory_kwargs,
297
+ )
298
+
299
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
300
+ """
301
+ input_ids: (batch, seqlen)
302
+ position_ids: (batch, seqlen)
303
+ """
304
+ batch_size, seqlen = input_ids.shape
305
+ world_size = torch.distributed.get_world_size(self.process_group)
306
+ embeddings = self.word_embeddings(input_ids)
307
+ if self.max_position_embeddings > 0:
308
+ if position_ids is None:
309
+ position_ids = torch.arange(
310
+ seqlen, dtype=torch.long, device=input_ids.device
311
+ )
312
+ position_embeddings = self.position_embeddings(position_ids)
313
+ if world_size <= 1:
314
+ embeddings = embeddings + position_embeddings
315
+ else:
316
+ partition_dim = self.position_embeddings.embedding_dim
317
+ rank = torch.distributed.get_rank(self.process_group)
318
+ embeddings[
319
+ ..., rank * partition_dim : (rank + 1) * partition_dim
320
+ ] += position_embeddings
321
+ if combine_batch_seqlen_dim:
322
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
323
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
324
+ return (
325
+ embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
326
+ )
torch-ext/mamba_ssm/models/__init__.py ADDED
File without changes
torch-ext/mamba_ssm/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
torch-ext/mamba_ssm/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .config_mamba import MambaConfig
15
+ from ..modules.mamba_simple import Mamba
16
+ from ..modules.mamba2 import Mamba2
17
+ from ..modules.mha import MHA
18
+ from ..modules.mlp import GatedMLP
19
+ from ..modules.block import Block
20
+ from ..utils.generation import GenerationMixin
21
+ from ..utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(
56
+ f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2"
57
+ )
58
+ mixer_cls = partial(
59
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
60
+ layer_idx=layer_idx,
61
+ **ssm_cfg,
62
+ **factory_kwargs,
63
+ )
64
+ else:
65
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
66
+ norm_cls = partial(
67
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
+ )
69
+ if d_intermediate == 0:
70
+ mlp_cls = nn.Identity
71
+ else:
72
+ mlp_cls = partial(
73
+ GatedMLP,
74
+ hidden_features=d_intermediate,
75
+ out_features=d_model,
76
+ **factory_kwargs,
77
+ )
78
+ block = Block(
79
+ d_model,
80
+ mixer_cls,
81
+ mlp_cls,
82
+ norm_cls=norm_cls,
83
+ fused_add_norm=fused_add_norm,
84
+ residual_in_fp32=residual_in_fp32,
85
+ )
86
+ block.layer_idx = layer_idx
87
+ return block
88
+
89
+
90
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
91
+ def _init_weights(
92
+ module,
93
+ n_layer,
94
+ initializer_range=0.02, # Now only used for embedding layer.
95
+ rescale_prenorm_residual=True,
96
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
97
+ ):
98
+ if isinstance(module, nn.Linear):
99
+ if module.bias is not None:
100
+ if not getattr(module.bias, "_no_reinit", False):
101
+ nn.init.zeros_(module.bias)
102
+ elif isinstance(module, nn.Embedding):
103
+ nn.init.normal_(module.weight, std=initializer_range)
104
+
105
+ if rescale_prenorm_residual:
106
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
107
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
108
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
109
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
110
+ #
111
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
112
+ for name, p in module.named_parameters():
113
+ if name in ["out_proj.weight", "fc2.weight"]:
114
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
115
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
116
+ # We need to reinit p since this code could be called multiple times
117
+ # Having just p *= scale would repeatedly scale it down
118
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
119
+ with torch.no_grad():
120
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
121
+
122
+
123
+ class MixerModel(nn.Module):
124
+ def __init__(
125
+ self,
126
+ d_model: int,
127
+ n_layer: int,
128
+ d_intermediate: int,
129
+ vocab_size: int,
130
+ ssm_cfg=None,
131
+ attn_layer_idx=None,
132
+ attn_cfg=None,
133
+ norm_epsilon: float = 1e-5,
134
+ rms_norm: bool = False,
135
+ initializer_cfg=None,
136
+ fused_add_norm=False,
137
+ residual_in_fp32=False,
138
+ device=None,
139
+ dtype=None,
140
+ ) -> None:
141
+ factory_kwargs = {"device": device, "dtype": dtype}
142
+ super().__init__()
143
+ self.residual_in_fp32 = residual_in_fp32
144
+
145
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
146
+
147
+ # We change the order of residual and layer norm:
148
+ # Instead of LN -> Attn / MLP -> Add, we do:
149
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
150
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
151
+ # This is for performance reason: we can fuse add + layer_norm.
152
+ self.fused_add_norm = fused_add_norm
153
+ if self.fused_add_norm:
154
+ if layer_norm_fn is None or rms_norm_fn is None:
155
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
156
+
157
+ self.layers = nn.ModuleList(
158
+ [
159
+ create_block(
160
+ d_model,
161
+ d_intermediate=d_intermediate,
162
+ ssm_cfg=ssm_cfg,
163
+ attn_layer_idx=attn_layer_idx,
164
+ attn_cfg=attn_cfg,
165
+ norm_epsilon=norm_epsilon,
166
+ rms_norm=rms_norm,
167
+ residual_in_fp32=residual_in_fp32,
168
+ fused_add_norm=fused_add_norm,
169
+ layer_idx=i,
170
+ **factory_kwargs,
171
+ )
172
+ for i in range(n_layer)
173
+ ]
174
+ )
175
+
176
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
177
+ d_model, eps=norm_epsilon, **factory_kwargs
178
+ )
179
+
180
+ self.apply(
181
+ partial(
182
+ _init_weights,
183
+ n_layer=n_layer,
184
+ **(initializer_cfg if initializer_cfg is not None else {}),
185
+ n_residuals_per_layer=(
186
+ 1 if d_intermediate == 0 else 2
187
+ ), # 2 if we have MLP
188
+ )
189
+ )
190
+
191
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
192
+ return {
193
+ i: layer.allocate_inference_cache(
194
+ batch_size, max_seqlen, dtype=dtype, **kwargs
195
+ )
196
+ for i, layer in enumerate(self.layers)
197
+ }
198
+
199
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
200
+ hidden_states = self.embedding(input_ids)
201
+ residual = None
202
+ for layer in self.layers:
203
+ hidden_states, residual = layer(
204
+ hidden_states,
205
+ residual,
206
+ inference_params=inference_params,
207
+ **mixer_kwargs,
208
+ )
209
+ if not self.fused_add_norm:
210
+ residual = (
211
+ (hidden_states + residual) if residual is not None else hidden_states
212
+ )
213
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
214
+ else:
215
+ # Set prenorm=False here since we don't need the residual
216
+ hidden_states = layer_norm_fn(
217
+ hidden_states,
218
+ self.norm_f.weight,
219
+ self.norm_f.bias,
220
+ eps=self.norm_f.eps,
221
+ residual=residual,
222
+ prenorm=False,
223
+ residual_in_fp32=self.residual_in_fp32,
224
+ is_rms_norm=isinstance(self.norm_f, RMSNorm),
225
+ )
226
+ return hidden_states
227
+
228
+
229
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
230
+
231
+ def __init__(
232
+ self,
233
+ config: MambaConfig,
234
+ initializer_cfg=None,
235
+ device=None,
236
+ dtype=None,
237
+ ) -> None:
238
+ self.config = config
239
+ d_model = config.d_model
240
+ n_layer = config.n_layer
241
+ d_intermediate = config.d_intermediate
242
+ vocab_size = config.vocab_size
243
+ ssm_cfg = config.ssm_cfg
244
+ attn_layer_idx = config.attn_layer_idx
245
+ attn_cfg = config.attn_cfg
246
+ rms_norm = config.rms_norm
247
+ residual_in_fp32 = config.residual_in_fp32
248
+ fused_add_norm = config.fused_add_norm
249
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
250
+ factory_kwargs = {"device": device, "dtype": dtype}
251
+
252
+ super().__init__()
253
+ if vocab_size % pad_vocab_size_multiple != 0:
254
+ vocab_size += pad_vocab_size_multiple - (
255
+ vocab_size % pad_vocab_size_multiple
256
+ )
257
+ self.backbone = MixerModel(
258
+ d_model=d_model,
259
+ n_layer=n_layer,
260
+ d_intermediate=d_intermediate,
261
+ vocab_size=vocab_size,
262
+ ssm_cfg=ssm_cfg,
263
+ attn_layer_idx=attn_layer_idx,
264
+ attn_cfg=attn_cfg,
265
+ rms_norm=rms_norm,
266
+ initializer_cfg=initializer_cfg,
267
+ fused_add_norm=fused_add_norm,
268
+ residual_in_fp32=residual_in_fp32,
269
+ **factory_kwargs,
270
+ )
271
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
272
+
273
+ # Initialize weights and apply final processing
274
+ self.apply(
275
+ partial(
276
+ _init_weights,
277
+ n_layer=n_layer,
278
+ **(initializer_cfg if initializer_cfg is not None else {}),
279
+ )
280
+ )
281
+ self.tie_weights()
282
+
283
+ def tie_weights(self):
284
+ if self.config.tie_embeddings:
285
+ self.lm_head.weight = self.backbone.embedding.weight
286
+
287
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
288
+ return self.backbone.allocate_inference_cache(
289
+ batch_size, max_seqlen, dtype=dtype, **kwargs
290
+ )
291
+
292
+ def forward(
293
+ self,
294
+ input_ids,
295
+ position_ids=None,
296
+ inference_params=None,
297
+ num_last_tokens=0,
298
+ **mixer_kwargs,
299
+ ):
300
+ """
301
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
302
+ num_last_tokens: if > 0, only return the logits for the last n tokens
303
+ """
304
+ hidden_states = self.backbone(
305
+ input_ids, inference_params=inference_params, **mixer_kwargs
306
+ )
307
+ if num_last_tokens > 0:
308
+ hidden_states = hidden_states[:, -num_last_tokens:]
309
+ lm_logits = self.lm_head(hidden_states)
310
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
311
+ return CausalLMOutput(logits=lm_logits)
312
+
313
+ @classmethod
314
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
315
+ config_data = load_config_hf(pretrained_model_name)
316
+ config = MambaConfig(**config_data)
317
+ model = cls(config, device=device, dtype=dtype, **kwargs)
318
+ model.load_state_dict(
319
+ load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
320
+ )
321
+ return model
322
+
323
+ def save_pretrained(self, save_directory):
324
+ """
325
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
326
+ Save the model and its configuration file to a directory.
327
+ """
328
+ # Ensure save_directory exists
329
+ os.makedirs(save_directory, exist_ok=True)
330
+
331
+ # Save the model's state_dict
332
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
333
+ torch.save(self.state_dict(), model_path)
334
+
335
+ # Save the configuration of the model
336
+ config_path = os.path.join(save_directory, "config.json")
337
+ with open(config_path, "w") as f:
338
+ json.dump(self.config.__dict__, f, indent=4)
torch-ext/mamba_ssm/modules/__init__.py ADDED
File without changes
torch-ext/mamba_ssm/modules/block.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ mixer_cls,
15
+ mlp_cls,
16
+ norm_cls=nn.LayerNorm,
17
+ fused_add_norm=False,
18
+ residual_in_fp32=False,
19
+ ):
20
+ """
21
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
+
23
+ This Block has a slightly different structure compared to a regular
24
+ prenorm Transformer block.
25
+ The standard block is: LN -> MHA/MLP -> Add.
26
+ [Ref: https://arxiv.org/abs/2002.04745]
27
+ Here we have: Add -> LN -> Mixer, returning both
28
+ the hidden_states (output of the mixer) and the residual.
29
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
30
+ The residual needs to be provided (except for the very first block).
31
+ """
32
+ super().__init__()
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.norm = norm_cls(dim)
36
+ self.mixer = mixer_cls(dim)
37
+ if mlp_cls is not nn.Identity:
38
+ self.norm2 = norm_cls(dim)
39
+ self.mlp = mlp_cls(dim)
40
+ else:
41
+ self.mlp = None
42
+ if self.fused_add_norm:
43
+ assert RMSNorm is not None, "RMSNorm import fails"
44
+ assert isinstance(
45
+ self.norm, (nn.LayerNorm, RMSNorm)
46
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: Tensor,
51
+ residual: Optional[Tensor] = None,
52
+ inference_params=None,
53
+ **mixer_kwargs
54
+ ):
55
+ r"""Pass the input through the encoder layer.
56
+
57
+ Args:
58
+ hidden_states: the sequence to the encoder layer (required).
59
+ residual: hidden_states = Mixer(LN(residual))
60
+ """
61
+ if not self.fused_add_norm:
62
+ residual = (
63
+ (hidden_states + residual) if residual is not None else hidden_states
64
+ )
65
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
+ if self.residual_in_fp32:
67
+ residual = residual.to(torch.float32)
68
+ else:
69
+ hidden_states, residual = layer_norm_fn(
70
+ hidden_states,
71
+ self.norm.weight,
72
+ self.norm.bias,
73
+ residual=residual,
74
+ prenorm=True,
75
+ residual_in_fp32=self.residual_in_fp32,
76
+ eps=self.norm.eps,
77
+ is_rms_norm=isinstance(self.norm, RMSNorm),
78
+ )
79
+ hidden_states = self.mixer(
80
+ hidden_states, inference_params=inference_params, **mixer_kwargs
81
+ )
82
+
83
+ if self.mlp is not None:
84
+ if not self.fused_add_norm:
85
+ residual = hidden_states + residual
86
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
+ if self.residual_in_fp32:
88
+ residual = residual.to(torch.float32)
89
+ else:
90
+ hidden_states, residual = layer_norm_fn(
91
+ hidden_states,
92
+ self.norm2.weight,
93
+ self.norm2.bias,
94
+ residual=residual,
95
+ prenorm=True,
96
+ residual_in_fp32=self.residual_in_fp32,
97
+ eps=self.norm2.eps,
98
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
99
+ )
100
+ hidden_states = self.mlp(hidden_states)
101
+
102
+ return hidden_states, residual
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(
106
+ batch_size, max_seqlen, dtype=dtype, **kwargs
107
+ )
torch-ext/mamba_ssm/modules/mamba2.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from ..ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(
99
+ self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
+ )
101
+ else:
102
+ self.in_proj = ColumnParallelLinear(
103
+ self.d_model,
104
+ d_in_proj * self.world_size,
105
+ bias=bias,
106
+ process_group=self.process_group,
107
+ sequence_parallel=self.sequence_parallel,
108
+ **factory_kwargs,
109
+ )
110
+
111
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
+ self.conv1d = nn.Conv1d(
113
+ in_channels=conv_dim,
114
+ out_channels=conv_dim,
115
+ bias=conv_bias,
116
+ kernel_size=d_conv,
117
+ groups=conv_dim,
118
+ padding=d_conv - 1,
119
+ **factory_kwargs,
120
+ )
121
+ if self.conv_init is not None:
122
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
+
124
+ self.act = nn.SiLU()
125
+
126
+ # Initialize log dt bias
127
+ dt = torch.exp(
128
+ torch.rand(self.nheads, **factory_kwargs)
129
+ * (math.log(dt_max) - math.log(dt_min))
130
+ + math.log(dt_min)
131
+ )
132
+ dt = torch.clamp(dt, min=dt_init_floor)
133
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
135
+ self.dt_bias = nn.Parameter(inv_dt)
136
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
+ # name.endswith("bias") in param_grouping.py
138
+ self.dt_bias._no_weight_decay = True
139
+
140
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
+ *A_init_range
143
+ )
144
+ A_log = torch.log(A).to(dtype=dtype)
145
+ self.A_log = nn.Parameter(A_log)
146
+ self.A_log._no_weight_decay = True
147
+
148
+ # D "skip" parameter
149
+ self.D = nn.Parameter(
150
+ torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
+ )
152
+ self.D._no_weight_decay = True
153
+
154
+ if self.rmsnorm:
155
+ assert RMSNormGated is not None
156
+ self.norm = RMSNormGated(
157
+ self.d_ssm,
158
+ eps=1e-5,
159
+ norm_before_gate=self.norm_before_gate,
160
+ group_size=self.d_ssm // ngroups,
161
+ **factory_kwargs,
162
+ )
163
+
164
+ if self.process_group is None:
165
+ self.out_proj = nn.Linear(
166
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
+ )
168
+ else:
169
+ self.out_proj = RowParallelLinear(
170
+ self.d_inner * self.world_size,
171
+ self.d_model,
172
+ bias=bias,
173
+ process_group=self.process_group,
174
+ sequence_parallel=self.sequence_parallel,
175
+ **factory_kwargs,
176
+ )
177
+
178
+ def forward(
179
+ self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
+ ):
181
+ """
182
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
183
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
+ split u during sequence parallel, we split the batch * seqlen dimension
185
+ (in case batch is small).
186
+ Returns: same shape as u
187
+ """
188
+ seqlen_og = seqlen
189
+ if seqlen is None:
190
+ batch, seqlen, dim = u.shape
191
+ else:
192
+ batch_seqlen, dim = u.shape
193
+ batch = batch_seqlen // seqlen
194
+
195
+ conv_state, ssm_state = None, None
196
+ if inference_params is not None:
197
+ inference_batch = (
198
+ cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
+ )
200
+ conv_state, ssm_state = self._get_states_from_cache(
201
+ inference_params, inference_batch
202
+ )
203
+ if inference_params.seqlen_offset > 0:
204
+ # The states are updated inplace
205
+ out, _, _ = self.step(u, conv_state, ssm_state)
206
+ return out
207
+
208
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
+ if seqlen_og is not None:
210
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
212
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
+ dt_limit_kwargs = (
214
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
+ )
216
+ if self.use_mem_eff_path and inference_params is None:
217
+ out = mamba_split_conv1d_scan_combined(
218
+ zxbcdt,
219
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
+ self.conv1d.bias,
221
+ self.dt_bias,
222
+ A,
223
+ D=(
224
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
+ if self.D_has_hdim
226
+ else self.D
227
+ ),
228
+ chunk_size=self.chunk_size,
229
+ seq_idx=seq_idx,
230
+ activation=self.activation,
231
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
+ outproj_weight=self.out_proj.weight,
234
+ outproj_bias=self.out_proj.bias,
235
+ headdim=None if self.D_has_hdim else self.headdim,
236
+ ngroups=self.ngroups,
237
+ norm_before_gate=self.norm_before_gate,
238
+ **dt_limit_kwargs,
239
+ )
240
+ if seqlen_og is not None:
241
+ out = rearrange(out, "b l d -> (b l) d")
242
+ if self.process_group is not None:
243
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
+ out = reduce_fn(out, self.process_group)
245
+ else:
246
+ d_mlp = (
247
+ zxbcdt.shape[-1]
248
+ - 2 * self.d_ssm
249
+ - 2 * self.ngroups * self.d_state
250
+ - self.nheads
251
+ ) // 2
252
+ z0, x0, z, xBC, dt = torch.split(
253
+ zxbcdt,
254
+ [
255
+ d_mlp,
256
+ d_mlp,
257
+ self.d_ssm,
258
+ self.d_ssm + 2 * self.ngroups * self.d_state,
259
+ self.nheads,
260
+ ],
261
+ dim=-1,
262
+ )
263
+ if conv_state is not None:
264
+ if cu_seqlens is None:
265
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
+ xBC_t = rearrange(xBC, "b l d -> b d l")
268
+ conv_state.copy_(
269
+ F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
+ ) # Update state (B D W)
271
+ else:
272
+ assert (
273
+ causal_conv1d_varlen_states is not None
274
+ ), "varlen inference requires causal_conv1d package"
275
+ assert (
276
+ batch == 1
277
+ ), "varlen inference only supports batch dimension 1"
278
+ conv_varlen_states = causal_conv1d_varlen_states(
279
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
+ )
281
+ conv_state.copy_(conv_varlen_states)
282
+ assert self.activation in ["silu", "swish"]
283
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
+ assert (
285
+ seq_idx is None
286
+ ), "varlen conv1d requires the causal_conv1d package"
287
+ xBC = self.act(
288
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
+ :, : -(self.d_conv - 1)
290
+ ]
291
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
+ else:
293
+ xBC = causal_conv1d_fn(
294
+ xBC.transpose(1, 2),
295
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
+ bias=self.conv1d.bias,
297
+ activation=self.activation,
298
+ seq_idx=seq_idx,
299
+ ).transpose(1, 2)
300
+ x, B, C = torch.split(
301
+ xBC,
302
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
+ dim=-1,
304
+ )
305
+ y = mamba_chunk_scan_combined(
306
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
+ dt,
308
+ A,
309
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
+ chunk_size=self.chunk_size,
312
+ D=(
313
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
+ if self.D_has_hdim
315
+ else self.D
316
+ ),
317
+ z=(
318
+ rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
+ if not self.rmsnorm
320
+ else None
321
+ ),
322
+ dt_bias=self.dt_bias,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None
329
+ and inference_params is not None,
330
+ )
331
+ if ssm_state is not None:
332
+ y, last_state, *rest = y
333
+ if cu_seqlens is None:
334
+ ssm_state.copy_(last_state)
335
+ else:
336
+ varlen_states = rest[0]
337
+ ssm_state.copy_(varlen_states)
338
+ y = rearrange(y, "b l h p -> b l (h p)")
339
+ if self.rmsnorm:
340
+ y = self.norm(y, z)
341
+ if d_mlp > 0:
342
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
+ if seqlen_og is not None:
344
+ y = rearrange(y, "b l d -> (b l) d")
345
+ out = self.out_proj(y)
346
+ return out
347
+
348
+ def step(self, hidden_states, conv_state, ssm_state):
349
+ dtype = hidden_states.dtype
350
+ assert (
351
+ hidden_states.shape[1] == 1
352
+ ), "Only support decoding with 1 token at a time for now"
353
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
+ d_mlp = (
355
+ zxbcdt.shape[-1]
356
+ - 2 * self.d_ssm
357
+ - 2 * self.ngroups * self.d_state
358
+ - self.nheads
359
+ ) // 2
360
+ z0, x0, z, xBC, dt = torch.split(
361
+ zxbcdt,
362
+ [
363
+ d_mlp,
364
+ d_mlp,
365
+ self.d_ssm,
366
+ self.d_ssm + 2 * self.ngroups * self.d_state,
367
+ self.nheads,
368
+ ],
369
+ dim=-1,
370
+ )
371
+
372
+ # Conv step
373
+ if causal_conv1d_update is None:
374
+ conv_state.copy_(
375
+ torch.roll(conv_state, shifts=-1, dims=-1)
376
+ ) # Update state (B D W)
377
+ conv_state[:, :, -1] = xBC
378
+ xBC = torch.sum(
379
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
+ ) # (B D)
381
+ if self.conv1d.bias is not None:
382
+ xBC = xBC + self.conv1d.bias
383
+ xBC = self.act(xBC).to(dtype=dtype)
384
+ else:
385
+ xBC = causal_conv1d_update(
386
+ xBC,
387
+ conv_state,
388
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
+ self.conv1d.bias,
390
+ self.activation,
391
+ )
392
+
393
+ x, B, C = torch.split(
394
+ xBC,
395
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
+ dim=-1,
397
+ )
398
+ A = -torch.exp(self.A_log.float()) # (nheads,)
399
+
400
+ # SSM step
401
+ if selective_state_update is None:
402
+ assert (
403
+ self.ngroups == 1
404
+ ), "Only support ngroups=1 for this inference code path"
405
+ # Discretize A and B
406
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
+ dA = torch.exp(dt * A) # (batch, nheads)
408
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
+ y = rearrange(y, "b h p -> b (h p)")
414
+ if not self.rmsnorm:
415
+ y = y * self.act(z) # (B D)
416
+ else:
417
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
+ dtype=torch.float32
419
+ )
420
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
+ D = repeat(self.D, "h -> h p", p=self.headdim)
423
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
+ if not self.rmsnorm:
427
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
+ y = selective_state_update(
429
+ ssm_state,
430
+ x_reshaped,
431
+ dt,
432
+ A,
433
+ B,
434
+ C,
435
+ D,
436
+ z=z if not self.rmsnorm else None,
437
+ dt_bias=dt_bias,
438
+ dt_softplus=True,
439
+ )
440
+ y = rearrange(y, "b h p -> b (h p)")
441
+ if self.rmsnorm:
442
+ y = self.norm(y, z)
443
+ if d_mlp > 0:
444
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
+ out = self.out_proj(y)
446
+ return out.unsqueeze(1), conv_state, ssm_state
447
+
448
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
+ device = self.out_proj.weight.device
450
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
+ conv_state = torch.zeros(
452
+ batch_size,
453
+ self.d_conv,
454
+ self.conv1d.weight.shape[0],
455
+ device=device,
456
+ dtype=conv_dtype,
457
+ ).transpose(1, 2)
458
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
+ ssm_state = torch.zeros(
460
+ batch_size,
461
+ self.nheads,
462
+ self.headdim,
463
+ self.d_state,
464
+ device=device,
465
+ dtype=ssm_dtype,
466
+ )
467
+ return conv_state, ssm_state
468
+
469
+ def _get_states_from_cache(
470
+ self, inference_params, batch_size, initialize_states=False
471
+ ):
472
+ assert self.layer_idx is not None
473
+ if self.layer_idx not in inference_params.key_value_memory_dict:
474
+ batch_shape = (batch_size,)
475
+ conv_state = torch.zeros(
476
+ batch_size,
477
+ self.d_conv,
478
+ self.conv1d.weight.shape[0],
479
+ device=self.conv1d.weight.device,
480
+ dtype=self.conv1d.weight.dtype,
481
+ ).transpose(1, 2)
482
+ ssm_state = torch.zeros(
483
+ batch_size,
484
+ self.nheads,
485
+ self.headdim,
486
+ self.d_state,
487
+ device=self.in_proj.weight.device,
488
+ dtype=self.in_proj.weight.dtype,
489
+ )
490
+ inference_params.key_value_memory_dict[self.layer_idx] = (
491
+ conv_state,
492
+ ssm_state,
493
+ )
494
+ else:
495
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
496
+ self.layer_idx
497
+ ]
498
+ # TODO: What if batch size changes between generation, and we reuse the same states?
499
+ if initialize_states:
500
+ conv_state.zero_()
501
+ ssm_state.zero_()
502
+ return conv_state, ssm_state
torch-ext/mamba_ssm/modules/mamba2_simple.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from causal_conv1d import causal_conv1d_fn
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+
15
+ try:
16
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
+ except ImportError:
18
+ RMSNormGated, LayerNorm = None, None
19
+
20
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
21
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
+
23
+
24
+ class Mamba2Simple(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model,
28
+ d_state=64,
29
+ d_conv=4,
30
+ conv_init=None,
31
+ expand=2,
32
+ headdim=128,
33
+ ngroups=1,
34
+ A_init_range=(1, 16),
35
+ dt_min=0.001,
36
+ dt_max=0.1,
37
+ dt_init_floor=1e-4,
38
+ dt_limit=(0.0, float("inf")),
39
+ learnable_init_states=False,
40
+ activation="swish",
41
+ bias=False,
42
+ conv_bias=True,
43
+ # Fused kernel and sharding options
44
+ chunk_size=256,
45
+ use_mem_eff_path=True,
46
+ layer_idx=None, # Absorb kwarg for general module
47
+ device=None,
48
+ dtype=None,
49
+ ):
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.d_model = d_model
53
+ self.d_state = d_state
54
+ self.d_conv = d_conv
55
+ self.conv_init = conv_init
56
+ self.expand = expand
57
+ self.d_inner = self.expand * self.d_model
58
+ self.headdim = headdim
59
+ self.ngroups = ngroups
60
+ assert self.d_inner % self.headdim == 0
61
+ self.nheads = self.d_inner // self.headdim
62
+ self.dt_limit = dt_limit
63
+ self.learnable_init_states = learnable_init_states
64
+ self.activation = activation
65
+ self.chunk_size = chunk_size
66
+ self.use_mem_eff_path = use_mem_eff_path
67
+ self.layer_idx = layer_idx
68
+
69
+ # Order: [z, x, B, C, dt]
70
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
+
73
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
+ self.conv1d = nn.Conv1d(
75
+ in_channels=conv_dim,
76
+ out_channels=conv_dim,
77
+ bias=conv_bias,
78
+ kernel_size=d_conv,
79
+ groups=conv_dim,
80
+ padding=d_conv - 1,
81
+ **factory_kwargs,
82
+ )
83
+ if self.conv_init is not None:
84
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
+ # self.conv1d.weight._no_weight_decay = True
86
+
87
+ if self.learnable_init_states:
88
+ self.init_states = nn.Parameter(
89
+ torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
90
+ )
91
+ self.init_states._no_weight_decay = True
92
+
93
+ self.act = nn.SiLU()
94
+
95
+ # Initialize log dt bias
96
+ dt = torch.exp(
97
+ torch.rand(self.nheads, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ )
101
+ dt = torch.clamp(dt, min=dt_init_floor)
102
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
103
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
104
+ self.dt_bias = nn.Parameter(inv_dt)
105
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
106
+ # name.endswith("bias") in param_grouping.py
107
+ self.dt_bias._no_weight_decay = True
108
+
109
+ # A parameter
110
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
111
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
112
+ *A_init_range
113
+ )
114
+ A_log = torch.log(A).to(dtype=dtype)
115
+ self.A_log = nn.Parameter(A_log)
116
+ # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
117
+ self.A_log._no_weight_decay = True
118
+
119
+ # D "skip" parameter
120
+ self.D = nn.Parameter(torch.ones(self.nheads, device=device))
121
+ self.D._no_weight_decay = True
122
+
123
+ # Extra normalization layer right before output projection
124
+ assert RMSNormGated is not None
125
+ self.norm = RMSNormGated(
126
+ self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
127
+ )
128
+
129
+ self.out_proj = nn.Linear(
130
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
131
+ )
132
+
133
+ def forward(self, u, seq_idx=None):
134
+ """
135
+ u: (B, L, D)
136
+ Returns: same shape as u
137
+ """
138
+ batch, seqlen, dim = u.shape
139
+
140
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
141
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
142
+ initial_states = (
143
+ repeat(self.init_states, "... -> b ...", b=batch)
144
+ if self.learnable_init_states
145
+ else None
146
+ )
147
+ dt_limit_kwargs = (
148
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
149
+ )
150
+
151
+ if self.use_mem_eff_path:
152
+ # Fully fused path
153
+ out = mamba_split_conv1d_scan_combined(
154
+ zxbcdt,
155
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
156
+ self.conv1d.bias,
157
+ self.dt_bias,
158
+ A,
159
+ D=self.D,
160
+ chunk_size=self.chunk_size,
161
+ seq_idx=seq_idx,
162
+ activation=self.activation,
163
+ rmsnorm_weight=self.norm.weight,
164
+ rmsnorm_eps=self.norm.eps,
165
+ outproj_weight=self.out_proj.weight,
166
+ outproj_bias=self.out_proj.bias,
167
+ headdim=self.headdim,
168
+ ngroups=self.ngroups,
169
+ norm_before_gate=False,
170
+ initial_states=initial_states,
171
+ **dt_limit_kwargs,
172
+ )
173
+ else:
174
+ z, xBC, dt = torch.split(
175
+ zxbcdt,
176
+ [
177
+ self.d_inner,
178
+ self.d_inner + 2 * self.ngroups * self.d_state,
179
+ self.nheads,
180
+ ],
181
+ dim=-1,
182
+ )
183
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
184
+ assert self.activation in ["silu", "swish"]
185
+
186
+ # 1D Convolution
187
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
188
+ xBC = self.act(
189
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
190
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
191
+ xBC = xBC[:, :seqlen, :]
192
+ else:
193
+ xBC = causal_conv1d_fn(
194
+ x=xBC.transpose(1, 2),
195
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
196
+ bias=self.conv1d.bias,
197
+ activation=self.activation,
198
+ ).transpose(1, 2)
199
+
200
+ # Split into 3 main branches: X, B, C
201
+ # These correspond to V, K, Q respectively in the SSM/attention duality
202
+ x, B, C = torch.split(
203
+ xBC,
204
+ [
205
+ self.d_inner,
206
+ self.ngroups * self.d_state,
207
+ self.ngroups * self.d_state,
208
+ ],
209
+ dim=-1,
210
+ )
211
+ y = mamba_chunk_scan_combined(
212
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
213
+ dt,
214
+ A,
215
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
216
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
217
+ chunk_size=self.chunk_size,
218
+ D=self.D,
219
+ z=None,
220
+ seq_idx=seq_idx,
221
+ initial_states=initial_states,
222
+ **dt_limit_kwargs,
223
+ )
224
+ y = rearrange(y, "b l h p -> b l (h p)")
225
+
226
+ # Multiply "gate" branch and apply extra normalization layer
227
+ y = self.norm(y, z)
228
+ out = self.out_proj(y)
229
+ return out
torch-ext/mamba_ssm/modules/mamba_simple.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from ..ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(
63
+ self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
64
+ )
65
+
66
+ self.conv1d = nn.Conv1d(
67
+ in_channels=self.d_inner,
68
+ out_channels=self.d_inner,
69
+ bias=conv_bias,
70
+ kernel_size=d_conv,
71
+ groups=self.d_inner,
72
+ padding=d_conv - 1,
73
+ **factory_kwargs,
74
+ )
75
+
76
+ self.activation = "silu"
77
+ self.act = nn.SiLU()
78
+
79
+ self.x_proj = nn.Linear(
80
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
+ )
82
+ self.dt_proj = nn.Linear(
83
+ self.dt_rank, self.d_inner, bias=True, **factory_kwargs
84
+ )
85
+
86
+ # Initialize special dt projection to preserve variance at initialization
87
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
88
+ if dt_init == "constant":
89
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
90
+ elif dt_init == "random":
91
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
96
+ dt = torch.exp(
97
+ torch.rand(self.d_inner, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ ).clamp(min=dt_init_floor)
101
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
103
+ with torch.no_grad():
104
+ self.dt_proj.bias.copy_(inv_dt)
105
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
+ self.dt_proj.bias._no_reinit = True
107
+
108
+ # S4D real initialization
109
+ A = repeat(
110
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
111
+ "n -> d n",
112
+ d=self.d_inner,
113
+ ).contiguous()
114
+ A_log = torch.log(A) # Keep A_log in fp32
115
+ self.A_log = nn.Parameter(A_log)
116
+ self.A_log._no_weight_decay = True
117
+
118
+ # D "skip" parameter
119
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
120
+ self.D._no_weight_decay = True
121
+
122
+ self.out_proj = nn.Linear(
123
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
124
+ )
125
+
126
+ def forward(self, hidden_states, inference_params=None):
127
+ """
128
+ hidden_states: (B, L, D)
129
+ Returns: same shape as hidden_states
130
+ """
131
+ batch, seqlen, dim = hidden_states.shape
132
+
133
+ conv_state, ssm_state = None, None
134
+ if inference_params is not None:
135
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
136
+ if inference_params.seqlen_offset > 0:
137
+ # The states are updated inplace
138
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
139
+ return out
140
+
141
+ # We do matmul and transpose BLH -> HBL at the same time
142
+ xz = rearrange(
143
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
144
+ "d (b l) -> b d l",
145
+ l=seqlen,
146
+ )
147
+ if self.in_proj.bias is not None:
148
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
149
+
150
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
151
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
152
+ if (
153
+ self.use_fast_path
154
+ and causal_conv1d_fn is not None
155
+ and inference_params is None
156
+ ): # Doesn't support outputting the states
157
+ out = mamba_inner_fn(
158
+ xz,
159
+ self.conv1d.weight,
160
+ self.conv1d.bias,
161
+ self.x_proj.weight,
162
+ self.dt_proj.weight,
163
+ self.out_proj.weight,
164
+ self.out_proj.bias,
165
+ A,
166
+ None, # input-dependent B
167
+ None, # input-dependent C
168
+ self.D.float(),
169
+ delta_bias=self.dt_proj.bias.float(),
170
+ delta_softplus=True,
171
+ )
172
+ else:
173
+ x, z = xz.chunk(2, dim=1)
174
+ # Compute short convolution
175
+ if conv_state is not None:
176
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
177
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
178
+ conv_state.copy_(
179
+ F.pad(x, (self.d_conv - x.shape[-1], 0))
180
+ ) # Update state (B D W)
181
+ if causal_conv1d_fn is None:
182
+ x = self.act(self.conv1d(x)[..., :seqlen])
183
+ else:
184
+ assert self.activation in ["silu", "swish"]
185
+ x = causal_conv1d_fn(
186
+ x=x,
187
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
+ bias=self.conv1d.bias,
189
+ activation=self.activation,
190
+ )
191
+
192
+ # We're careful here about the layout, to avoid extra transposes.
193
+ # We want dt to have d as the slowest moving dimension
194
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
195
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
196
+ dt, B, C = torch.split(
197
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
198
+ )
199
+ dt = self.dt_proj.weight @ dt.t()
200
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
201
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
202
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
203
+ assert self.activation in ["silu", "swish"]
204
+ y = selective_scan_fn(
205
+ x,
206
+ dt,
207
+ A,
208
+ B,
209
+ C,
210
+ self.D.float(),
211
+ z=z,
212
+ delta_bias=self.dt_proj.bias.float(),
213
+ delta_softplus=True,
214
+ return_last_state=ssm_state is not None,
215
+ )
216
+ if ssm_state is not None:
217
+ y, last_state = y
218
+ ssm_state.copy_(last_state)
219
+ y = rearrange(y, "b d l -> b l d")
220
+ out = self.out_proj(y)
221
+ return out
222
+
223
+ def step(self, hidden_states, conv_state, ssm_state):
224
+ dtype = hidden_states.dtype
225
+ assert (
226
+ hidden_states.shape[1] == 1
227
+ ), "Only support decoding with 1 token at a time for now"
228
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
229
+ x, z = xz.chunk(2, dim=-1) # (B D)
230
+
231
+ # Conv step
232
+ if causal_conv1d_update is None:
233
+ conv_state.copy_(
234
+ torch.roll(conv_state, shifts=-1, dims=-1)
235
+ ) # Update state (B D W)
236
+ conv_state[:, :, -1] = x
237
+ x = torch.sum(
238
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
239
+ ) # (B D)
240
+ if self.conv1d.bias is not None:
241
+ x = x + self.conv1d.bias
242
+ x = self.act(x).to(dtype=dtype)
243
+ else:
244
+ x = causal_conv1d_update(
245
+ x,
246
+ conv_state,
247
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
248
+ self.conv1d.bias,
249
+ self.activation,
250
+ )
251
+
252
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
253
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
254
+ # Don't add dt_bias here
255
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
256
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
257
+
258
+ # SSM step
259
+ if selective_state_update is None:
260
+ # Discretize A and B
261
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
262
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
263
+ dB = torch.einsum("bd,bn->bdn", dt, B)
264
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
265
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
266
+ y = y + self.D.to(dtype) * x
267
+ y = y * self.act(z) # (B D)
268
+ else:
269
+ y = selective_state_update(
270
+ ssm_state,
271
+ x,
272
+ dt,
273
+ A,
274
+ B,
275
+ C,
276
+ self.D,
277
+ z=z,
278
+ dt_bias=self.dt_proj.bias,
279
+ dt_softplus=True,
280
+ )
281
+
282
+ out = self.out_proj(y)
283
+ return out.unsqueeze(1), conv_state, ssm_state
284
+
285
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
286
+ device = self.out_proj.weight.device
287
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
288
+ conv_state = torch.zeros(
289
+ batch_size,
290
+ self.d_model * self.expand,
291
+ self.d_conv,
292
+ device=device,
293
+ dtype=conv_dtype,
294
+ )
295
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
296
+ # ssm_dtype = torch.float32
297
+ ssm_state = torch.zeros(
298
+ batch_size,
299
+ self.d_model * self.expand,
300
+ self.d_state,
301
+ device=device,
302
+ dtype=ssm_dtype,
303
+ )
304
+ return conv_state, ssm_state
305
+
306
+ def _get_states_from_cache(
307
+ self, inference_params, batch_size, initialize_states=False
308
+ ):
309
+ assert self.layer_idx is not None
310
+ if self.layer_idx not in inference_params.key_value_memory_dict:
311
+ batch_shape = (batch_size,)
312
+ conv_state = torch.zeros(
313
+ batch_size,
314
+ self.d_model * self.expand,
315
+ self.d_conv,
316
+ device=self.conv1d.weight.device,
317
+ dtype=self.conv1d.weight.dtype,
318
+ )
319
+ ssm_state = torch.zeros(
320
+ batch_size,
321
+ self.d_model * self.expand,
322
+ self.d_state,
323
+ device=self.dt_proj.weight.device,
324
+ dtype=self.dt_proj.weight.dtype,
325
+ # dtype=torch.float32,
326
+ )
327
+ inference_params.key_value_memory_dict[self.layer_idx] = (
328
+ conv_state,
329
+ ssm_state,
330
+ )
331
+ else:
332
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
333
+ self.layer_idx
334
+ ]
335
+ # TODO: What if batch size changes between generation, and we reuse the same states?
336
+ if initialize_states:
337
+ conv_state.zero_()
338
+ ssm_state.zero_()
339
+ return conv_state, ssm_state
torch-ext/mamba_ssm/modules/mha.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ try:
11
+ from flash_attn import flash_attn_with_kvcache
12
+ except ImportError:
13
+ flash_attn_with_kvcache = None
14
+
15
+ try:
16
+ from flash_attn.layers.rotary import RotaryEmbedding
17
+ except ImportError:
18
+ RotaryEmbedding = None
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn, causal_conv1d_update = None, None
24
+
25
+
26
+ def _update_kv_cache(kv, inference_params, layer_idx):
27
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
+ # Pre-allocate memory for key-values for inference.
29
+ num_heads, head_dim = kv.shape[-2:]
30
+ assert layer_idx in inference_params.key_value_memory_dict
31
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
+ # Adjust key and value for inference
33
+ batch_start = inference_params.batch_size_offset
34
+ batch_end = batch_start + kv.shape[0]
35
+ sequence_start = inference_params.seqlen_offset
36
+ sequence_end = sequence_start + kv.shape[1]
37
+ assert batch_end <= kv_cache.shape[0]
38
+ assert sequence_end <= kv_cache.shape[1]
39
+ assert kv_cache is not None
40
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
+
43
+
44
+ class MHA(nn.Module):
45
+ """Multi-head self-attention and cross-attention"""
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim,
50
+ num_heads,
51
+ num_heads_kv=None,
52
+ head_dim=None, # If None, use embed_dim // num_heads
53
+ mlp_dim=0,
54
+ qkv_proj_bias=True,
55
+ out_proj_bias=True,
56
+ softmax_scale=None,
57
+ causal=False,
58
+ layer_idx=None,
59
+ d_conv=0,
60
+ rotary_emb_dim=0,
61
+ rotary_emb_base=10000.0,
62
+ rotary_emb_interleaved=False,
63
+ device=None,
64
+ dtype=None,
65
+ ) -> None:
66
+ """
67
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
+ return_residual: whether to return the input x along with the output. This is for
69
+ performance reason: for post-norm architecture, returning the input allows us
70
+ to fuse the backward of nn.Linear with the residual connection.
71
+ """
72
+ factory_kwargs = {"device": device, "dtype": dtype}
73
+ super().__init__()
74
+ self.embed_dim = embed_dim
75
+ self.layer_idx = layer_idx
76
+ self.d_conv = d_conv
77
+ self.rotary_emb_dim = rotary_emb_dim
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+
81
+ self.num_heads = num_heads
82
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
+ assert (
84
+ self.num_heads % self.num_heads_kv == 0
85
+ ), "num_heads must be divisible by num_heads_kv"
86
+ if head_dim is None:
87
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
+ self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ if self.rotary_emb_dim > 0:
94
+ assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
+ self.rotary_emb = RotaryEmbedding(
96
+ self.rotary_emb_dim,
97
+ base=rotary_emb_base,
98
+ interleaved=rotary_emb_interleaved,
99
+ device=device,
100
+ )
101
+
102
+ self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
+ if self.d_conv > 0:
104
+ self.conv1d = nn.Conv1d(
105
+ qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
+ **factory_kwargs
107
+ )
108
+ self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
+
110
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
+ device = self.out_proj.weight.device
113
+ if self.d_conv > 0:
114
+ conv_state = torch.zeros(
115
+ batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
+ )
117
+ else:
118
+ conv_state = None
119
+ kv_cache = torch.empty(
120
+ batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
+ )
122
+ return kv_cache, conv_state
123
+
124
+ def _update_kv_cache(self, kv, inference_params):
125
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
128
+
129
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
+ """
131
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
+ q: (batch_size, seqlen_q, nheads, head_dim)
133
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
+ """
135
+ assert inference_params is not None and inference_params.seqlen_offset > 0
136
+ if self.rotary_emb_dim > 0:
137
+ self.rotary_emb._update_cos_sin_cache(
138
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
+ )
140
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
+ else:
142
+ rotary_cos, rotary_sin = None, None
143
+ batch = q.shape[0]
144
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
+ kv_cache = kv_cache[:batch]
146
+ cache_seqlens = (
147
+ inference_params.lengths_per_sample[:batch]
148
+ if inference_params.lengths_per_sample is not None
149
+ else inference_params.seqlen_offset
150
+ )
151
+ assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
+ context = flash_attn_with_kvcache(
153
+ q,
154
+ kv_cache[:, :, 0],
155
+ kv_cache[:, :, 1],
156
+ kv[:, :, 0],
157
+ kv[:, :, 1],
158
+ rotary_cos=rotary_cos,
159
+ rotary_sin=rotary_sin,
160
+ cache_seqlens=cache_seqlens,
161
+ softmax_scale=self.softmax_scale,
162
+ causal=self.causal,
163
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
+ )
165
+ return context
166
+
167
+ def _update_kvcache_attention(self, q, kv, inference_params):
168
+ """Write kv to inference_params, then do attention"""
169
+ if (
170
+ inference_params.seqlen_offset == 0
171
+ or flash_attn_with_kvcache is None
172
+ ):
173
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
+ kv = self._update_kv_cache(kv, inference_params)
175
+ k, v = kv.unbind(dim=-3)
176
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
+ return F.scaled_dot_product_attention(
179
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
+ ).transpose(1, 2)
181
+ else:
182
+ batch = q.shape[0]
183
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
+ kv_cache = kv_cache[:batch]
185
+ cache_seqlens = (
186
+ inference_params.lengths_per_sample[:batch]
187
+ if inference_params.lengths_per_sample is not None
188
+ else inference_params.seqlen_offset
189
+ )
190
+ return flash_attn_with_kvcache(
191
+ q,
192
+ kv_cache[:, :, 0],
193
+ kv_cache[:, :, 1],
194
+ kv[:, :, 0],
195
+ kv[:, :, 1],
196
+ cache_seqlens=cache_seqlens,
197
+ softmax_scale=self.softmax_scale,
198
+ causal=self.causal,
199
+ )
200
+
201
+ def forward(self, x, inference_params=None):
202
+ """
203
+ Arguments:
204
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
+ is the is the sum of the sequence lengths in the batch.
207
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
+ """
210
+ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
+ )
214
+ seqlen_offset = (
215
+ 0
216
+ if inference_params is None
217
+ else (
218
+ inference_params.lengths_per_sample
219
+ if inference_params.lengths_per_sample is not None
220
+ else inference_params.seqlen_offset
221
+ )
222
+ )
223
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
+ qkv = self.in_proj(x)
225
+ if self.mlp_dim > 0:
226
+ qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
+ x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
+ x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
+ if self.d_conv > 0:
230
+ # The inference code for conv1d is pretty messy, should clean it up
231
+ if (inference_params is None or inference_params.seqlen_offset == 0):
232
+ if causal_conv1d_fn is None:
233
+ qkv = rearrange(
234
+ self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
+ ).contiguous()
236
+ else:
237
+ qkv = causal_conv1d_fn(
238
+ qkv.transpose(1, 2),
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias
241
+ ).transpose(1, 2)
242
+ if inference_params is not None:
243
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
+ # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
+ qkv_t = rearrange(qkv, "b l d -> b d l")
247
+ conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
+ else:
249
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
+ assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
+ qkv = qkv.squeeze(1)
252
+ # Conv step
253
+ if causal_conv1d_update is None:
254
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
+ conv_state[:, :, -1] = qkv
256
+ qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
+ if self.conv1d.bias is not None:
258
+ qkv = qkv + self.conv1d.bias
259
+ else:
260
+ qkv = causal_conv1d_update(
261
+ qkv,
262
+ conv_state,
263
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
+ self.conv1d.bias
265
+ )
266
+ qkv = qkv.unsqueeze(1)
267
+ q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
+ if (
271
+ inference_params is None
272
+ or inference_params.seqlen_offset == 0
273
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
+ ):
275
+ if self.rotary_emb_dim > 0:
276
+ q, kv = self.rotary_emb(
277
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
+ )
279
+ if inference_params is None:
280
+ k, v = kv.unbind(dim=-3)
281
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
+ context = F.scaled_dot_product_attention(
284
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
+ ).transpose(1, 2)
286
+ else:
287
+ context = self._update_kvcache_attention(q, kv, inference_params)
288
+ else:
289
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
+ context = rearrange(context, "... h d -> ... (h d)")
291
+ if self.mlp_dim > 0:
292
+ context = torch.cat([context, x_mlp], dim=-1)
293
+ out = self.out_proj(context)
294
+ return out
torch-ext/mamba_ssm/modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class GatedMLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ activation=F.silu,
13
+ bias=False,
14
+ multiple_of=128,
15
+ device=None,
16
+ dtype=None,
17
+ ):
18
+ factory_kwargs = {"device": device, "dtype": dtype}
19
+ super().__init__()
20
+ out_features = out_features if out_features is not None else in_features
21
+ hidden_features = (
22
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
+ )
24
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
+ self.activation = activation
27
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
+
29
+ def forward(self, x):
30
+ y = self.fc1(x)
31
+ y, gate = y.chunk(2, dim=-1)
32
+ y = y * self.activation(gate)
33
+ y = self.fc2(y)
34
+ return y
torch-ext/mamba_ssm/modules/ssd_minimal.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Albert Gu and Tri Dao.
2
+ """Minimal implementation of SSD.
3
+
4
+ This is the same as Listing 1 from the paper.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
12
+
13
+
14
+ def segsum_unstable(x):
15
+ """Naive segment sum calculation."""
16
+ T = x.size(-1)
17
+ x_cumsum = torch.cumsum(x, dim=-1)
18
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
19
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
20
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
21
+ return x_segsum
22
+
23
+
24
+ def segsum(x):
25
+ """More stable segment sum calculation."""
26
+ T = x.size(-1)
27
+ x = repeat(x, "... d -> ... d e", e=T)
28
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
29
+ x = x.masked_fill(~mask, 0)
30
+ x_segsum = torch.cumsum(x, dim=-2)
31
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
32
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
33
+ return x_segsum
34
+
35
+
36
+ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
37
+ """
38
+ Arguments:
39
+ X: (batch, length, n_heads, d_head)
40
+ A: (batch, length, n_heads)
41
+ B: (batch, length, n_heads, d_state)
42
+ C: (batch, length, n_heads, d_state)
43
+ Return:
44
+ Y: (batch, length, n_heads, d_head)
45
+ """
46
+ assert X.dtype == A.dtype == B.dtype == C.dtype
47
+ assert X.shape[1] % block_len == 0
48
+
49
+ # Rearrange into blocks/chunks
50
+ X, A, B, C = [
51
+ rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
52
+ ]
53
+
54
+ A = rearrange(A, "b c l h -> b h c l")
55
+ A_cumsum = torch.cumsum(A, dim=-1)
56
+
57
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
58
+ L = torch.exp(segsum(A))
59
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
60
+
61
+ # 2. Compute the state for each intra-chunk
62
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
63
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
64
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
65
+
66
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
67
+ # (middle term of factorization of off-diag blocks; A terms)
68
+ if initial_states is None:
69
+ initial_states = torch.zeros_like(states[:, :1])
70
+ states = torch.cat([initial_states, states], dim=1)
71
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
72
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
73
+ states, final_state = new_states[:, :-1], new_states[:, -1]
74
+
75
+ # 4. Compute state -> output conversion per chunk
76
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
77
+ state_decay_out = torch.exp(A_cumsum)
78
+ Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
79
+
80
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
81
+ Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
82
+ return Y, final_state
83
+
84
+
85
+ # Simple test
86
+ def test_correctness():
87
+ torch.manual_seed(42)
88
+
89
+ ## Dimensions
90
+ # Denoted (B, T, Q, D, P) in the paper
91
+ batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
92
+ nheads = dim // headdim # (H) in the paper
93
+ ngroups = 1 # (G) in the paper
94
+ dstate = 64 # (N) in the paper
95
+ dtype = torch.float32
96
+ device = "cuda"
97
+
98
+ x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
99
+ dt = F.softplus(
100
+ torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
101
+ ).requires_grad_()
102
+ A = (
103
+ -torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
104
+ ).requires_grad_()
105
+ B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
106
+ C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
107
+ D = torch.randn(nheads, dtype=dtype, device=device)
108
+
109
+ # Comparing fused version and minimal version
110
+ y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
111
+ y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
torch-ext/mamba_ssm/ops/__init__.py ADDED
File without changes
torch-ext/mamba_ssm/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from ..utils.torch import custom_fwd, custom_bwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ from .triton.layer_norm import _layer_norm_fwd
17
+
18
+ from .._ops import ops
19
+
20
+
21
+ class SelectiveScanFn(torch.autograd.Function):
22
+
23
+ @staticmethod
24
+ def forward(
25
+ ctx,
26
+ u,
27
+ delta,
28
+ A,
29
+ B,
30
+ C,
31
+ D=None,
32
+ z=None,
33
+ delta_bias=None,
34
+ delta_softplus=False,
35
+ return_last_state=False,
36
+ ):
37
+ if u.stride(-1) != 1:
38
+ u = u.contiguous()
39
+ if delta.stride(-1) != 1:
40
+ delta = delta.contiguous()
41
+ if D is not None:
42
+ D = D.contiguous()
43
+ if B.stride(-1) != 1:
44
+ B = B.contiguous()
45
+ if C.stride(-1) != 1:
46
+ C = C.contiguous()
47
+ if z is not None and z.stride(-1) != 1:
48
+ z = z.contiguous()
49
+ if B.dim() == 3:
50
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
51
+ ctx.squeeze_B = True
52
+ if C.dim() == 3:
53
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
54
+ ctx.squeeze_C = True
55
+ out, x, *rest = ops.selective_scan_fwd(
56
+ u, delta, A, B, C, D, z, delta_bias, delta_softplus
57
+ )
58
+ ctx.delta_softplus = delta_softplus
59
+ ctx.has_z = z is not None
60
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
61
+ if not ctx.has_z:
62
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
63
+ return out if not return_last_state else (out, last_state)
64
+ else:
65
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
66
+ out_z = rest[0]
67
+ return out_z if not return_last_state else (out_z, last_state)
68
+
69
+ @staticmethod
70
+ def backward(ctx, dout, *args):
71
+ if not ctx.has_z:
72
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
73
+ z = None
74
+ out = None
75
+ else:
76
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
77
+ if dout.stride(-1) != 1:
78
+ dout = dout.contiguous()
79
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
80
+ # backward of selective_scan_cuda with the backward of chunk).
81
+ # Here we just pass in None and dz will be allocated in the C++ code.
82
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
83
+ u,
84
+ delta,
85
+ A,
86
+ B,
87
+ C,
88
+ D,
89
+ z,
90
+ delta_bias,
91
+ dout,
92
+ x,
93
+ out,
94
+ None,
95
+ ctx.delta_softplus,
96
+ False, # option to recompute out_z, not used here
97
+ )
98
+ dz = rest[0] if ctx.has_z else None
99
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
100
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
101
+ return (
102
+ du,
103
+ ddelta,
104
+ dA,
105
+ dB,
106
+ dC,
107
+ dD if D is not None else None,
108
+ dz,
109
+ ddelta_bias if delta_bias is not None else None,
110
+ None,
111
+ None,
112
+ )
113
+
114
+
115
+ def rms_norm_forward(
116
+ x,
117
+ weight,
118
+ bias,
119
+ eps=1e-6,
120
+ is_rms_norm=True,
121
+ ):
122
+ # x (b l) d
123
+ if x.stride(-1) != 1:
124
+ x = x.contiguous()
125
+ weight = weight.contiguous()
126
+ if bias is not None:
127
+ bias = bias.contiguous()
128
+ y = _layer_norm_fwd(
129
+ x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
130
+ )[0]
131
+ # y (b l) d
132
+ return y
133
+
134
+
135
+ def selective_scan_fn(
136
+ u,
137
+ delta,
138
+ A,
139
+ B,
140
+ C,
141
+ D=None,
142
+ z=None,
143
+ delta_bias=None,
144
+ delta_softplus=False,
145
+ return_last_state=False,
146
+ ):
147
+ """if return_last_state is True, returns (out, last_state)
148
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
149
+ not considered in the backward pass.
150
+ """
151
+ return SelectiveScanFn.apply(
152
+ u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state
153
+ )
154
+
155
+
156
+ def selective_scan_ref(
157
+ u,
158
+ delta,
159
+ A,
160
+ B,
161
+ C,
162
+ D=None,
163
+ z=None,
164
+ delta_bias=None,
165
+ delta_softplus=False,
166
+ return_last_state=False,
167
+ ):
168
+ """
169
+ u: r(B D L)
170
+ delta: r(B D L)
171
+ A: c(D N) or r(D N)
172
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
173
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
174
+ D: r(D)
175
+ z: r(B D L)
176
+ delta_bias: r(D), fp32
177
+
178
+ out: r(B D L)
179
+ last_state (optional): r(B D dstate) or c(B D dstate)
180
+ """
181
+ dtype_in = u.dtype
182
+ u = u.float()
183
+ delta = delta.float()
184
+ if delta_bias is not None:
185
+ delta = delta + delta_bias[..., None].float()
186
+ if delta_softplus:
187
+ delta = F.softplus(delta)
188
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
189
+ is_variable_B = B.dim() >= 3
190
+ is_variable_C = C.dim() >= 3
191
+ if A.is_complex():
192
+ if is_variable_B:
193
+ B = torch.view_as_complex(
194
+ rearrange(B.float(), "... (L two) -> ... L two", two=2)
195
+ )
196
+ if is_variable_C:
197
+ C = torch.view_as_complex(
198
+ rearrange(C.float(), "... (L two) -> ... L two", two=2)
199
+ )
200
+ else:
201
+ B = B.float()
202
+ C = C.float()
203
+ x = A.new_zeros((batch, dim, dstate))
204
+ ys = []
205
+ deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
206
+ if not is_variable_B:
207
+ deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
208
+ else:
209
+ if B.dim() == 3:
210
+ deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
211
+ else:
212
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
213
+ deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
214
+ if is_variable_C and C.dim() == 4:
215
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
216
+ last_state = None
217
+ for i in range(u.shape[2]):
218
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
219
+ if not is_variable_C:
220
+ y = torch.einsum("bdn,dn->bd", x, C)
221
+ else:
222
+ if C.dim() == 3:
223
+ y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
224
+ else:
225
+ y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
226
+ if i == u.shape[2] - 1:
227
+ last_state = x
228
+ if y.is_complex():
229
+ y = y.real * 2
230
+ ys.append(y)
231
+ y = torch.stack(ys, dim=2) # (batch dim L)
232
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
233
+ if z is not None:
234
+ out = out * F.silu(z)
235
+ out = out.to(dtype=dtype_in)
236
+ return out if not return_last_state else (out, last_state)
237
+
238
+
239
+ class MambaInnerFn(torch.autograd.Function):
240
+
241
+ @staticmethod
242
+ @custom_fwd
243
+ def forward(
244
+ ctx,
245
+ xz,
246
+ conv1d_weight,
247
+ conv1d_bias,
248
+ x_proj_weight,
249
+ delta_proj_weight,
250
+ out_proj_weight,
251
+ out_proj_bias,
252
+ A,
253
+ B=None,
254
+ C=None,
255
+ D=None,
256
+ delta_bias=None,
257
+ B_proj_bias=None,
258
+ C_proj_bias=None,
259
+ delta_softplus=True,
260
+ checkpoint_lvl=1,
261
+ b_rms_weight=None,
262
+ c_rms_weight=None,
263
+ dt_rms_weight=None,
264
+ b_c_dt_rms_eps=1e-6,
265
+ ):
266
+ """
267
+ xz: (batch, dim, seqlen)
268
+ """
269
+ assert (
270
+ causal_conv1d_cuda is not None
271
+ ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
272
+ assert checkpoint_lvl in [0, 1]
273
+ L = xz.shape[-1]
274
+ delta_rank = delta_proj_weight.shape[1]
275
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
276
+ if torch.is_autocast_enabled():
277
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
278
+ delta_proj_weight = delta_proj_weight.to(
279
+ dtype=torch.get_autocast_gpu_dtype()
280
+ )
281
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
282
+ out_proj_bias = (
283
+ out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
284
+ if out_proj_bias is not None
285
+ else None
286
+ )
287
+ if xz.stride(-1) != 1:
288
+ xz = xz.contiguous()
289
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
290
+ x, z = xz.chunk(2, dim=1)
291
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
292
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
293
+ x, conv1d_weight, conv1d_bias, None, None, None, True
294
+ )
295
+ # We're being very careful here about the layout, to avoid extra transposes.
296
+ # We want delta to have d as the slowest moving dimension
297
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
298
+ x_dbl = F.linear(
299
+ rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight
300
+ ) # (bl d)
301
+ delta = rearrange(
302
+ delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
303
+ )
304
+ ctx.is_variable_B = B is None
305
+ ctx.is_variable_C = C is None
306
+ ctx.B_proj_bias_is_None = B_proj_bias is None
307
+ ctx.C_proj_bias_is_None = C_proj_bias is None
308
+ if B is None: # variable B
309
+ B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
310
+ if B_proj_bias is not None:
311
+ B = B + B_proj_bias.to(dtype=B.dtype)
312
+ if not A.is_complex():
313
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
314
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
315
+ else:
316
+ B = rearrange(
317
+ B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
318
+ ).contiguous()
319
+ else:
320
+ if B.stride(-1) != 1:
321
+ B = B.contiguous()
322
+ if C is None: # variable C
323
+ C = x_dbl[:, -d_state:] # (bl dstate)
324
+ if C_proj_bias is not None:
325
+ C = C + C_proj_bias.to(dtype=C.dtype)
326
+ if not A.is_complex():
327
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
328
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
+ else:
330
+ C = rearrange(
331
+ C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2
332
+ ).contiguous()
333
+ else:
334
+ if C.stride(-1) != 1:
335
+ C = C.contiguous()
336
+ if D is not None:
337
+ D = D.contiguous()
338
+
339
+ if b_rms_weight is not None:
340
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
341
+ B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
342
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
343
+ if c_rms_weight is not None:
344
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
345
+ C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
346
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
347
+ if dt_rms_weight is not None:
348
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
349
+ delta = rms_norm_forward(
350
+ delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps
351
+ )
352
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
353
+
354
+ out, scan_intermediates, out_z = ops.selective_scan_fwd(
355
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
356
+ )
357
+ ctx.delta_softplus = delta_softplus
358
+ ctx.out_proj_bias_is_None = out_proj_bias is None
359
+ ctx.checkpoint_lvl = checkpoint_lvl
360
+ ctx.b_rms_weight = b_rms_weight
361
+ ctx.c_rms_weight = c_rms_weight
362
+ ctx.dt_rms_weight = dt_rms_weight
363
+ ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
364
+ if (
365
+ checkpoint_lvl >= 1
366
+ ): # Will recompute conv1d_out and delta in the backward pass
367
+ conv1d_out, delta = None, None
368
+ ctx.save_for_backward(
369
+ xz,
370
+ conv1d_weight,
371
+ conv1d_bias,
372
+ x_dbl,
373
+ x_proj_weight,
374
+ delta_proj_weight,
375
+ out_proj_weight,
376
+ conv1d_out,
377
+ delta,
378
+ A,
379
+ B,
380
+ C,
381
+ D,
382
+ delta_bias,
383
+ scan_intermediates,
384
+ b_rms_weight,
385
+ c_rms_weight,
386
+ dt_rms_weight,
387
+ out,
388
+ )
389
+ return F.linear(
390
+ rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias
391
+ )
392
+
393
+ @staticmethod
394
+ @custom_bwd
395
+ def backward(ctx, dout):
396
+ # dout: (batch, seqlen, dim)
397
+ assert (
398
+ causal_conv1d_cuda is not None
399
+ ), "causal_conv1d_cuda is not available. Please install causal-conv1d."
400
+ (
401
+ xz,
402
+ conv1d_weight,
403
+ conv1d_bias,
404
+ x_dbl,
405
+ x_proj_weight,
406
+ delta_proj_weight,
407
+ out_proj_weight,
408
+ conv1d_out,
409
+ delta,
410
+ A,
411
+ B,
412
+ C,
413
+ D,
414
+ delta_bias,
415
+ scan_intermediates,
416
+ b_rms_weight,
417
+ c_rms_weight,
418
+ dt_rms_weight,
419
+ out,
420
+ ) = ctx.saved_tensors
421
+ L = xz.shape[-1]
422
+ delta_rank = delta_proj_weight.shape[1]
423
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
424
+ x, z = xz.chunk(2, dim=1)
425
+ if dout.stride(-1) != 1:
426
+ dout = dout.contiguous()
427
+ if ctx.checkpoint_lvl == 1:
428
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
429
+ x, conv1d_weight, conv1d_bias, None, None, None, True
430
+ )
431
+ delta = rearrange(
432
+ delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L
433
+ )
434
+ if dt_rms_weight is not None:
435
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
436
+ delta = rms_norm_forward(
437
+ delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps
438
+ )
439
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
440
+ if b_rms_weight is not None:
441
+ # Recompute & RMSNorm B
442
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
443
+ B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
444
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
445
+ if c_rms_weight is not None:
446
+ # Recompute & RMSNorm C
447
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
448
+ C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
449
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
450
+
451
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
452
+ # backward of selective_scan_cuda with the backward of chunk).
453
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
454
+ dx, dz = dxz.chunk(2, dim=1)
455
+ dout = rearrange(dout, "b l e -> e (b l)")
456
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
457
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
458
+ ops.selective_scan_bwd(
459
+ conv1d_out,
460
+ delta,
461
+ A,
462
+ B,
463
+ C,
464
+ D,
465
+ z,
466
+ delta_bias,
467
+ dout_y,
468
+ scan_intermediates,
469
+ out,
470
+ dz,
471
+ ctx.delta_softplus,
472
+ True, # option to recompute out_z
473
+ )
474
+ )
475
+ dout_proj_weight = torch.einsum(
476
+ "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
477
+ )
478
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
479
+ dD = dD if D is not None else None
480
+ dx_dbl = torch.empty_like(x_dbl)
481
+ dB_proj_bias = None
482
+ if ctx.is_variable_B:
483
+ if not A.is_complex():
484
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
485
+ else:
486
+ dB = rearrange(
487
+ dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
488
+ ).contiguous()
489
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
490
+ dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
491
+ dB = None
492
+ dC_proj_bias = None
493
+ if ctx.is_variable_C:
494
+ if not A.is_complex():
495
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
496
+ else:
497
+ dC = rearrange(
498
+ dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2
499
+ ).contiguous()
500
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
501
+ dx_dbl[:, -d_state:] = dC # (bl d)
502
+ dC = None
503
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
504
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
505
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
506
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
507
+ dx_proj_weight = torch.einsum(
508
+ "Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")
509
+ )
510
+ dconv1d_out = torch.addmm(
511
+ dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out
512
+ )
513
+ dconv1d_out = rearrange(
514
+ dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]
515
+ )
516
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
517
+ # backward of conv1d with the backward of chunk).
518
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
519
+ x,
520
+ conv1d_weight,
521
+ conv1d_bias,
522
+ dconv1d_out,
523
+ None,
524
+ None,
525
+ None,
526
+ dx,
527
+ False,
528
+ True,
529
+ )
530
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
531
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
532
+ return (
533
+ dxz,
534
+ dconv1d_weight,
535
+ dconv1d_bias,
536
+ dx_proj_weight,
537
+ ddelta_proj_weight,
538
+ dout_proj_weight,
539
+ dout_proj_bias,
540
+ dA,
541
+ dB,
542
+ dC,
543
+ dD,
544
+ ddelta_bias if delta_bias is not None else None,
545
+ # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
546
+ dB_proj_bias,
547
+ dC_proj_bias,
548
+ None,
549
+ None,
550
+ None,
551
+ None,
552
+ None,
553
+ None,
554
+ )
555
+
556
+
557
+ def mamba_inner_fn(
558
+ xz,
559
+ conv1d_weight,
560
+ conv1d_bias,
561
+ x_proj_weight,
562
+ delta_proj_weight,
563
+ out_proj_weight,
564
+ out_proj_bias,
565
+ A,
566
+ B=None,
567
+ C=None,
568
+ D=None,
569
+ delta_bias=None,
570
+ B_proj_bias=None,
571
+ C_proj_bias=None,
572
+ delta_softplus=True,
573
+ checkpoint_lvl=1,
574
+ b_rms_weight=None,
575
+ c_rms_weight=None,
576
+ dt_rms_weight=None,
577
+ b_c_dt_rms_eps=1e-6,
578
+ ):
579
+ return MambaInnerFn.apply(
580
+ xz,
581
+ conv1d_weight,
582
+ conv1d_bias,
583
+ x_proj_weight,
584
+ delta_proj_weight,
585
+ out_proj_weight,
586
+ out_proj_bias,
587
+ A,
588
+ B,
589
+ C,
590
+ D,
591
+ delta_bias,
592
+ B_proj_bias,
593
+ C_proj_bias,
594
+ delta_softplus,
595
+ checkpoint_lvl,
596
+ b_rms_weight,
597
+ c_rms_weight,
598
+ dt_rms_weight,
599
+ b_c_dt_rms_eps,
600
+ )
601
+
602
+
603
+ def mamba_inner_ref(
604
+ xz,
605
+ conv1d_weight,
606
+ conv1d_bias,
607
+ x_proj_weight,
608
+ delta_proj_weight,
609
+ out_proj_weight,
610
+ out_proj_bias,
611
+ A,
612
+ B=None,
613
+ C=None,
614
+ D=None,
615
+ delta_bias=None,
616
+ B_proj_bias=None,
617
+ C_proj_bias=None,
618
+ delta_softplus=True,
619
+ ):
620
+ assert (
621
+ causal_conv1d_fn is not None
622
+ ), "causal_conv1d_fn is not available. Please install causal-conv1d."
623
+ L = xz.shape[-1]
624
+ delta_rank = delta_proj_weight.shape[1]
625
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
626
+ x, z = xz.chunk(2, dim=1)
627
+ x = causal_conv1d_fn(
628
+ x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu"
629
+ )
630
+ # We're being very careful here about the layout, to avoid extra transposes.
631
+ # We want delta to have d as the slowest moving dimension
632
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
633
+ x_dbl = F.linear(rearrange(x, "b d l -> (b l) d"), x_proj_weight) # (bl d)
634
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
635
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
636
+ if B is None: # variable B
637
+ B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl d)
638
+ if B_proj_bias is not None:
639
+ B = B + B_proj_bias.to(dtype=B.dtype)
640
+ if not A.is_complex():
641
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
642
+ else:
643
+ B = rearrange(
644
+ B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
645
+ ).contiguous()
646
+ if C is None: # variable B
647
+ C = x_dbl[:, -d_state:] # (bl d)
648
+ if C_proj_bias is not None:
649
+ C = C + C_proj_bias.to(dtype=C.dtype)
650
+ if not A.is_complex():
651
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
652
+ else:
653
+ C = rearrange(
654
+ C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2
655
+ ).contiguous()
656
+ y = selective_scan_fn(
657
+ x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True
658
+ )
659
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
torch-ext/mamba_ssm/ops/triton/__init__.py ADDED
File without changes
torch-ext/mamba_ssm/ops/triton/k_activations.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_N': 32}),
12
+ triton.Config({'BLOCK_N': 64}),
13
+ triton.Config({'BLOCK_N': 128}),
14
+ triton.Config({'BLOCK_N': 256}),
15
+ triton.Config({'BLOCK_N': 512}),
16
+ triton.Config({'BLOCK_N': 1024}),
17
+ ],
18
+ key=['ncols'],
19
+ )
20
+ @triton.jit
21
+ def _swiglu_fwd_kernel(
22
+ X,
23
+ Y,
24
+ OUT,
25
+ stride_x_row, # how much to increase the pointer when moving by 1 row
26
+ stride_y_row,
27
+ stride_out_row,
28
+ ncols,
29
+ BLOCK_N: tl.constexpr,
30
+ ):
31
+ # Map the program id to the row of X and Y it should compute.
32
+ row = tl.program_id(0)
33
+ start_col = tl.program_id(1) * BLOCK_N
34
+ X += row * stride_x_row
35
+ Y += row * stride_y_row
36
+ OUT += row * stride_out_row
37
+ cols = start_col + tl.arange(0, BLOCK_N)
38
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
39
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
40
+ out = x * tl.sigmoid(x) * y
41
+ tl.store(OUT + cols, out, mask=cols < ncols)
42
+
43
+
44
+ def _swiglu_fwd(xy, out=None):
45
+ if xy.stride(-1) != 1:
46
+ xy = xy.contiguous()
47
+ batch_shape = xy.shape[:-1]
48
+ xy = xy.reshape(-1, xy.shape[-1])
49
+ x, y = xy.chunk(2, dim=-1)
50
+ if out is None:
51
+ out = torch.empty_like(x)
52
+ else:
53
+ out = out.reshape(-1, out.shape[-1])
54
+ assert out.shape == x.shape
55
+ assert out.stride(-1) == 1
56
+ M, N = x.shape
57
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
58
+ with torch.cuda.device(x.device.index):
59
+ _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
60
+ return out.reshape(*batch_shape, out.shape[-1])
61
+
62
+
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BLOCK_N': 32}),
66
+ triton.Config({'BLOCK_N': 64}),
67
+ triton.Config({'BLOCK_N': 128}),
68
+ triton.Config({'BLOCK_N': 256}),
69
+ triton.Config({'BLOCK_N': 512}),
70
+ triton.Config({'BLOCK_N': 1024}),
71
+ ],
72
+ key=['ncols'],
73
+ )
74
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
75
+ @triton.jit
76
+ def _swiglu_bwd_kernel(
77
+ X,
78
+ Y,
79
+ DOUT,
80
+ OUT,
81
+ DX,
82
+ DY,
83
+ stride_x_row, # how much to increase the pointer when moving by 1 row
84
+ stride_y_row,
85
+ stride_dout_row,
86
+ stride_out_row,
87
+ stride_dx_row,
88
+ stride_dy_row,
89
+ ncols,
90
+ BLOCK_N: tl.constexpr,
91
+ RECOMPUTE_OUTPUT: tl.constexpr,
92
+ ):
93
+ # Map the program id to the row of X and Y it should compute.
94
+ row = tl.program_id(0)
95
+ start_col = tl.program_id(1) * BLOCK_N
96
+ X += row * stride_x_row
97
+ Y += row * stride_y_row
98
+ DOUT += row * stride_dout_row
99
+ if RECOMPUTE_OUTPUT:
100
+ OUT += row * stride_out_row
101
+ DX += row * stride_dx_row
102
+ DY += row * stride_dy_row
103
+ cols = start_col + tl.arange(0, BLOCK_N)
104
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
105
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
106
+ dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
107
+ x_sigmoid = tl.sigmoid(x)
108
+ dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
109
+ dy = x * x_sigmoid * dout
110
+ tl.store(DX + cols, dx, mask=cols < ncols)
111
+ tl.store(DY + cols, dy, mask=cols < ncols)
112
+ if RECOMPUTE_OUTPUT:
113
+ out = x * x_sigmoid * y
114
+ tl.store(OUT + cols, out, mask=cols < ncols)
115
+
116
+
117
+ def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
118
+ if xy.stride(-1) != 1:
119
+ xy = xy.contiguous()
120
+ if dout.stride(-1) != 1:
121
+ dout = dout.contiguous()
122
+ batch_shape = xy.shape[:-1]
123
+ xy = xy.reshape(-1, xy.shape[-1])
124
+ x, y = xy.chunk(2, dim=-1)
125
+ dout = dout.reshape(-1, dout.shape[-1])
126
+ assert dout.shape == x.shape
127
+ if dxy is None:
128
+ dxy = torch.empty_like(xy)
129
+ else:
130
+ dxy = dxy.reshape(-1, dxy.shape[-1])
131
+ assert dxy.shape == xy.shape
132
+ dx, dy = dxy.chunk(2, dim=-1)
133
+ assert dx.stride(-1) == 1
134
+ assert dy.stride(-1) == 1
135
+ if recompute_output:
136
+ if out is None:
137
+ out = torch.empty_like(x)
138
+ else:
139
+ out = out.reshape(-1, out.shape[-1])
140
+ assert out.shape == x.shape
141
+ assert out.stride(-1) == 1
142
+ M, N = x.shape
143
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
144
+ with torch.cuda.device(x.device.index):
145
+ _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
146
+ x.stride(0), y.stride(0), dout.stride(0),
147
+ out.stride(0) if recompute_output else 0,
148
+ dx.stride(0), dy.stride(0),
149
+ N)
150
+ if not recompute_output:
151
+ return dxy.reshape(*batch_shape, dxy.shape[-1])
152
+ else:
153
+ return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
154
+
155
+
156
+ class SwiGLU(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ def forward(ctx, xy):
160
+ ctx.save_for_backward(xy)
161
+ return _swiglu_fwd(xy)
162
+
163
+ @staticmethod
164
+ def backward(ctx, dout):
165
+ xy, = ctx.saved_tensors
166
+ return _swiglu_bwd(xy, dout)
167
+
168
+
169
+ swiglu = SwiGLU.apply
torch-ext/mamba_ssm/ops/triton/layer_norm.py ADDED
@@ -0,0 +1,1166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from ...utils.torch import custom_bwd, custom_fwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+
20
+ def layer_norm_ref(
21
+ x,
22
+ weight,
23
+ bias,
24
+ residual=None,
25
+ x1=None,
26
+ weight1=None,
27
+ bias1=None,
28
+ eps=1e-6,
29
+ dropout_p=0.0,
30
+ rowscale=None,
31
+ prenorm=False,
32
+ dropout_mask=None,
33
+ dropout_mask1=None,
34
+ upcast=False,
35
+ ):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ x = x.float()
39
+ weight = weight.float()
40
+ bias = bias.float() if bias is not None else None
41
+ residual = residual.float() if residual is not None else residual
42
+ x1 = x1.float() if x1 is not None else None
43
+ weight1 = weight1.float() if weight1 is not None else None
44
+ bias1 = bias1.float() if bias1 is not None else None
45
+ if x1 is not None:
46
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
+ if rowscale is not None:
48
+ x = x * rowscale[..., None]
49
+ if dropout_p > 0.0:
50
+ if dropout_mask is not None:
51
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
+ else:
53
+ x = F.dropout(x, p=dropout_p)
54
+ if x1 is not None:
55
+ if dropout_mask1 is not None:
56
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
+ else:
58
+ x1 = F.dropout(x1, p=dropout_p)
59
+ if x1 is not None:
60
+ x = x + x1
61
+ if residual is not None:
62
+ x = (x + residual).to(x.dtype)
63
+ out = F.layer_norm(
64
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
65
+ ).to(dtype)
66
+ if weight1 is None:
67
+ return out if not prenorm else (out, x)
68
+ else:
69
+ out1 = F.layer_norm(
70
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
+ ).to(dtype)
72
+ return (out, out1) if not prenorm else (out, out1, x)
73
+
74
+
75
+ def rms_norm_ref(
76
+ x,
77
+ weight,
78
+ bias,
79
+ residual=None,
80
+ x1=None,
81
+ weight1=None,
82
+ bias1=None,
83
+ eps=1e-6,
84
+ dropout_p=0.0,
85
+ rowscale=None,
86
+ prenorm=False,
87
+ dropout_mask=None,
88
+ dropout_mask1=None,
89
+ upcast=False,
90
+ ):
91
+ dtype = x.dtype
92
+ if upcast:
93
+ x = x.float()
94
+ weight = weight.float()
95
+ bias = bias.float() if bias is not None else None
96
+ residual = residual.float() if residual is not None else residual
97
+ x1 = x1.float() if x1 is not None else None
98
+ weight1 = weight1.float() if weight1 is not None else None
99
+ bias1 = bias1.float() if bias1 is not None else None
100
+ if x1 is not None:
101
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
+ if rowscale is not None:
103
+ x = x * rowscale[..., None]
104
+ if dropout_p > 0.0:
105
+ if dropout_mask is not None:
106
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
+ else:
108
+ x = F.dropout(x, p=dropout_p)
109
+ if x1 is not None:
110
+ if dropout_mask1 is not None:
111
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
+ else:
113
+ x1 = F.dropout(x1, p=dropout_p)
114
+ if x1 is not None:
115
+ x = x + x1
116
+ if residual is not None:
117
+ x = (x + residual).to(x.dtype)
118
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
120
+ dtype
121
+ )
122
+ if weight1 is None:
123
+ return out if not prenorm else (out, x)
124
+ else:
125
+ out1 = (
126
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
127
+ ).to(dtype)
128
+ return (out, out1) if not prenorm else (out, out1, x)
129
+
130
+
131
+ def config_prune(configs):
132
+
133
+ if torch.version.hip:
134
+ try:
135
+ # set warp size based on gcn architecure
136
+ gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
137
+ if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
138
+ # radeon
139
+ warp_size = 32
140
+ else:
141
+ # instinct
142
+ warp_size = 64
143
+ except AttributeError as e:
144
+ # fall back to crude method to set warp size
145
+ device_name = torch.cuda.get_device_properties(0).name
146
+ if "instinct" in device_name.lower():
147
+ warp_size = 64
148
+ else:
149
+ warp_size = 32
150
+ warnings.warn(
151
+ f"{e}, warp size set to {warp_size} based on device name: {device_name}",
152
+ UserWarning,
153
+ )
154
+
155
+ else:
156
+ # cuda
157
+ warp_size = 32
158
+
159
+ max_block_sz = 1024
160
+ max_num_warps = max_block_sz // warp_size
161
+ pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
162
+ return pruned_configs
163
+
164
+
165
+ configs_autotune = [
166
+ triton.Config({}, num_warps=1),
167
+ triton.Config({}, num_warps=2),
168
+ triton.Config({}, num_warps=4),
169
+ triton.Config({}, num_warps=8),
170
+ triton.Config({}, num_warps=16),
171
+ triton.Config({}, num_warps=32),
172
+ ]
173
+
174
+ pruned_configs_autotune = config_prune(configs_autotune)
175
+
176
+
177
+ @triton.autotune(
178
+ configs=pruned_configs_autotune,
179
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
180
+ )
181
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
182
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
183
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
184
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
185
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
186
+ @triton.jit
187
+ def _layer_norm_fwd_1pass_kernel(
188
+ X, # pointer to the input
189
+ Y, # pointer to the output
190
+ W, # pointer to the weights
191
+ B, # pointer to the biases
192
+ RESIDUAL, # pointer to the residual
193
+ X1,
194
+ W1,
195
+ B1,
196
+ Y1,
197
+ RESIDUAL_OUT, # pointer to the residual
198
+ ROWSCALE,
199
+ SEEDS, # Dropout seeds for each row
200
+ DROPOUT_MASK,
201
+ Mean, # pointer to the mean
202
+ Rstd, # pointer to the 1/std
203
+ stride_x_row, # how much to increase the pointer when moving by 1 row
204
+ stride_y_row,
205
+ stride_res_row,
206
+ stride_res_out_row,
207
+ stride_x1_row,
208
+ stride_y1_row,
209
+ M, # number of rows in X
210
+ N, # number of columns in X
211
+ eps, # epsilon to avoid division by zero
212
+ dropout_p, # Dropout probability
213
+ IS_RMS_NORM: tl.constexpr,
214
+ BLOCK_N: tl.constexpr,
215
+ HAS_RESIDUAL: tl.constexpr,
216
+ STORE_RESIDUAL_OUT: tl.constexpr,
217
+ HAS_BIAS: tl.constexpr,
218
+ HAS_DROPOUT: tl.constexpr,
219
+ STORE_DROPOUT_MASK: tl.constexpr,
220
+ HAS_ROWSCALE: tl.constexpr,
221
+ HAS_X1: tl.constexpr,
222
+ HAS_W1: tl.constexpr,
223
+ HAS_B1: tl.constexpr,
224
+ ):
225
+ # Map the program id to the row of X and Y it should compute.
226
+ row = tl.program_id(0)
227
+ X += row * stride_x_row
228
+ Y += row * stride_y_row
229
+ if HAS_RESIDUAL:
230
+ RESIDUAL += row * stride_res_row
231
+ if STORE_RESIDUAL_OUT:
232
+ RESIDUAL_OUT += row * stride_res_out_row
233
+ if HAS_X1:
234
+ X1 += row * stride_x1_row
235
+ if HAS_W1:
236
+ Y1 += row * stride_y1_row
237
+ # Compute mean and variance
238
+ cols = tl.arange(0, BLOCK_N)
239
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
240
+ if HAS_ROWSCALE:
241
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
242
+ x *= rowscale
243
+ if HAS_DROPOUT:
244
+ # Compute dropout mask
245
+ # 7 rounds is good enough, and reduces register pressure
246
+ keep_mask = (
247
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
248
+ )
249
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
250
+ if STORE_DROPOUT_MASK:
251
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
252
+ if HAS_X1:
253
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
254
+ if HAS_ROWSCALE:
255
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
256
+ x1 *= rowscale
257
+ if HAS_DROPOUT:
258
+ # Compute dropout mask
259
+ # 7 rounds is good enough, and reduces register pressure
260
+ keep_mask = (
261
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
262
+ > dropout_p
263
+ )
264
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
265
+ if STORE_DROPOUT_MASK:
266
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
267
+ x += x1
268
+ if HAS_RESIDUAL:
269
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
270
+ x += residual
271
+ if STORE_RESIDUAL_OUT:
272
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
273
+ if not IS_RMS_NORM:
274
+ mean = tl.sum(x, axis=0) / N
275
+ tl.store(Mean + row, mean)
276
+ xbar = tl.where(cols < N, x - mean, 0.0)
277
+ var = tl.sum(xbar * xbar, axis=0) / N
278
+ else:
279
+ xbar = tl.where(cols < N, x, 0.0)
280
+ var = tl.sum(xbar * xbar, axis=0) / N
281
+ rstd = 1 / tl.sqrt(var + eps)
282
+ tl.store(Rstd + row, rstd)
283
+ # Normalize and apply linear transformation
284
+ mask = cols < N
285
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
286
+ if HAS_BIAS:
287
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
288
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
289
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
290
+ # Write output
291
+ tl.store(Y + cols, y, mask=mask)
292
+ if HAS_W1:
293
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
294
+ if HAS_B1:
295
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
296
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
297
+ tl.store(Y1 + cols, y1, mask=mask)
298
+
299
+
300
+ def _layer_norm_fwd(
301
+ x,
302
+ weight,
303
+ bias,
304
+ eps,
305
+ residual=None,
306
+ x1=None,
307
+ weight1=None,
308
+ bias1=None,
309
+ dropout_p=0.0,
310
+ rowscale=None,
311
+ out_dtype=None,
312
+ residual_dtype=None,
313
+ is_rms_norm=False,
314
+ return_dropout_mask=False,
315
+ ):
316
+ if residual is not None:
317
+ residual_dtype = residual.dtype
318
+ M, N = x.shape
319
+ assert x.stride(-1) == 1
320
+ if residual is not None:
321
+ assert residual.stride(-1) == 1
322
+ assert residual.shape == (M, N)
323
+ assert weight.shape == (N,)
324
+ assert weight.stride(-1) == 1
325
+ if bias is not None:
326
+ assert bias.stride(-1) == 1
327
+ assert bias.shape == (N,)
328
+ if x1 is not None:
329
+ assert x1.shape == x.shape
330
+ assert rowscale is None
331
+ assert x1.stride(-1) == 1
332
+ if weight1 is not None:
333
+ assert weight1.shape == (N,)
334
+ assert weight1.stride(-1) == 1
335
+ if bias1 is not None:
336
+ assert bias1.shape == (N,)
337
+ assert bias1.stride(-1) == 1
338
+ if rowscale is not None:
339
+ assert rowscale.is_contiguous()
340
+ assert rowscale.shape == (M,)
341
+ # allocate output
342
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
343
+ assert y.stride(-1) == 1
344
+ if weight1 is not None:
345
+ y1 = torch.empty_like(y)
346
+ assert y1.stride(-1) == 1
347
+ else:
348
+ y1 = None
349
+ if (
350
+ residual is not None
351
+ or (residual_dtype is not None and residual_dtype != x.dtype)
352
+ or dropout_p > 0.0
353
+ or rowscale is not None
354
+ or x1 is not None
355
+ ):
356
+ residual_out = torch.empty(
357
+ M,
358
+ N,
359
+ device=x.device,
360
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
361
+ )
362
+ assert residual_out.stride(-1) == 1
363
+ else:
364
+ residual_out = None
365
+ mean = (
366
+ torch.empty((M,), dtype=torch.float32, device=x.device)
367
+ if not is_rms_norm
368
+ else None
369
+ )
370
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
371
+ if dropout_p > 0.0:
372
+ seeds = torch.randint(
373
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
374
+ )
375
+ else:
376
+ seeds = None
377
+ if return_dropout_mask and dropout_p > 0.0:
378
+ dropout_mask = torch.empty(
379
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
380
+ )
381
+ else:
382
+ dropout_mask = None
383
+ # Less than 64KB per feature: enqueue fused kernel
384
+ MAX_FUSED_SIZE = 65536 // x.element_size()
385
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
386
+ if N > BLOCK_N:
387
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
388
+ with torch.cuda.device(x.device.index):
389
+ _layer_norm_fwd_1pass_kernel[(M,)](
390
+ x,
391
+ y,
392
+ weight,
393
+ bias,
394
+ residual,
395
+ x1,
396
+ weight1,
397
+ bias1,
398
+ y1,
399
+ residual_out,
400
+ rowscale,
401
+ seeds,
402
+ dropout_mask,
403
+ mean,
404
+ rstd,
405
+ x.stride(0),
406
+ y.stride(0),
407
+ residual.stride(0) if residual is not None else 0,
408
+ residual_out.stride(0) if residual_out is not None else 0,
409
+ x1.stride(0) if x1 is not None else 0,
410
+ y1.stride(0) if y1 is not None else 0,
411
+ M,
412
+ N,
413
+ eps,
414
+ dropout_p,
415
+ is_rms_norm,
416
+ BLOCK_N,
417
+ residual is not None,
418
+ residual_out is not None,
419
+ bias is not None,
420
+ dropout_p > 0.0,
421
+ dropout_mask is not None,
422
+ rowscale is not None,
423
+ )
424
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
425
+ if dropout_mask is not None and x1 is not None:
426
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
427
+ else:
428
+ dropout_mask1 = None
429
+ return (
430
+ y,
431
+ y1,
432
+ mean,
433
+ rstd,
434
+ residual_out if residual_out is not None else x,
435
+ seeds,
436
+ dropout_mask,
437
+ dropout_mask1,
438
+ )
439
+
440
+
441
+ @triton.autotune(
442
+ configs=pruned_configs_autotune,
443
+ key=[
444
+ "N",
445
+ "HAS_DRESIDUAL",
446
+ "STORE_DRESIDUAL",
447
+ "IS_RMS_NORM",
448
+ "HAS_BIAS",
449
+ "HAS_DROPOUT",
450
+ ],
451
+ )
452
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
453
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
454
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
455
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
456
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
457
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
458
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
459
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
460
+ @triton.jit
461
+ def _layer_norm_bwd_kernel(
462
+ X, # pointer to the input
463
+ W, # pointer to the weights
464
+ B, # pointer to the biases
465
+ Y, # pointer to the output to be recomputed
466
+ DY, # pointer to the output gradient
467
+ DX, # pointer to the input gradient
468
+ DW, # pointer to the partial sum of weights gradient
469
+ DB, # pointer to the partial sum of biases gradient
470
+ DRESIDUAL,
471
+ W1,
472
+ DY1,
473
+ DX1,
474
+ DW1,
475
+ DB1,
476
+ DRESIDUAL_IN,
477
+ ROWSCALE,
478
+ SEEDS,
479
+ Mean, # pointer to the mean
480
+ Rstd, # pointer to the 1/std
481
+ stride_x_row, # how much to increase the pointer when moving by 1 row
482
+ stride_y_row,
483
+ stride_dy_row,
484
+ stride_dx_row,
485
+ stride_dres_row,
486
+ stride_dy1_row,
487
+ stride_dx1_row,
488
+ stride_dres_in_row,
489
+ M, # number of rows in X
490
+ N, # number of columns in X
491
+ eps, # epsilon to avoid division by zero
492
+ dropout_p,
493
+ rows_per_program,
494
+ IS_RMS_NORM: tl.constexpr,
495
+ BLOCK_N: tl.constexpr,
496
+ HAS_DRESIDUAL: tl.constexpr,
497
+ STORE_DRESIDUAL: tl.constexpr,
498
+ HAS_BIAS: tl.constexpr,
499
+ HAS_DROPOUT: tl.constexpr,
500
+ HAS_ROWSCALE: tl.constexpr,
501
+ HAS_DY1: tl.constexpr,
502
+ HAS_DX1: tl.constexpr,
503
+ HAS_B1: tl.constexpr,
504
+ RECOMPUTE_OUTPUT: tl.constexpr,
505
+ ):
506
+ # Map the program id to the elements of X, DX, and DY it should compute.
507
+ row_block_id = tl.program_id(0)
508
+ row_start = row_block_id * rows_per_program
509
+ # Do not early exit if row_start >= M, because we need to write DW and DB
510
+ cols = tl.arange(0, BLOCK_N)
511
+ mask = cols < N
512
+ X += row_start * stride_x_row
513
+ if HAS_DRESIDUAL:
514
+ DRESIDUAL += row_start * stride_dres_row
515
+ if STORE_DRESIDUAL:
516
+ DRESIDUAL_IN += row_start * stride_dres_in_row
517
+ DY += row_start * stride_dy_row
518
+ DX += row_start * stride_dx_row
519
+ if HAS_DY1:
520
+ DY1 += row_start * stride_dy1_row
521
+ if HAS_DX1:
522
+ DX1 += row_start * stride_dx1_row
523
+ if RECOMPUTE_OUTPUT:
524
+ Y += row_start * stride_y_row
525
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
526
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
527
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
528
+ if HAS_DY1:
529
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
530
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
531
+ if HAS_BIAS:
532
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
533
+ if HAS_DY1:
534
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
535
+ if HAS_B1:
536
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
537
+ row_end = min((row_block_id + 1) * rows_per_program, M)
538
+ for row in range(row_start, row_end):
539
+ # Load data to SRAM
540
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
541
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
542
+ if HAS_DY1:
543
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
544
+ if not IS_RMS_NORM:
545
+ mean = tl.load(Mean + row)
546
+ rstd = tl.load(Rstd + row)
547
+ # Compute dx
548
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
549
+ xhat = tl.where(mask, xhat, 0.0)
550
+ if RECOMPUTE_OUTPUT:
551
+ y = xhat * w + b if HAS_BIAS else xhat * w
552
+ tl.store(Y + cols, y, mask=mask)
553
+ wdy = w * dy
554
+ dw += dy * xhat
555
+ if HAS_BIAS:
556
+ db += dy
557
+ if HAS_DY1:
558
+ wdy += w1 * dy1
559
+ dw1 += dy1 * xhat
560
+ if HAS_B1:
561
+ db1 += dy1
562
+ if not IS_RMS_NORM:
563
+ c1 = tl.sum(xhat * wdy, axis=0) / N
564
+ c2 = tl.sum(wdy, axis=0) / N
565
+ dx = (wdy - (xhat * c1 + c2)) * rstd
566
+ else:
567
+ c1 = tl.sum(xhat * wdy, axis=0) / N
568
+ dx = (wdy - xhat * c1) * rstd
569
+ if HAS_DRESIDUAL:
570
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
571
+ dx += dres
572
+ # Write dx
573
+ if STORE_DRESIDUAL:
574
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
575
+ if HAS_DX1:
576
+ if HAS_DROPOUT:
577
+ keep_mask = (
578
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
579
+ > dropout_p
580
+ )
581
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
582
+ else:
583
+ dx1 = dx
584
+ tl.store(DX1 + cols, dx1, mask=mask)
585
+ if HAS_DROPOUT:
586
+ keep_mask = (
587
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
588
+ > dropout_p
589
+ )
590
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
591
+ if HAS_ROWSCALE:
592
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
593
+ dx *= rowscale
594
+ tl.store(DX + cols, dx, mask=mask)
595
+
596
+ X += stride_x_row
597
+ if HAS_DRESIDUAL:
598
+ DRESIDUAL += stride_dres_row
599
+ if STORE_DRESIDUAL:
600
+ DRESIDUAL_IN += stride_dres_in_row
601
+ if RECOMPUTE_OUTPUT:
602
+ Y += stride_y_row
603
+ DY += stride_dy_row
604
+ DX += stride_dx_row
605
+ if HAS_DY1:
606
+ DY1 += stride_dy1_row
607
+ if HAS_DX1:
608
+ DX1 += stride_dx1_row
609
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
610
+ if HAS_BIAS:
611
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
612
+ if HAS_DY1:
613
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
614
+ if HAS_B1:
615
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
616
+
617
+
618
+ def _layer_norm_bwd(
619
+ dy,
620
+ x,
621
+ weight,
622
+ bias,
623
+ eps,
624
+ mean,
625
+ rstd,
626
+ dresidual=None,
627
+ dy1=None,
628
+ weight1=None,
629
+ bias1=None,
630
+ seeds=None,
631
+ dropout_p=0.0,
632
+ rowscale=None,
633
+ has_residual=False,
634
+ has_x1=False,
635
+ is_rms_norm=False,
636
+ x_dtype=None,
637
+ recompute_output=False,
638
+ ):
639
+ M, N = x.shape
640
+ assert x.stride(-1) == 1
641
+ assert dy.stride(-1) == 1
642
+ assert dy.shape == (M, N)
643
+ if dresidual is not None:
644
+ assert dresidual.stride(-1) == 1
645
+ assert dresidual.shape == (M, N)
646
+ assert weight.shape == (N,)
647
+ assert weight.stride(-1) == 1
648
+ if bias is not None:
649
+ assert bias.stride(-1) == 1
650
+ assert bias.shape == (N,)
651
+ if dy1 is not None:
652
+ assert weight1 is not None
653
+ assert dy1.shape == dy.shape
654
+ assert dy1.stride(-1) == 1
655
+ if weight1 is not None:
656
+ assert weight1.shape == (N,)
657
+ assert weight1.stride(-1) == 1
658
+ if bias1 is not None:
659
+ assert bias1.shape == (N,)
660
+ assert bias1.stride(-1) == 1
661
+ if seeds is not None:
662
+ assert seeds.is_contiguous()
663
+ assert seeds.shape == (M if not has_x1 else M * 2,)
664
+ if rowscale is not None:
665
+ assert rowscale.is_contiguous()
666
+ assert rowscale.shape == (M,)
667
+ # allocate output
668
+ dx = (
669
+ torch.empty_like(x)
670
+ if x_dtype is None
671
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
672
+ )
673
+ dresidual_in = (
674
+ torch.empty_like(x)
675
+ if has_residual
676
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
677
+ else None
678
+ )
679
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
680
+ y = (
681
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
682
+ if recompute_output
683
+ else None
684
+ )
685
+ if recompute_output:
686
+ assert (
687
+ weight1 is None
688
+ ), "recompute_output is not supported with parallel LayerNorm"
689
+
690
+ # Less than 64KB per feature: enqueue fused kernel
691
+ MAX_FUSED_SIZE = 65536 // x.element_size()
692
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
693
+ if N > BLOCK_N:
694
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
695
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
696
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
697
+ _db = (
698
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
699
+ if bias is not None
700
+ else None
701
+ )
702
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
703
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
704
+ rows_per_program = math.ceil(M / sm_count)
705
+ grid = (sm_count,)
706
+ with torch.cuda.device(x.device.index):
707
+ _layer_norm_bwd_kernel[grid](
708
+ x,
709
+ weight,
710
+ bias,
711
+ y,
712
+ dy,
713
+ dx,
714
+ _dw,
715
+ _db,
716
+ dresidual,
717
+ weight1,
718
+ dy1,
719
+ dx1,
720
+ _dw1,
721
+ _db1,
722
+ dresidual_in,
723
+ rowscale,
724
+ seeds,
725
+ mean,
726
+ rstd,
727
+ x.stride(0),
728
+ 0 if not recompute_output else y.stride(0),
729
+ dy.stride(0),
730
+ dx.stride(0),
731
+ dresidual.stride(0) if dresidual is not None else 0,
732
+ dy1.stride(0) if dy1 is not None else 0,
733
+ dx1.stride(0) if dx1 is not None else 0,
734
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
735
+ M,
736
+ N,
737
+ eps,
738
+ dropout_p,
739
+ rows_per_program,
740
+ is_rms_norm,
741
+ BLOCK_N,
742
+ dresidual is not None,
743
+ dresidual_in is not None,
744
+ bias is not None,
745
+ dropout_p > 0.0,
746
+ )
747
+ dw = _dw.sum(0).to(weight.dtype)
748
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
749
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
750
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
751
+ # Don't need to compute dresidual_in separately in this case
752
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
753
+ dresidual_in = dx
754
+ if has_x1 and dropout_p == 0.0:
755
+ dx1 = dx
756
+ return (
757
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
758
+ if not recompute_output
759
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
760
+ )
761
+
762
+
763
+ class LayerNormFn(torch.autograd.Function):
764
+ @staticmethod
765
+ def forward(
766
+ ctx,
767
+ x,
768
+ weight,
769
+ bias,
770
+ residual=None,
771
+ x1=None,
772
+ weight1=None,
773
+ bias1=None,
774
+ eps=1e-6,
775
+ dropout_p=0.0,
776
+ rowscale=None,
777
+ prenorm=False,
778
+ residual_in_fp32=False,
779
+ is_rms_norm=False,
780
+ return_dropout_mask=False,
781
+ ):
782
+ x_shape_og = x.shape
783
+ # reshape input data into 2D tensor
784
+ x = x.reshape(-1, x.shape[-1])
785
+ if x.stride(-1) != 1:
786
+ x = x.contiguous()
787
+ if residual is not None:
788
+ assert residual.shape == x_shape_og
789
+ residual = residual.reshape(-1, residual.shape[-1])
790
+ if residual.stride(-1) != 1:
791
+ residual = residual.contiguous()
792
+ if x1 is not None:
793
+ assert x1.shape == x_shape_og
794
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
795
+ x1 = x1.reshape(-1, x1.shape[-1])
796
+ if x1.stride(-1) != 1:
797
+ x1 = x1.contiguous()
798
+ weight = weight.contiguous()
799
+ if bias is not None:
800
+ bias = bias.contiguous()
801
+ if weight1 is not None:
802
+ weight1 = weight1.contiguous()
803
+ if bias1 is not None:
804
+ bias1 = bias1.contiguous()
805
+ if rowscale is not None:
806
+ rowscale = rowscale.reshape(-1).contiguous()
807
+ residual_dtype = (
808
+ residual.dtype
809
+ if residual is not None
810
+ else (torch.float32 if residual_in_fp32 else None)
811
+ )
812
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
813
+ _layer_norm_fwd(
814
+ x,
815
+ weight,
816
+ bias,
817
+ eps,
818
+ residual,
819
+ x1,
820
+ weight1,
821
+ bias1,
822
+ dropout_p=dropout_p,
823
+ rowscale=rowscale,
824
+ residual_dtype=residual_dtype,
825
+ is_rms_norm=is_rms_norm,
826
+ return_dropout_mask=return_dropout_mask,
827
+ )
828
+ )
829
+ ctx.save_for_backward(
830
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
831
+ )
832
+ ctx.x_shape_og = x_shape_og
833
+ ctx.eps = eps
834
+ ctx.dropout_p = dropout_p
835
+ ctx.is_rms_norm = is_rms_norm
836
+ ctx.has_residual = residual is not None
837
+ ctx.has_x1 = x1 is not None
838
+ ctx.prenorm = prenorm
839
+ ctx.x_dtype = x.dtype
840
+ y = y.reshape(x_shape_og)
841
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
842
+ residual_out = (
843
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
844
+ )
845
+ dropout_mask = (
846
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
847
+ )
848
+ dropout_mask1 = (
849
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
850
+ )
851
+ if not return_dropout_mask:
852
+ if weight1 is None:
853
+ return y if not prenorm else (y, residual_out)
854
+ else:
855
+ return (y, y1) if not prenorm else (y, y1, residual_out)
856
+ else:
857
+ if weight1 is None:
858
+ return (
859
+ (y, dropout_mask, dropout_mask1)
860
+ if not prenorm
861
+ else (y, residual_out, dropout_mask, dropout_mask1)
862
+ )
863
+ else:
864
+ return (
865
+ (y, y1, dropout_mask, dropout_mask1)
866
+ if not prenorm
867
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
868
+ )
869
+
870
+ @staticmethod
871
+ def backward(ctx, dy, *args):
872
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
873
+ dy = dy.reshape(-1, dy.shape[-1])
874
+ if dy.stride(-1) != 1:
875
+ dy = dy.contiguous()
876
+ assert dy.shape == x.shape
877
+ if weight1 is not None:
878
+ dy1, args = args[0], args[1:]
879
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
880
+ if dy1.stride(-1) != 1:
881
+ dy1 = dy1.contiguous()
882
+ assert dy1.shape == x.shape
883
+ else:
884
+ dy1 = None
885
+ if ctx.prenorm:
886
+ dresidual = args[0]
887
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
888
+ if dresidual.stride(-1) != 1:
889
+ dresidual = dresidual.contiguous()
890
+ assert dresidual.shape == x.shape
891
+ else:
892
+ dresidual = None
893
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
894
+ dy,
895
+ x,
896
+ weight,
897
+ bias,
898
+ ctx.eps,
899
+ mean,
900
+ rstd,
901
+ dresidual,
902
+ dy1,
903
+ weight1,
904
+ bias1,
905
+ seeds,
906
+ ctx.dropout_p,
907
+ rowscale,
908
+ ctx.has_residual,
909
+ ctx.has_x1,
910
+ ctx.is_rms_norm,
911
+ x_dtype=ctx.x_dtype,
912
+ )
913
+ return (
914
+ dx.reshape(ctx.x_shape_og),
915
+ dw,
916
+ db,
917
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
918
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
919
+ dw1,
920
+ db1,
921
+ None,
922
+ None,
923
+ None,
924
+ None,
925
+ None,
926
+ None,
927
+ None,
928
+ )
929
+
930
+
931
+ def layer_norm_fn(
932
+ x,
933
+ weight,
934
+ bias,
935
+ residual=None,
936
+ x1=None,
937
+ weight1=None,
938
+ bias1=None,
939
+ eps=1e-6,
940
+ dropout_p=0.0,
941
+ rowscale=None,
942
+ prenorm=False,
943
+ residual_in_fp32=False,
944
+ is_rms_norm=False,
945
+ return_dropout_mask=False,
946
+ ):
947
+ return LayerNormFn.apply(
948
+ x,
949
+ weight,
950
+ bias,
951
+ residual,
952
+ x1,
953
+ weight1,
954
+ bias1,
955
+ eps,
956
+ dropout_p,
957
+ rowscale,
958
+ prenorm,
959
+ residual_in_fp32,
960
+ is_rms_norm,
961
+ return_dropout_mask,
962
+ )
963
+
964
+
965
+ def rms_norm_fn(
966
+ x,
967
+ weight,
968
+ bias,
969
+ residual=None,
970
+ x1=None,
971
+ weight1=None,
972
+ bias1=None,
973
+ eps=1e-6,
974
+ dropout_p=0.0,
975
+ rowscale=None,
976
+ prenorm=False,
977
+ residual_in_fp32=False,
978
+ return_dropout_mask=False,
979
+ ):
980
+ return LayerNormFn.apply(
981
+ x,
982
+ weight,
983
+ bias,
984
+ residual,
985
+ x1,
986
+ weight1,
987
+ bias1,
988
+ eps,
989
+ dropout_p,
990
+ rowscale,
991
+ prenorm,
992
+ residual_in_fp32,
993
+ True,
994
+ return_dropout_mask,
995
+ )
996
+
997
+
998
+ class RMSNorm(torch.nn.Module):
999
+
1000
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
1001
+ factory_kwargs = {"device": device, "dtype": dtype}
1002
+ super().__init__()
1003
+ self.eps = eps
1004
+ if dropout_p > 0.0:
1005
+ self.drop = torch.nn.Dropout(dropout_p)
1006
+ else:
1007
+ self.drop = None
1008
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1009
+ self.register_parameter("bias", None)
1010
+ self.reset_parameters()
1011
+
1012
+ def reset_parameters(self):
1013
+ torch.nn.init.ones_(self.weight)
1014
+
1015
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1016
+ return rms_norm_fn(
1017
+ x,
1018
+ self.weight,
1019
+ self.bias,
1020
+ residual=residual,
1021
+ eps=self.eps,
1022
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1023
+ prenorm=prenorm,
1024
+ residual_in_fp32=residual_in_fp32,
1025
+ )
1026
+
1027
+
1028
+ class LayerNormLinearFn(torch.autograd.Function):
1029
+ @staticmethod
1030
+ @custom_fwd
1031
+ def forward(
1032
+ ctx,
1033
+ x,
1034
+ norm_weight,
1035
+ norm_bias,
1036
+ linear_weight,
1037
+ linear_bias,
1038
+ residual=None,
1039
+ eps=1e-6,
1040
+ prenorm=False,
1041
+ residual_in_fp32=False,
1042
+ is_rms_norm=False,
1043
+ ):
1044
+ x_shape_og = x.shape
1045
+ # reshape input data into 2D tensor
1046
+ x = x.reshape(-1, x.shape[-1])
1047
+ if x.stride(-1) != 1:
1048
+ x = x.contiguous()
1049
+ if residual is not None:
1050
+ assert residual.shape == x_shape_og
1051
+ residual = residual.reshape(-1, residual.shape[-1])
1052
+ if residual.stride(-1) != 1:
1053
+ residual = residual.contiguous()
1054
+ norm_weight = norm_weight.contiguous()
1055
+ if norm_bias is not None:
1056
+ norm_bias = norm_bias.contiguous()
1057
+ residual_dtype = (
1058
+ residual.dtype
1059
+ if residual is not None
1060
+ else (torch.float32 if residual_in_fp32 else None)
1061
+ )
1062
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1063
+ x,
1064
+ norm_weight,
1065
+ norm_bias,
1066
+ eps,
1067
+ residual,
1068
+ out_dtype=(
1069
+ None
1070
+ if not torch.is_autocast_enabled()
1071
+ else torch.get_autocast_gpu_dtype()
1072
+ ),
1073
+ residual_dtype=residual_dtype,
1074
+ is_rms_norm=is_rms_norm,
1075
+ )
1076
+ y = y.reshape(x_shape_og)
1077
+ dtype = (
1078
+ torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1079
+ )
1080
+ linear_weight = linear_weight.to(dtype)
1081
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1082
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1083
+ # We don't store y, will be recomputed in the backward pass to save memory
1084
+ ctx.save_for_backward(
1085
+ residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1086
+ )
1087
+ ctx.x_shape_og = x_shape_og
1088
+ ctx.eps = eps
1089
+ ctx.is_rms_norm = is_rms_norm
1090
+ ctx.has_residual = residual is not None
1091
+ ctx.prenorm = prenorm
1092
+ ctx.x_dtype = x.dtype
1093
+ ctx.linear_bias_is_none = linear_bias is None
1094
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1095
+
1096
+ @staticmethod
1097
+ @custom_bwd
1098
+ def backward(ctx, dout, *args):
1099
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1100
+ dout = dout.reshape(-1, dout.shape[-1])
1101
+ dy = F.linear(dout, linear_weight.t())
1102
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1103
+ if dy.stride(-1) != 1:
1104
+ dy = dy.contiguous()
1105
+ assert dy.shape == x.shape
1106
+ if ctx.prenorm:
1107
+ dresidual = args[0]
1108
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1109
+ if dresidual.stride(-1) != 1:
1110
+ dresidual = dresidual.contiguous()
1111
+ assert dresidual.shape == x.shape
1112
+ else:
1113
+ dresidual = None
1114
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1115
+ dy,
1116
+ x,
1117
+ norm_weight,
1118
+ norm_bias,
1119
+ ctx.eps,
1120
+ mean,
1121
+ rstd,
1122
+ dresidual=dresidual,
1123
+ has_residual=ctx.has_residual,
1124
+ is_rms_norm=ctx.is_rms_norm,
1125
+ x_dtype=ctx.x_dtype,
1126
+ recompute_output=True,
1127
+ )
1128
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1129
+ return (
1130
+ dx.reshape(ctx.x_shape_og),
1131
+ dnorm_weight,
1132
+ dnorm_bias,
1133
+ dlinear_weight,
1134
+ dlinear_bias,
1135
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1136
+ None,
1137
+ None,
1138
+ None,
1139
+ None,
1140
+ )
1141
+
1142
+
1143
+ def layer_norm_linear_fn(
1144
+ x,
1145
+ norm_weight,
1146
+ norm_bias,
1147
+ linear_weight,
1148
+ linear_bias,
1149
+ residual=None,
1150
+ eps=1e-6,
1151
+ prenorm=False,
1152
+ residual_in_fp32=False,
1153
+ is_rms_norm=False,
1154
+ ):
1155
+ return LayerNormLinearFn.apply(
1156
+ x,
1157
+ norm_weight,
1158
+ norm_bias,
1159
+ linear_weight,
1160
+ linear_bias,
1161
+ residual,
1162
+ eps,
1163
+ prenorm,
1164
+ residual_in_fp32,
1165
+ is_rms_norm,
1166
+ )
torch-ext/mamba_ssm/ops/triton/layernorm_gated.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from einops import rearrange
16
+
17
+
18
+ def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
19
+ dtype = x.dtype
20
+ N = x.shape[-1]
21
+ weight = weight.float()
22
+ bias = bias.float() if bias is not None else None
23
+ if upcast:
24
+ x = x.float()
25
+ z = z.float() if z is not None else z
26
+ if z is not None and not norm_before_gate:
27
+ x = x * F.silu(z)
28
+ if group_size is None:
29
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
30
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
31
+ else:
32
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
33
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
34
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
35
+ if bias is not None:
36
+ out = out + bias
37
+ if z is not None and norm_before_gate:
38
+ out *= F.silu(z)
39
+ return out.to(dtype)
40
+
41
+
42
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
43
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
44
+ @triton.jit
45
+ def _layer_norm_fwd_1pass_kernel(
46
+ X, # pointer to the input
47
+ Y, # pointer to the output
48
+ W, # pointer to the weights
49
+ B, # pointer to the biases
50
+ Z, # pointer to the other branch
51
+ Mean, # pointer to the mean
52
+ Rstd, # pointer to the 1/std
53
+ stride_x_row, # how much to increase the pointer when moving by 1 row
54
+ stride_y_row,
55
+ stride_z_row,
56
+ M, # number of rows in X
57
+ N, # number of columns in X
58
+ eps, # epsilon to avoid division by zero
59
+ BLOCK_N: tl.constexpr,
60
+ HAS_BIAS: tl.constexpr,
61
+ HAS_Z: tl.constexpr,
62
+ NORM_BEFORE_GATE: tl.constexpr,
63
+ IS_RMS_NORM: tl.constexpr,
64
+ ):
65
+ # Map the program id to the row of X and Y it should compute.
66
+ row = tl.program_id(0)
67
+ group = tl.program_id(1)
68
+ X += row * stride_x_row + group * N
69
+ Y += row * stride_y_row + group * N
70
+ if HAS_Z:
71
+ Z += row * stride_z_row + group * N
72
+ if not IS_RMS_NORM:
73
+ Mean += group * M
74
+ Rstd += group * M
75
+ W += group * N
76
+ if HAS_BIAS:
77
+ B += group * N
78
+ # Compute mean and variance
79
+ cols = tl.arange(0, BLOCK_N)
80
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
81
+ if HAS_Z and not NORM_BEFORE_GATE:
82
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
83
+ x *= z * tl.sigmoid(z)
84
+ if not IS_RMS_NORM:
85
+ mean = tl.sum(x, axis=0) / N
86
+ tl.store(Mean + row, mean)
87
+ xbar = tl.where(cols < N, x - mean, 0.)
88
+ var = tl.sum(xbar * xbar, axis=0) / N
89
+ else:
90
+ xbar = tl.where(cols < N, x, 0.)
91
+ var = tl.sum(xbar * xbar, axis=0) / N
92
+ rstd = 1 / tl.sqrt(var + eps)
93
+ tl.store(Rstd + row, rstd)
94
+ # Normalize and apply linear transformation
95
+ mask = cols < N
96
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
97
+ if HAS_BIAS:
98
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
99
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
100
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
101
+ if HAS_Z and NORM_BEFORE_GATE:
102
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
103
+ y *= z * tl.sigmoid(z)
104
+ # Write output
105
+ tl.store(Y + cols, y, mask=mask)
106
+
107
+
108
+ def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
109
+ M, N = x.shape
110
+ if group_size is None:
111
+ group_size = N
112
+ assert N % group_size == 0
113
+ ngroups = N // group_size
114
+ assert x.stride(-1) == 1
115
+ if z is not None:
116
+ assert z.stride(-1) == 1
117
+ assert z.shape == (M, N)
118
+ assert weight.shape == (N,)
119
+ assert weight.stride(-1) == 1
120
+ if bias is not None:
121
+ assert bias.stride(-1) == 1
122
+ assert bias.shape == (N,)
123
+ # allocate output
124
+ if out is not None:
125
+ assert out.shape == x.shape
126
+ else:
127
+ out = torch.empty_like(x)
128
+ assert out.stride(-1) == 1
129
+ mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
130
+ rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
131
+ # Less than 64KB per feature: enqueue fused kernel
132
+ MAX_FUSED_SIZE = 65536 // x.element_size()
133
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
134
+ if group_size > BLOCK_N:
135
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
136
+ # heuristics for number of warps
137
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
138
+ grid = (M, ngroups)
139
+ with torch.cuda.device(x.device.index):
140
+ _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
141
+ x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
142
+ M, group_size, eps,
143
+ BLOCK_N=BLOCK_N,
144
+ NORM_BEFORE_GATE=norm_before_gate,
145
+ IS_RMS_NORM=is_rms_norm,
146
+ num_warps=num_warps)
147
+ return out, mean, rstd
148
+
149
+
150
+
151
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
152
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
153
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
154
+ @triton.jit
155
+ def _layer_norm_bwd_kernel(
156
+ X, # pointer to the input
157
+ W, # pointer to the weights
158
+ B, # pointer to the biases
159
+ Z, # pointer to the other branch
160
+ Y, # pointer to the output to be recomputed
161
+ DY, # pointer to the output gradient
162
+ DX, # pointer to the input gradient
163
+ DW, # pointer to the partial sum of weights gradient
164
+ DB, # pointer to the partial sum of biases gradient
165
+ DZ, # pointer to the other branch
166
+ Mean, # pointer to the mean
167
+ Rstd, # pointer to the 1/std
168
+ stride_x_row, # how much to increase the pointer when moving by 1 row
169
+ stride_z_row,
170
+ stride_y_row,
171
+ stride_dy_row,
172
+ stride_dx_row,
173
+ stride_dz_row,
174
+ stride_dw_row,
175
+ stride_db_row,
176
+ M, # number of rows in X
177
+ N, # number of columns in X
178
+ eps, # epsilon to avoid division by zero
179
+ rows_per_program,
180
+ NORM_BEFORE_GATE: tl.constexpr,
181
+ IS_RMS_NORM: tl.constexpr,
182
+ HAS_BIAS: tl.constexpr,
183
+ HAS_Z: tl.constexpr,
184
+ RECOMPUTE_OUTPUT: tl.constexpr,
185
+ BLOCK_N: tl.constexpr,
186
+ ):
187
+ # Map the program id to the elements of X, DX, and DY it should compute.
188
+ row_block_id = tl.program_id(0)
189
+ group = tl.program_id(1)
190
+ row_start = row_block_id * rows_per_program
191
+ cols = tl.arange(0, BLOCK_N)
192
+ mask = cols < N
193
+ X += row_start * stride_x_row + group * N
194
+ if HAS_Z:
195
+ Z += row_start * stride_z_row + group * N
196
+ DZ += row_start * stride_dz_row + group * N
197
+ DY += row_start * stride_dy_row + group * N
198
+ DX += row_start * stride_dx_row + group * N
199
+ if RECOMPUTE_OUTPUT:
200
+ Y += row_start * stride_y_row + group * N
201
+ if not IS_RMS_NORM:
202
+ Mean += group * M
203
+ Rstd += group * M
204
+ W += group * N
205
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
206
+ if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
207
+ B += group * N
208
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
209
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
210
+ if HAS_BIAS:
211
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
212
+ row_end = min((row_block_id + 1) * rows_per_program, M)
213
+ for row in range(row_start, row_end):
214
+ # Load data to SRAM
215
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
216
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
217
+ if not IS_RMS_NORM:
218
+ mean = tl.load(Mean + row)
219
+ if HAS_Z and not NORM_BEFORE_GATE:
220
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
221
+ x_og = x
222
+ x = x_og * z * tl.sigmoid(z)
223
+ rstd = tl.load(Rstd + row)
224
+ # Compute dx
225
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
226
+ xhat = tl.where(mask, xhat, 0.)
227
+ if HAS_Z and NORM_BEFORE_GATE:
228
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
229
+ z_sigmoid = tl.sigmoid(z)
230
+ y = xhat * w + b if HAS_BIAS else xhat * w
231
+ if RECOMPUTE_OUTPUT:
232
+ tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
233
+ dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
234
+ tl.store(DZ + cols, dz, mask=mask)
235
+ dy *= z * z_sigmoid
236
+ else:
237
+ if RECOMPUTE_OUTPUT:
238
+ y = xhat * w + b if HAS_BIAS else xhat * w
239
+ tl.store(Y + cols, y, mask=mask)
240
+ wdy = w * dy
241
+ c1 = tl.sum(xhat * wdy, axis=0) / N
242
+ if not IS_RMS_NORM:
243
+ c2 = tl.sum(wdy, axis=0) / N
244
+ dx = (wdy - (xhat * c1 + c2)) * rstd
245
+ else:
246
+ dx = (wdy - xhat * c1) * rstd
247
+ dw += dy * xhat
248
+ if HAS_BIAS:
249
+ db += dy
250
+ if HAS_Z and not NORM_BEFORE_GATE:
251
+ z_sigmoid = tl.sigmoid(z)
252
+ dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
253
+ tl.store(DZ + cols, dz, mask=mask)
254
+ dx *= z * z_sigmoid
255
+ # Write dx
256
+ tl.store(DX + cols, dx, mask=mask)
257
+
258
+ X += stride_x_row
259
+ if HAS_Z:
260
+ Z += stride_z_row
261
+ DZ += stride_dz_row
262
+ if RECOMPUTE_OUTPUT:
263
+ Y += stride_y_row
264
+ DY += stride_dy_row
265
+ DX += stride_dx_row
266
+ tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
267
+ if HAS_BIAS:
268
+ tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
269
+
270
+
271
+ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
272
+ norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
273
+ M, N = x.shape
274
+ if group_size is None:
275
+ group_size = N
276
+ assert N % group_size == 0
277
+ ngroups = N // group_size
278
+ assert x.stride(-1) == 1
279
+ assert dy.stride(-1) == 1
280
+ assert dy.shape == (M, N)
281
+ if z is not None:
282
+ assert z.stride(-1) == 1
283
+ assert z.shape == (M, N)
284
+ assert weight.shape == (N,)
285
+ assert weight.stride(-1) == 1
286
+ if bias is not None:
287
+ assert bias.stride(-1) == 1
288
+ assert bias.shape == (N,)
289
+ # allocate output
290
+ dx = torch.empty_like(x)
291
+ if dz is not None:
292
+ assert z is not None
293
+ assert dz.shape == z.shape
294
+ assert dz.stride(-1) == 1
295
+ else:
296
+ dz = torch.empty_like(z) if z is not None else None
297
+ if recompute_output:
298
+ if out is None:
299
+ out = torch.empty_like(x)
300
+ assert out.shape == x.shape
301
+
302
+ # Less than 64KB per feature: enqueue fused kernel
303
+ MAX_FUSED_SIZE = 65536 // x.element_size()
304
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
305
+ if group_size > BLOCK_N:
306
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
307
+ # heuristics for number of warps
308
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
309
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
310
+ # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
311
+ # would limit the occupancy.
312
+ nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
313
+ _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
314
+ _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
315
+ rows_per_program = math.ceil(M / nrow_groups)
316
+ grid = (nrow_groups, ngroups)
317
+ with torch.cuda.device(x.device.index):
318
+ _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
319
+ dy, dx, _dw, _db, dz, mean, rstd,
320
+ x.stride(0),
321
+ z.stride(0) if z is not None else 0,
322
+ 0 if not recompute_output else out.stride(0),
323
+ dy.stride(0), dx.stride(0),
324
+ dz.stride(0) if dz is not None else 0,
325
+ _dw.stride(0),
326
+ _db.stride(0) if _db is not None else 0,
327
+ M, group_size, eps,
328
+ rows_per_program,
329
+ BLOCK_N=BLOCK_N,
330
+ NORM_BEFORE_GATE=norm_before_gate,
331
+ IS_RMS_NORM=is_rms_norm,
332
+ num_warps=num_warps)
333
+ dw = _dw.sum(0).to(weight.dtype)
334
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
335
+ return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
336
+
337
+
338
+ class LayerNormFn(torch.autograd.Function):
339
+
340
+ @staticmethod
341
+ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
342
+ is_rms_norm=False):
343
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
344
+ """
345
+
346
+ x_shape_og = x.shape
347
+ # reshape input data into 2D tensor
348
+ x = x.reshape(-1, x.shape[-1])
349
+ if x.stride(-1) != 1:
350
+ x = x.contiguous()
351
+ if z is not None:
352
+ assert z.shape == x_shape_og
353
+ z = z.reshape(-1, z.shape[-1])
354
+ if z.stride(-1) != 1:
355
+ z = z.contiguous()
356
+ weight = weight.contiguous()
357
+ if bias is not None:
358
+ bias = bias.contiguous()
359
+ y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
360
+ ctx.save_for_backward(x, weight, bias, mean, rstd, z)
361
+ ctx.x_shape_og = x_shape_og
362
+ ctx.eps = eps
363
+ ctx.group_size = group_size
364
+ ctx.norm_before_gate = norm_before_gate
365
+ ctx.is_rms_norm = is_rms_norm
366
+ return y.reshape(x_shape_og)
367
+
368
+ @staticmethod
369
+ def backward(ctx, dy):
370
+ x, weight, bias, mean, rstd, z = ctx.saved_tensors
371
+ dy = dy.reshape(-1, dy.shape[-1])
372
+ if dy.stride(-1) != 1:
373
+ dy = dy.contiguous()
374
+ assert dy.shape == x.shape
375
+ dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
376
+ ctx.norm_before_gate, ctx.is_rms_norm)
377
+ return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
378
+
379
+
380
+ def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
381
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
382
+
383
+
384
+ def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
385
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
386
+
387
+
388
+ class LayerNorm(torch.nn.Module):
389
+
390
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
391
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
392
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
393
+ """
394
+
395
+ factory_kwargs = {"device": device, "dtype": dtype}
396
+ super().__init__()
397
+ self.eps = eps
398
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
399
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
400
+ self.group_size = group_size
401
+ self.norm_before_gate = norm_before_gate
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ torch.nn.init.ones_(self.weight)
406
+ torch.nn.init.zeros_(self.bias)
407
+
408
+ def forward(self, x, z=None):
409
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
410
+ """
411
+ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
412
+ norm_before_gate=self.norm_before_gate)
413
+
414
+
415
+ class RMSNorm(torch.nn.Module):
416
+
417
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
418
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
419
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
420
+ """
421
+ factory_kwargs = {"device": device, "dtype": dtype}
422
+ super().__init__()
423
+ self.eps = eps
424
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
425
+ self.register_parameter("bias", None)
426
+ self.group_size = group_size
427
+ self.norm_before_gate = norm_before_gate
428
+ self.reset_parameters()
429
+
430
+ def reset_parameters(self):
431
+ torch.nn.init.ones_(self.weight)
432
+
433
+ def forward(self, x, z=None):
434
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
435
+ """
436
+ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
437
+ norm_before_gate=self.norm_before_gate)
torch-ext/mamba_ssm/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from .softplus import softplus
16
+
17
+
18
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
+ @triton.heuristics(
22
+ {
23
+ "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
24
+ is not None
25
+ }
26
+ )
27
+ @triton.heuristics(
28
+ {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
29
+ )
30
+ @triton.jit
31
+ def _selective_scan_update_kernel(
32
+ # Pointers to matrices
33
+ state_ptr,
34
+ x_ptr,
35
+ dt_ptr,
36
+ dt_bias_ptr,
37
+ A_ptr,
38
+ B_ptr,
39
+ C_ptr,
40
+ D_ptr,
41
+ z_ptr,
42
+ out_ptr,
43
+ state_batch_indices_ptr,
44
+ # Matrix dimensions
45
+ batch,
46
+ nheads,
47
+ dim,
48
+ dstate,
49
+ nheads_ngroups_ratio,
50
+ # Strides
51
+ stride_state_batch,
52
+ stride_state_head,
53
+ stride_state_dim,
54
+ stride_state_dstate,
55
+ stride_x_batch,
56
+ stride_x_head,
57
+ stride_x_dim,
58
+ stride_dt_batch,
59
+ stride_dt_head,
60
+ stride_dt_dim,
61
+ stride_dt_bias_head,
62
+ stride_dt_bias_dim,
63
+ stride_A_head,
64
+ stride_A_dim,
65
+ stride_A_dstate,
66
+ stride_B_batch,
67
+ stride_B_group,
68
+ stride_B_dstate,
69
+ stride_C_batch,
70
+ stride_C_group,
71
+ stride_C_dstate,
72
+ stride_D_head,
73
+ stride_D_dim,
74
+ stride_z_batch,
75
+ stride_z_head,
76
+ stride_z_dim,
77
+ stride_out_batch,
78
+ stride_out_head,
79
+ stride_out_dim,
80
+ # Meta-parameters
81
+ DT_SOFTPLUS: tl.constexpr,
82
+ TIE_HDIM: tl.constexpr,
83
+ BLOCK_SIZE_M: tl.constexpr,
84
+ HAS_DT_BIAS: tl.constexpr,
85
+ HAS_D: tl.constexpr,
86
+ HAS_Z: tl.constexpr,
87
+ HAS_STATE_BATCH_INDICES: tl.constexpr,
88
+ BLOCK_SIZE_DSTATE: tl.constexpr,
89
+ ):
90
+ pid_m = tl.program_id(axis=0)
91
+ pid_b = tl.program_id(axis=1)
92
+ pid_h = tl.program_id(axis=2)
93
+
94
+ if HAS_STATE_BATCH_INDICES:
95
+ state_batch_indices_ptr += pid_b
96
+ state_batch_idx = tl.load(state_batch_indices_ptr)
97
+ state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
98
+ else:
99
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
100
+
101
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
102
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
103
+ if HAS_DT_BIAS:
104
+ dt_bias_ptr += pid_h * stride_dt_bias_head
105
+ A_ptr += pid_h * stride_A_head
106
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
107
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
108
+ if HAS_Z:
109
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
110
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
111
+
112
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
113
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
114
+ state_ptrs = state_ptr + (
115
+ offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
116
+ )
117
+ x_ptrs = x_ptr + offs_m * stride_x_dim
118
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
119
+ if HAS_DT_BIAS:
120
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
121
+ if HAS_D:
122
+ D_ptr += pid_h * stride_D_head
123
+ A_ptrs = A_ptr + (
124
+ offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
125
+ )
126
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
127
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
128
+ if HAS_D:
129
+ D_ptrs = D_ptr + offs_m * stride_D_dim
130
+ if HAS_Z:
131
+ z_ptrs = z_ptr + offs_m * stride_z_dim
132
+ out_ptrs = out_ptr + offs_m * stride_out_dim
133
+
134
+ state = tl.load(
135
+ state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
136
+ )
137
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
138
+ if not TIE_HDIM:
139
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
140
+ if HAS_DT_BIAS:
141
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
142
+ if DT_SOFTPLUS:
143
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
144
+ A = tl.load(
145
+ A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
146
+ ).to(tl.float32)
147
+ dA = tl.exp(A * dt[:, None])
148
+ else:
149
+ dt = tl.load(dt_ptr).to(tl.float32)
150
+ if HAS_DT_BIAS:
151
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
152
+ if DT_SOFTPLUS:
153
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
154
+ A = tl.load(A_ptr).to(tl.float32)
155
+ dA = tl.exp(A * dt) # scalar, not a matrix
156
+
157
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
158
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
159
+ if HAS_D:
160
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
161
+ if HAS_Z:
162
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
163
+
164
+ if not TIE_HDIM:
165
+ dB = B[None, :] * dt[:, None]
166
+ else:
167
+ dB = B * dt # vector of size (dstate,)
168
+ state = state * dA + dB * x[:, None]
169
+ tl.store(
170
+ state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
171
+ )
172
+ out = tl.sum(state * C[None, :], axis=1)
173
+ if HAS_D:
174
+ out += x * D
175
+ if HAS_Z:
176
+ out *= z * tl.sigmoid(z)
177
+ tl.store(out_ptrs, out, mask=offs_m < dim)
178
+
179
+
180
+ def selective_state_update(
181
+ state,
182
+ x,
183
+ dt,
184
+ A,
185
+ B,
186
+ C,
187
+ D=None,
188
+ z=None,
189
+ dt_bias=None,
190
+ dt_softplus=False,
191
+ state_batch_indices=None,
192
+ ):
193
+ """
194
+ Argument:
195
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
196
+ x: (batch, dim) or (batch, nheads, dim)
197
+ dt: (batch, dim) or (batch, nheads, dim)
198
+ A: (dim, dstate) or (nheads, dim, dstate)
199
+ B: (batch, dstate) or (batch, ngroups, dstate)
200
+ C: (batch, dstate) or (batch, ngroups, dstate)
201
+ D: (dim,) or (nheads, dim)
202
+ z: (batch, dim) or (batch, nheads, dim)
203
+ dt_bias: (dim,) or (nheads, dim)
204
+ Return:
205
+ out: (batch, dim) or (batch, nheads, dim)
206
+ """
207
+ has_heads = state.dim() > 3
208
+ if state.dim() == 3:
209
+ state = state.unsqueeze(1)
210
+ if x.dim() == 2:
211
+ x = x.unsqueeze(1)
212
+ if dt.dim() == 2:
213
+ dt = dt.unsqueeze(1)
214
+ if A.dim() == 2:
215
+ A = A.unsqueeze(0)
216
+ if B.dim() == 2:
217
+ B = B.unsqueeze(1)
218
+ if C.dim() == 2:
219
+ C = C.unsqueeze(1)
220
+ if D is not None and D.dim() == 1:
221
+ D = D.unsqueeze(0)
222
+ if z is not None and z.dim() == 2:
223
+ z = z.unsqueeze(1)
224
+ if dt_bias is not None and dt_bias.dim() == 1:
225
+ dt_bias = dt_bias.unsqueeze(0)
226
+ _, nheads, dim, dstate = state.shape
227
+ batch = x.shape[0]
228
+ if x.shape != (batch, nheads, dim):
229
+ print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
230
+ assert x.shape == (batch, nheads, dim)
231
+ assert dt.shape == x.shape
232
+ assert A.shape == (nheads, dim, dstate)
233
+ ngroups = B.shape[1]
234
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
235
+ assert B.shape == (batch, ngroups, dstate)
236
+ assert C.shape == B.shape
237
+ if D is not None:
238
+ assert D.shape == (nheads, dim)
239
+ if z is not None:
240
+ assert z.shape == x.shape
241
+ if dt_bias is not None:
242
+ assert dt_bias.shape == (nheads, dim)
243
+ if state_batch_indices is not None:
244
+ assert state_batch_indices.shape == (batch,)
245
+ out = torch.empty_like(x)
246
+ grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
247
+ z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
248
+ # We don't want autotune since it will overwrite the state
249
+ # We instead tune by hand.
250
+ BLOCK_SIZE_M, num_warps = (
251
+ (32, 4)
252
+ if dstate <= 16
253
+ else (
254
+ (16, 4)
255
+ if dstate <= 32
256
+ else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
257
+ )
258
+ )
259
+ tie_hdim = (
260
+ A.stride(-1) == 0
261
+ and A.stride(-2) == 0
262
+ and dt.stride(-1) == 0
263
+ and dt_bias.stride(-1) == 0
264
+ )
265
+ with torch.cuda.device(x.device.index):
266
+ _selective_scan_update_kernel[grid](
267
+ state,
268
+ x,
269
+ dt,
270
+ dt_bias,
271
+ A,
272
+ B,
273
+ C,
274
+ D,
275
+ z,
276
+ out,
277
+ state_batch_indices,
278
+ batch,
279
+ nheads,
280
+ dim,
281
+ dstate,
282
+ nheads // ngroups,
283
+ state.stride(0),
284
+ state.stride(1),
285
+ state.stride(2),
286
+ state.stride(3),
287
+ x.stride(0),
288
+ x.stride(1),
289
+ x.stride(2),
290
+ dt.stride(0),
291
+ dt.stride(1),
292
+ dt.stride(2),
293
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
294
+ A.stride(0),
295
+ A.stride(1),
296
+ A.stride(2),
297
+ B.stride(0),
298
+ B.stride(1),
299
+ B.stride(2),
300
+ C.stride(0),
301
+ C.stride(1),
302
+ C.stride(2),
303
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
304
+ z_strides[0],
305
+ z_strides[1],
306
+ z_strides[2],
307
+ out.stride(0),
308
+ out.stride(1),
309
+ out.stride(2),
310
+ dt_softplus,
311
+ tie_hdim,
312
+ BLOCK_SIZE_M,
313
+ num_warps=num_warps,
314
+ )
315
+ if not has_heads:
316
+ out = out.squeeze(1)
317
+ return out
318
+
319
+
320
+ def selective_state_update_ref(
321
+ state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
322
+ ):
323
+ """
324
+ Argument:
325
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
326
+ x: (batch, dim) or (batch, nheads, dim)
327
+ dt: (batch, dim) or (batch, nheads, dim)
328
+ A: (dim, dstate) or (nheads, dim, dstate)
329
+ B: (batch, dstate) or (batch, ngroups, dstate)
330
+ C: (batch, dstate) or (batch, ngroups, dstate)
331
+ D: (dim,) or (nheads, dim)
332
+ z: (batch, dim) or (batch, nheads, dim)
333
+ dt_bias: (dim,) or (nheads, dim)
334
+ Return:
335
+ out: (batch, dim) or (batch, nheads, dim)
336
+ """
337
+ has_heads = state.dim() > 3
338
+ if state.dim() == 3:
339
+ state = state.unsqueeze(1)
340
+ if x.dim() == 2:
341
+ x = x.unsqueeze(1)
342
+ if dt.dim() == 2:
343
+ dt = dt.unsqueeze(1)
344
+ if A.dim() == 2:
345
+ A = A.unsqueeze(0)
346
+ if B.dim() == 2:
347
+ B = B.unsqueeze(1)
348
+ if C.dim() == 2:
349
+ C = C.unsqueeze(1)
350
+ if D is not None and D.dim() == 1:
351
+ D = D.unsqueeze(0)
352
+ if z is not None and z.dim() == 2:
353
+ z = z.unsqueeze(1)
354
+ if dt_bias is not None and dt_bias.dim() == 1:
355
+ dt_bias = dt_bias.unsqueeze(0)
356
+ batch, nheads, dim, dstate = state.shape
357
+ assert x.shape == (batch, nheads, dim)
358
+ assert dt.shape == x.shape
359
+ assert A.shape == (nheads, dim, dstate)
360
+ ngroups = B.shape[1]
361
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
362
+ assert B.shape == (batch, ngroups, dstate)
363
+ assert C.shape == B.shape
364
+ if D is not None:
365
+ assert D.shape == (nheads, dim)
366
+ if z is not None:
367
+ assert z.shape == x.shape
368
+ if dt_bias is not None:
369
+ assert dt_bias.shape == (nheads, dim)
370
+ dt = dt + dt_bias
371
+ dt = F.softplus(dt) if dt_softplus else dt
372
+ dA = torch.exp(
373
+ rearrange(dt, "b h d -> b h d 1") * A
374
+ ) # (batch, nheads, dim, dstate)
375
+ B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
376
+ C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
377
+ dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
378
+ B, "b h n -> b h 1 n"
379
+ ) # (batch, nheads, dim, dstate)
380
+ state.copy_(
381
+ state * dA + dB * rearrange(x, "b h d -> b h d 1")
382
+ ) # (batch, dim, dstate
383
+ out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
384
+ if D is not None:
385
+ out += (x * D).to(out.dtype)
386
+ out = (out if z is None else out * F.silu(z)).to(x.dtype)
387
+ if not has_heads:
388
+ out = out.squeeze(1)
389
+ return out
torch-ext/mamba_ssm/ops/triton/softplus.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from packaging import version
4
+
5
+ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
6
+
7
+
8
+ if TRITON3:
9
+ @triton.jit
10
+ def softplus(dt):
11
+ return tl.math.log(tl.math.exp(dt) + 1)
12
+ else:
13
+ @triton.jit
14
+ def softplus(dt):
15
+ return tl.math.log1p(tl.exp(dt))
torch-ext/mamba_ssm/ops/triton/ssd_bmm.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ def init_to_zero(names):
17
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
18
+
19
+
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
23
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
25
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
26
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
27
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
28
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
29
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
30
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
31
+ ],
32
+ key=['chunk_size', 'K', 'IS_CAUSAL'],
33
+ )
34
+ @triton.jit
35
+ def _bmm_chunk_fwd_kernel(
36
+ # Pointers to matrices
37
+ a_ptr, b_ptr, out_ptr, seq_idx_ptr,
38
+ # Matrix dimensions
39
+ seqlen, chunk_size, K, ngroups,
40
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
41
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
42
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
43
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
44
+ # Meta-parameters
45
+ IS_CAUSAL: tl.constexpr,
46
+ dot_dtype: tl.constexpr,
47
+ HAS_SEQ_IDX: tl.constexpr,
48
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
49
+ ):
50
+ pid_b = tl.program_id(axis=1)
51
+ pid_ch = tl.program_id(axis=2)
52
+ pid_c = pid_ch // ngroups
53
+ pid_h = pid_ch - pid_c * ngroups
54
+ num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
55
+ pid_m = tl.program_id(axis=0) // num_pid_n
56
+ pid_n = tl.program_id(axis=0) % num_pid_n
57
+ if IS_CAUSAL:
58
+ if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
59
+ return
60
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
61
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
62
+ if HAS_SEQ_IDX:
63
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
64
+
65
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
67
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
68
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
69
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
70
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
71
+
72
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
74
+ a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
75
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
76
+ acc += tl.dot(a, b)
77
+ a_ptrs += BLOCK_SIZE_K * stride_ak
78
+ b_ptrs += BLOCK_SIZE_K * stride_bk
79
+
80
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
81
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82
+ if HAS_SEQ_IDX:
83
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
85
+ seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
86
+ acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
87
+ out = acc.to(out_ptr.dtype.element_ty)
88
+
89
+ out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90
+ out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
91
+ tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
92
+
93
+
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
97
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
98
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
99
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
100
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
101
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
102
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
103
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
104
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
105
+ ],
106
+ key=['chunk_size', 'K'],
107
+ )
108
+ @triton.jit
109
+ def _bmm_chunk_bwd_kernel(
110
+ # Pointers to matrices
111
+ a_ptr, dout_ptr, db_ptr, res_ptr,
112
+ # Matrix dimensions
113
+ seqlen, chunk_size, K, ngroups,
114
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
115
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
116
+ stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
117
+ stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
118
+ # Meta-parameters
119
+ dot_dtype: tl.constexpr,
120
+ HAS_RESIDUAL: tl.constexpr,
121
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
122
+ ):
123
+ pid_b = tl.program_id(axis=1)
124
+ pid_ch = tl.program_id(axis=2)
125
+ pid_c = pid_ch // ngroups
126
+ pid_h = pid_ch - pid_c * ngroups
127
+ num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
128
+ pid_m = tl.program_id(axis=0) // num_pid_n
129
+ pid_n = tl.program_id(axis=0) % num_pid_n
130
+
131
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
132
+ dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
133
+
134
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
135
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
136
+ offs_cs = tl.arange(0, BLOCK_SIZE_CS)
137
+ dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
138
+ a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
139
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
140
+
141
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142
+ for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
143
+ dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
144
+ a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
145
+ acc += tl.dot(dout, a)
146
+ dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
147
+ a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
148
+
149
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
150
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
151
+ if HAS_RESIDUAL:
152
+ res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
153
+ res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
154
+ res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
155
+ acc += res
156
+ db = acc.to(db_ptr.dtype.element_ty)
157
+
158
+ db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
159
+ db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
160
+ tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
161
+
162
+
163
+ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
164
+ """
165
+ Argument:
166
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
167
+ b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
168
+ seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
169
+ causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
170
+ guaranteed to be correct.
171
+ Return:
172
+ out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
173
+ """
174
+ # Check constraints.
175
+ has_groups = a.dim() == 4
176
+ if not has_groups:
177
+ batch, seqlen, k = a.shape
178
+ else:
179
+ batch, seqlen, ngroups, k = a.shape
180
+ assert b.shape == a.shape
181
+ if seq_idx is not None:
182
+ assert seq_idx.shape == (batch, seqlen)
183
+ if a.stride(-1) != 1 and a.stride(1) != 1:
184
+ a = a.contiguous()
185
+ if b.stride(-1) != 1 and b.stride(1) != 1:
186
+ b = b.contiguous()
187
+ nchunks = math.ceil(seqlen / chunk_size)
188
+ # Allocates output.
189
+ out_dtype = a.dtype if output_dtype is None else output_dtype
190
+ out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
191
+ device=a.device, dtype=out_dtype)
192
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
193
+ (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
194
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
195
+ batch, nchunks if not has_groups else nchunks * ngroups)
196
+ with torch.cuda.device(a.device.index):
197
+ _bmm_chunk_fwd_kernel[grid](
198
+ a, b, out, seq_idx,
199
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
200
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
201
+ b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
202
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
203
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
204
+ causal,
205
+ dot_dtype,
206
+ HAS_SEQ_IDX=seq_idx is not None,
207
+ )
208
+ return out
209
+
210
+
211
+ def _bmm_chunk_bwd(a, dout, residual=None, out=None):
212
+ """
213
+ Argument:
214
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
215
+ dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
216
+ residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
217
+ Return:
218
+ out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
219
+
220
+ If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
221
+ zeroed out before calling this function.
222
+ """
223
+ # Check constraints.
224
+ has_groups = a.dim() == 4
225
+ if not has_groups:
226
+ batch, seqlen, k = a.shape
227
+ else:
228
+ batch, seqlen, ngroups, k = a.shape
229
+ nchunks, chunk_size = dout.shape[1], dout.shape[-1]
230
+ if a.stride(-1) != 1 and a.stride(-2) != 1:
231
+ a = a.contiguous()
232
+ if dout.stride(-1) != 1 and dout.stride(-2) != 1:
233
+ dout = dout.contiguous()
234
+ if residual is not None:
235
+ assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
236
+ if residual.stride(-1) != 1 and residual.stride(1) != 1:
237
+ residual = residual.contiguous()
238
+ # Allocates output.
239
+ if out is not None:
240
+ assert out.shape == a.shape
241
+ assert out.stride(-1) == 1 or out.stride(1) == 1
242
+ else:
243
+ out = torch.empty_like(a)
244
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
245
+ (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
246
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
247
+ nchunks if not has_groups else nchunks * ngroups)
248
+ residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
249
+ residual.stride(-1))
250
+ if residual is not None else (0, 0, 0, 0))
251
+ with torch.cuda.device(a.device.index):
252
+ _bmm_chunk_bwd_kernel[grid](
253
+ a, dout, out, residual,
254
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
255
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
256
+ dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
257
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
258
+ residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
259
+ dot_dtype,
260
+ HAS_RESIDUAL=residual is not None,
261
+ )
262
+ return out
torch-ext/mamba_ssm/ops/triton/ssd_chunk_scan.py ADDED
The diff for this file is too large to render. See raw diff
 
torch-ext/mamba_ssm/ops/triton/ssd_chunk_state.py ADDED
@@ -0,0 +1,2012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from .softplus import softplus
16
+
17
+
18
+ def init_to_zero(names):
19
+ return lambda nargs: [
20
+ nargs[name].zero_() for name in names if nargs[name] is not None
21
+ ]
22
+
23
+
24
+ @triton.autotune(
25
+ configs=[
26
+ triton.Config({"BLOCK_SIZE_H": 1}),
27
+ triton.Config({"BLOCK_SIZE_H": 2}),
28
+ triton.Config({"BLOCK_SIZE_H": 4}),
29
+ triton.Config({"BLOCK_SIZE_H": 8}),
30
+ triton.Config({"BLOCK_SIZE_H": 16}),
31
+ triton.Config({"BLOCK_SIZE_H": 32}),
32
+ triton.Config({"BLOCK_SIZE_H": 64}),
33
+ ],
34
+ key=["chunk_size", "nheads"],
35
+ )
36
+ @triton.jit
37
+ def _chunk_cumsum_fwd_kernel(
38
+ # Pointers to matrices
39
+ dt_ptr,
40
+ A_ptr,
41
+ dt_bias_ptr,
42
+ dt_out_ptr,
43
+ dA_cumsum_ptr,
44
+ # Matrix dimension
45
+ batch,
46
+ seqlen,
47
+ nheads,
48
+ chunk_size,
49
+ dt_min,
50
+ dt_max,
51
+ # Strides
52
+ stride_dt_batch,
53
+ stride_dt_seqlen,
54
+ stride_dt_head,
55
+ stride_A_head,
56
+ stride_dt_bias_head,
57
+ stride_dt_out_batch,
58
+ stride_dt_out_chunk,
59
+ stride_dt_out_head,
60
+ stride_dt_out_csize,
61
+ stride_dA_cs_batch,
62
+ stride_dA_cs_chunk,
63
+ stride_dA_cs_head,
64
+ stride_dA_cs_csize,
65
+ # Meta-parameters
66
+ DT_SOFTPLUS: tl.constexpr,
67
+ HAS_DT_BIAS: tl.constexpr,
68
+ BLOCK_SIZE_H: tl.constexpr,
69
+ BLOCK_SIZE_CHUNK: tl.constexpr,
70
+ ):
71
+ pid_b = tl.program_id(axis=0)
72
+ pid_c = tl.program_id(axis=1)
73
+ pid_h = tl.program_id(axis=2)
74
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
75
+ dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
76
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
77
+
78
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
79
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
80
+ dt_ptrs = dt_ptr + (
81
+ offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
82
+ )
83
+ A_ptrs = A_ptr + offs_h * stride_A_head
84
+ dt_out_ptrs = dt_out_ptr + (
85
+ offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
86
+ )
87
+ dA_cs_ptrs = dA_cumsum_ptr + (
88
+ offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
89
+ )
90
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
91
+
92
+ dt = tl.load(
93
+ dt_ptrs,
94
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
95
+ other=0.0,
96
+ ).to(tl.float32)
97
+ if HAS_DT_BIAS:
98
+ dt_bias = tl.load(
99
+ dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
100
+ ).to(tl.float32)
101
+ dt += dt_bias[:, None]
102
+ if DT_SOFTPLUS:
103
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
104
+ # As of Triton 2.2.0, tl.clamp is not available yet
105
+ # dt = tl.clamp(dt, dt_min, dt_max)
106
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
107
+ dt = tl.where(
108
+ (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
109
+ )
110
+ tl.store(
111
+ dt_out_ptrs,
112
+ dt,
113
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
114
+ )
115
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
116
+ dA = dt * A[:, None]
117
+ dA_cs = tl.cumsum(dA, axis=1)
118
+ tl.store(
119
+ dA_cs_ptrs,
120
+ dA_cs,
121
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
122
+ )
123
+
124
+
125
+ @triton.autotune(
126
+ configs=[
127
+ triton.Config(
128
+ {"BLOCK_SIZE_H": 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
129
+ ),
130
+ triton.Config(
131
+ {"BLOCK_SIZE_H": 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
132
+ ),
133
+ triton.Config(
134
+ {"BLOCK_SIZE_H": 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
135
+ ),
136
+ triton.Config(
137
+ {"BLOCK_SIZE_H": 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
138
+ ),
139
+ triton.Config(
140
+ {"BLOCK_SIZE_H": 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
141
+ ),
142
+ triton.Config(
143
+ {"BLOCK_SIZE_H": 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
144
+ ),
145
+ triton.Config(
146
+ {"BLOCK_SIZE_H": 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])
147
+ ),
148
+ ],
149
+ key=["chunk_size", "nheads"],
150
+ )
151
+ @triton.jit
152
+ def _chunk_cumsum_bwd_kernel(
153
+ # Pointers to matrices
154
+ ddA_ptr,
155
+ ddt_out_ptr,
156
+ dt_ptr,
157
+ A_ptr,
158
+ dt_bias_ptr,
159
+ ddt_ptr,
160
+ dA_ptr,
161
+ ddt_bias_ptr,
162
+ # Matrix dimensions
163
+ batch,
164
+ seqlen,
165
+ nheads,
166
+ chunk_size,
167
+ dt_min,
168
+ dt_max,
169
+ # Strides
170
+ stride_ddA_batch,
171
+ stride_ddA_chunk,
172
+ stride_ddA_head,
173
+ stride_ddA_csize,
174
+ stride_ddt_out_batch,
175
+ stride_ddt_out_chunk,
176
+ stride_ddt_out_head,
177
+ stride_ddt_out_csize,
178
+ stride_dt_batch,
179
+ stride_dt_seqlen,
180
+ stride_dt_head,
181
+ stride_A_head,
182
+ stride_dt_bias_head,
183
+ stride_ddt_batch,
184
+ stride_ddt_seqlen,
185
+ stride_ddt_head,
186
+ stride_dA_head,
187
+ stride_ddt_bias_head,
188
+ # Meta-parameters
189
+ DT_SOFTPLUS: tl.constexpr,
190
+ HAS_DT_BIAS: tl.constexpr,
191
+ BLOCK_SIZE_H: tl.constexpr,
192
+ BLOCK_SIZE_CHUNK: tl.constexpr,
193
+ ):
194
+ pid_b = tl.program_id(axis=0)
195
+ pid_c = tl.program_id(axis=1)
196
+ pid_h = tl.program_id(axis=2)
197
+ ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
198
+ ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
199
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
200
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
201
+
202
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
203
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
204
+ ddt_out_ptrs = ddt_out_ptr + (
205
+ offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize
206
+ )
207
+ ddA_ptrs = ddA_ptr + (
208
+ offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize
209
+ )
210
+ dt_ptrs = dt_ptr + (
211
+ offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
212
+ )
213
+ ddt_ptrs = ddt_ptr + (
214
+ offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen
215
+ )
216
+ A_ptrs = A_ptr + offs_h * stride_A_head
217
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
218
+
219
+ ddA = tl.load(
220
+ ddA_ptrs,
221
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
222
+ other=0.0,
223
+ ).to(tl.float32)
224
+ ddt_out = tl.load(
225
+ ddt_out_ptrs,
226
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
227
+ other=0.0,
228
+ ).to(tl.float32)
229
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
230
+ ddt = ddA * A[:, None] + ddt_out
231
+ dt = tl.load(
232
+ dt_ptrs,
233
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
234
+ other=0.0,
235
+ ).to(tl.float32)
236
+ if HAS_DT_BIAS:
237
+ dt_bias = tl.load(
238
+ dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
239
+ ).to(tl.float32)
240
+ dt += dt_bias[:, None]
241
+ if DT_SOFTPLUS:
242
+ dt_presoftplus = dt
243
+ dt = tl.where(dt <= 20.0, softplus(dt), ddt)
244
+ clamp_mask = (dt < dt_min) | (dt > dt_max)
245
+ # As of Triton 2.2.0, tl.clamp is not available yet
246
+ # dt = tl.clamp(dt, dt_min, dt_max)
247
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
248
+ dt = tl.where(
249
+ (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
250
+ )
251
+ ddt = tl.where(
252
+ (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0
253
+ )
254
+ ddt = tl.where(clamp_mask, 0.0, ddt)
255
+ if DT_SOFTPLUS:
256
+ ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
257
+ tl.store(
258
+ ddt_ptrs,
259
+ ddt,
260
+ mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
261
+ )
262
+ dA = tl.sum(ddA * dt, axis=1)
263
+ tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
264
+ if HAS_DT_BIAS:
265
+ ddt_bias = tl.sum(ddt, axis=1)
266
+ tl.atomic_add(
267
+ ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads
268
+ )
269
+
270
+
271
+ @triton.autotune(
272
+ configs=[
273
+ triton.Config(
274
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
275
+ num_stages=3,
276
+ num_warps=8,
277
+ ),
278
+ triton.Config(
279
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
280
+ num_stages=4,
281
+ num_warps=4,
282
+ ),
283
+ triton.Config(
284
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
285
+ num_stages=4,
286
+ num_warps=4,
287
+ ),
288
+ triton.Config(
289
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
290
+ num_stages=4,
291
+ num_warps=4,
292
+ ),
293
+ triton.Config(
294
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
295
+ num_stages=4,
296
+ num_warps=4,
297
+ ),
298
+ triton.Config(
299
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
300
+ num_stages=4,
301
+ num_warps=4,
302
+ ),
303
+ triton.Config(
304
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
305
+ num_stages=5,
306
+ num_warps=2,
307
+ ),
308
+ triton.Config(
309
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
310
+ num_stages=5,
311
+ num_warps=2,
312
+ ),
313
+ triton.Config(
314
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
315
+ num_stages=4,
316
+ num_warps=2,
317
+ ),
318
+ ],
319
+ key=["hdim", "dstate", "chunk_size"],
320
+ )
321
+ @triton.jit
322
+ def _chunk_state_fwd_kernel(
323
+ # Pointers to matrices
324
+ x_ptr,
325
+ b_ptr,
326
+ states_ptr,
327
+ dt_ptr,
328
+ dA_cumsum_ptr,
329
+ seq_idx_ptr,
330
+ # Matrix dimensions
331
+ hdim,
332
+ dstate,
333
+ chunk_size,
334
+ batch,
335
+ seqlen,
336
+ nheads_ngroups_ratio,
337
+ # Strides
338
+ stride_x_batch,
339
+ stride_x_seqlen,
340
+ stride_x_head,
341
+ stride_x_hdim,
342
+ stride_b_batch,
343
+ stride_b_seqlen,
344
+ stride_b_head,
345
+ stride_b_dstate,
346
+ stride_states_batch,
347
+ stride_states_chunk,
348
+ stride_states_head,
349
+ stride_states_hdim,
350
+ stride_states_dstate,
351
+ stride_dt_batch,
352
+ stride_dt_chunk,
353
+ stride_dt_head,
354
+ stride_dt_csize,
355
+ stride_dA_cs_batch,
356
+ stride_dA_cs_chunk,
357
+ stride_dA_cs_head,
358
+ stride_dA_cs_csize,
359
+ stride_seq_idx_batch,
360
+ stride_seq_idx_seqlen,
361
+ # Meta-parameters
362
+ HAS_SEQ_IDX: tl.constexpr,
363
+ BLOCK_SIZE_M: tl.constexpr,
364
+ BLOCK_SIZE_N: tl.constexpr,
365
+ BLOCK_SIZE_K: tl.constexpr,
366
+ ):
367
+ pid_bc = tl.program_id(axis=1)
368
+ pid_c = pid_bc // batch
369
+ pid_b = pid_bc - pid_c * batch
370
+ pid_h = tl.program_id(axis=2)
371
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
372
+ pid_m = tl.program_id(axis=0) // num_pid_n
373
+ pid_n = tl.program_id(axis=0) % num_pid_n
374
+ b_ptr += (
375
+ pid_b * stride_b_batch
376
+ + pid_c * chunk_size * stride_b_seqlen
377
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
378
+ )
379
+ x_ptr += (
380
+ pid_b * stride_x_batch
381
+ + pid_c * chunk_size * stride_x_seqlen
382
+ + pid_h * stride_x_head
383
+ )
384
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
385
+ dA_cumsum_ptr += (
386
+ pid_b * stride_dA_cs_batch
387
+ + pid_c * stride_dA_cs_chunk
388
+ + pid_h * stride_dA_cs_head
389
+ )
390
+ if HAS_SEQ_IDX:
391
+ seq_idx_ptr += (
392
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
393
+ )
394
+
395
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
396
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
397
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
398
+ x_ptrs = x_ptr + (
399
+ offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
400
+ )
401
+ b_ptrs = b_ptr + (
402
+ offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
403
+ )
404
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
405
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
406
+ tl.float32
407
+ )
408
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
409
+ if HAS_SEQ_IDX:
410
+ seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
411
+
412
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
413
+ if HAS_SEQ_IDX:
414
+ seq_idx_last = tl.load(
415
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
416
+ )
417
+
418
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
419
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
420
+ x = tl.load(
421
+ x_ptrs,
422
+ mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
423
+ other=0.0,
424
+ )
425
+ b = tl.load(
426
+ b_ptrs,
427
+ mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
428
+ other=0.0,
429
+ ).to(tl.float32)
430
+ dA_cs_k = tl.load(
431
+ dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
432
+ ).to(tl.float32)
433
+ if HAS_SEQ_IDX:
434
+ seq_idx_k = tl.load(
435
+ seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
436
+ )
437
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
438
+ tl.float32
439
+ )
440
+ if not HAS_SEQ_IDX:
441
+ scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
442
+ else:
443
+ scale = tl.where(
444
+ seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0
445
+ )
446
+ b *= scale[:, None]
447
+ b = b.to(x_ptr.dtype.element_ty)
448
+ acc += tl.dot(x, b)
449
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
450
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
451
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
452
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
453
+ if HAS_SEQ_IDX:
454
+ seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
455
+ states = acc.to(states_ptr.dtype.element_ty)
456
+
457
+ states_ptr += (
458
+ pid_b * stride_states_batch
459
+ + pid_c * stride_states_chunk
460
+ + pid_h * stride_states_head
461
+ )
462
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
463
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
464
+ states_ptrs = states_ptr + (
465
+ offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
466
+ )
467
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
468
+ tl.store(states_ptrs, states, mask=c_mask)
469
+
470
+
471
+ @triton.autotune(
472
+ configs=[
473
+ triton.Config(
474
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
475
+ num_stages=3,
476
+ num_warps=8,
477
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
478
+ ),
479
+ triton.Config(
480
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
481
+ num_stages=4,
482
+ num_warps=4,
483
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
484
+ ),
485
+ triton.Config(
486
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
487
+ num_stages=4,
488
+ num_warps=4,
489
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
490
+ ),
491
+ triton.Config(
492
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
493
+ num_stages=4,
494
+ num_warps=4,
495
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
496
+ ),
497
+ triton.Config(
498
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
499
+ num_stages=4,
500
+ num_warps=4,
501
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
502
+ ),
503
+ triton.Config(
504
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
505
+ num_stages=4,
506
+ num_warps=4,
507
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
508
+ ),
509
+ triton.Config(
510
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
511
+ num_stages=5,
512
+ num_warps=4,
513
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
514
+ ),
515
+ triton.Config(
516
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
517
+ num_stages=5,
518
+ num_warps=4,
519
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
520
+ ),
521
+ triton.Config(
522
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
523
+ num_stages=4,
524
+ num_warps=4,
525
+ pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"]),
526
+ ),
527
+ ],
528
+ key=["chunk_size", "hdim", "dstate"],
529
+ )
530
+ @triton.jit
531
+ def _chunk_state_bwd_dx_kernel(
532
+ # Pointers to matrices
533
+ x_ptr,
534
+ b_ptr,
535
+ dstates_ptr,
536
+ dt_ptr,
537
+ dA_cumsum_ptr,
538
+ dx_ptr,
539
+ ddt_ptr,
540
+ ddA_cumsum_ptr,
541
+ # Matrix dimensions
542
+ chunk_size,
543
+ hdim,
544
+ dstate,
545
+ batch,
546
+ seqlen,
547
+ nheads_ngroups_ratio,
548
+ # Strides
549
+ stride_x_batch,
550
+ stride_x_seqlen,
551
+ stride_x_head,
552
+ stride_x_hdim,
553
+ stride_b_batch,
554
+ stride_b_seqlen,
555
+ stride_b_head,
556
+ stride_b_dstate,
557
+ stride_dstates_batch,
558
+ stride_dstates_chunk,
559
+ stride_states_head,
560
+ stride_states_hdim,
561
+ stride_states_dstate,
562
+ stride_dt_batch,
563
+ stride_dt_chunk,
564
+ stride_dt_head,
565
+ stride_dt_csize,
566
+ stride_dA_cs_batch,
567
+ stride_dA_cs_chunk,
568
+ stride_dA_cs_head,
569
+ stride_dA_cs_csize,
570
+ stride_dx_batch,
571
+ stride_dx_seqlen,
572
+ stride_dx_head,
573
+ stride_dx_hdim,
574
+ stride_ddt_batch,
575
+ stride_ddt_chunk,
576
+ stride_ddt_head,
577
+ stride_ddt_csize,
578
+ stride_ddA_cs_batch,
579
+ stride_ddA_cs_chunk,
580
+ stride_ddA_cs_head,
581
+ stride_ddA_cs_csize,
582
+ # Meta-parameters
583
+ BLOCK_SIZE_M: tl.constexpr,
584
+ BLOCK_SIZE_N: tl.constexpr,
585
+ BLOCK_SIZE_K: tl.constexpr,
586
+ BLOCK_SIZE_DSTATE: tl.constexpr,
587
+ ):
588
+ pid_bc = tl.program_id(axis=1)
589
+ pid_c = pid_bc // batch
590
+ pid_b = pid_bc - pid_c * batch
591
+ pid_h = tl.program_id(axis=2)
592
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
593
+ pid_m = tl.program_id(axis=0) // num_pid_n
594
+ pid_n = tl.program_id(axis=0) % num_pid_n
595
+ x_ptr += (
596
+ pid_b * stride_x_batch
597
+ + pid_c * chunk_size * stride_x_seqlen
598
+ + pid_h * stride_x_head
599
+ )
600
+ b_ptr += (
601
+ pid_b * stride_b_batch
602
+ + pid_c * chunk_size * stride_b_seqlen
603
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
604
+ )
605
+ dstates_ptr += (
606
+ pid_b * stride_dstates_batch
607
+ + pid_c * stride_dstates_chunk
608
+ + pid_h * stride_states_head
609
+ )
610
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
611
+ ddt_ptr += (
612
+ pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
613
+ )
614
+ ddA_cumsum_ptr += (
615
+ pid_b * stride_ddA_cs_batch
616
+ + pid_c * stride_ddA_cs_chunk
617
+ + pid_h * stride_ddA_cs_head
618
+ )
619
+ dA_cumsum_ptr += (
620
+ pid_b * stride_dA_cs_batch
621
+ + pid_c * stride_dA_cs_chunk
622
+ + pid_h * stride_dA_cs_head
623
+ )
624
+
625
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
+
628
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
629
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
630
+ offs_k = tl.arange(
631
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
632
+ )
633
+ b_ptrs = b_ptr + (
634
+ offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
635
+ )
636
+ dstates_ptrs = dstates_ptr + (
637
+ offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
638
+ )
639
+ if BLOCK_SIZE_DSTATE <= 128:
640
+ b = tl.load(
641
+ b_ptrs,
642
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
643
+ other=0.0,
644
+ )
645
+ dstates = tl.load(
646
+ dstates_ptrs,
647
+ mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
648
+ other=0.0,
649
+ )
650
+ dstates = dstates.to(b_ptr.dtype.element_ty)
651
+ acc = tl.dot(b, dstates)
652
+ else:
653
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
654
+ for k in range(0, dstate, BLOCK_SIZE_K):
655
+ b = tl.load(
656
+ b_ptrs,
657
+ mask=(offs_m[:, None] < chunk_size_limit)
658
+ & (offs_k[None, :] < dstate - k),
659
+ other=0.0,
660
+ )
661
+ dstates = tl.load(
662
+ dstates_ptrs,
663
+ mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
664
+ other=0.0,
665
+ )
666
+ dstates = dstates.to(b_ptr.dtype.element_ty)
667
+ acc += tl.dot(b, dstates)
668
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
669
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
670
+
671
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
672
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
673
+
674
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
675
+ tl.float32
676
+ )
677
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
678
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
679
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
680
+ tl.float32
681
+ )
682
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
683
+ acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
684
+
685
+ x_ptrs = x_ptr + (
686
+ offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
687
+ )
688
+ x = tl.load(
689
+ x_ptrs,
690
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
691
+ other=0.0,
692
+ ).to(tl.float32)
693
+ ddt = tl.sum(acc * x, axis=1)
694
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
695
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
696
+ ddA_cs = -(ddt * dt_m)
697
+ ddA_cs_last = -tl.sum(ddA_cs)
698
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
699
+ tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
700
+ tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
701
+
702
+ dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
703
+ dx_ptr += (
704
+ pid_b * stride_dx_batch
705
+ + pid_c * chunk_size * stride_dx_seqlen
706
+ + pid_h * stride_dx_head
707
+ )
708
+ dx_ptrs = dx_ptr + (
709
+ offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim
710
+ )
711
+ tl.store(
712
+ dx_ptrs,
713
+ dx,
714
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
715
+ )
716
+
717
+
718
+ @triton.autotune(
719
+ configs=[
720
+ triton.Config(
721
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128},
722
+ num_stages=3,
723
+ num_warps=4,
724
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
725
+ ),
726
+ triton.Config(
727
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32},
728
+ num_stages=3,
729
+ num_warps=4,
730
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
731
+ ),
732
+ triton.Config(
733
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128},
734
+ num_stages=3,
735
+ num_warps=4,
736
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
737
+ ),
738
+ triton.Config(
739
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64},
740
+ num_stages=3,
741
+ num_warps=4,
742
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
743
+ ),
744
+ triton.Config(
745
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64},
746
+ num_stages=3,
747
+ num_warps=4,
748
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
749
+ ),
750
+ triton.Config(
751
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32},
752
+ num_stages=3,
753
+ num_warps=4,
754
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
755
+ ),
756
+ triton.Config(
757
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64},
758
+ num_stages=3,
759
+ num_warps=4,
760
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
761
+ ),
762
+ triton.Config(
763
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32},
764
+ num_stages=3,
765
+ num_warps=4,
766
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
767
+ ),
768
+ ],
769
+ key=["chunk_size", "dstate", "hdim"],
770
+ )
771
+ @triton.jit
772
+ def _chunk_state_bwd_db_kernel(
773
+ # Pointers to matrices
774
+ x_ptr,
775
+ dstates_ptr,
776
+ b_ptr,
777
+ dt_ptr,
778
+ dA_cumsum_ptr,
779
+ seq_idx_ptr,
780
+ db_ptr,
781
+ ddA_cumsum_ptr,
782
+ # Matrix dimensions
783
+ chunk_size,
784
+ dstate,
785
+ hdim,
786
+ batch,
787
+ seqlen,
788
+ nheads,
789
+ nheads_per_program,
790
+ ngroups,
791
+ # Strides
792
+ stride_x_batch,
793
+ stride_x_seqlen,
794
+ stride_x_head,
795
+ stride_x_hdim,
796
+ stride_dstates_batch,
797
+ stride_dstates_chunk,
798
+ stride_states_head,
799
+ stride_states_hdim,
800
+ stride_states_dstate,
801
+ stride_b_batch,
802
+ stride_b_seqlen,
803
+ stride_b_head,
804
+ stride_b_dstate,
805
+ stride_dt_batch,
806
+ stride_dt_chunk,
807
+ stride_dt_head,
808
+ stride_dt_csize,
809
+ stride_dA_cs_batch,
810
+ stride_dA_cs_chunk,
811
+ stride_dA_cs_head,
812
+ stride_dA_cs_csize,
813
+ stride_seq_idx_batch,
814
+ stride_seq_idx_seqlen,
815
+ stride_db_batch,
816
+ stride_db_seqlen,
817
+ stride_db_split,
818
+ stride_db_group,
819
+ stride_db_dstate,
820
+ stride_ddA_cs_batch,
821
+ stride_ddA_cs_chunk,
822
+ stride_ddA_cs_head,
823
+ stride_ddA_cs_csize,
824
+ # Meta-parameters
825
+ HAS_DDA_CS: tl.constexpr,
826
+ HAS_SEQ_IDX: tl.constexpr,
827
+ BLOCK_SIZE_M: tl.constexpr,
828
+ BLOCK_SIZE_N: tl.constexpr,
829
+ BLOCK_SIZE_K: tl.constexpr,
830
+ ):
831
+ pid_bc = tl.program_id(axis=1)
832
+ pid_c = pid_bc // batch
833
+ pid_b = pid_bc - pid_c * batch
834
+ pid_sg = tl.program_id(axis=2)
835
+ pid_s = pid_sg // ngroups
836
+ pid_g = pid_sg - pid_s * ngroups
837
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
838
+ pid_m = tl.program_id(axis=0) // num_pid_n
839
+ pid_n = tl.program_id(axis=0) % num_pid_n
840
+ x_ptr += (
841
+ pid_b * stride_x_batch
842
+ + pid_c * chunk_size * stride_x_seqlen
843
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
844
+ )
845
+ db_ptr += (
846
+ pid_b * stride_db_batch
847
+ + pid_c * chunk_size * stride_db_seqlen
848
+ + pid_g * stride_db_group
849
+ + pid_s * stride_db_split
850
+ )
851
+ dstates_ptr += (
852
+ pid_b * stride_dstates_batch
853
+ + pid_c * stride_dstates_chunk
854
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
855
+ * stride_states_head
856
+ )
857
+ dt_ptr += (
858
+ pid_b * stride_dt_batch
859
+ + pid_c * stride_dt_chunk
860
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
861
+ )
862
+ dA_cumsum_ptr += (
863
+ pid_b * stride_dA_cs_batch
864
+ + pid_c * stride_dA_cs_chunk
865
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
866
+ )
867
+ if HAS_DDA_CS:
868
+ b_ptr += (
869
+ pid_b * stride_b_batch
870
+ + pid_c * chunk_size * stride_b_seqlen
871
+ + pid_g * stride_b_head
872
+ )
873
+ ddA_cumsum_ptr += (
874
+ pid_b * stride_ddA_cs_batch
875
+ + pid_c * stride_ddA_cs_chunk
876
+ + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program)
877
+ * stride_ddA_cs_head
878
+ )
879
+ if HAS_SEQ_IDX:
880
+ seq_idx_ptr += (
881
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
882
+ )
883
+
884
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
885
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
886
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
887
+ x_ptrs = x_ptr + (
888
+ offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim
889
+ )
890
+ dstates_ptrs = dstates_ptr + (
891
+ offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim
892
+ )
893
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
894
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
895
+ if HAS_DDA_CS:
896
+ b_ptrs = b_ptr + (
897
+ offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate
898
+ )
899
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
900
+
901
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
902
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
903
+ if HAS_DDA_CS:
904
+ b = tl.load(
905
+ b_ptrs,
906
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
907
+ other=0.0,
908
+ ).to(tl.float32)
909
+ if HAS_SEQ_IDX:
910
+ seq_idx_m = tl.load(
911
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
912
+ mask=offs_m < chunk_size_limit,
913
+ other=-1,
914
+ )
915
+ seq_idx_last = tl.load(
916
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
917
+ )
918
+ nheads_iter = min(
919
+ nheads_per_program, nheads // ngroups - pid_s * nheads_per_program
920
+ )
921
+ for h in range(nheads_iter):
922
+ x = tl.load(
923
+ x_ptrs,
924
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim),
925
+ other=0.0,
926
+ )
927
+ dstates = tl.load(
928
+ dstates_ptrs,
929
+ mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate),
930
+ other=0.0,
931
+ )
932
+ dstates = dstates.to(x_ptrs.dtype.element_ty)
933
+ db = tl.dot(x, dstates)
934
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
935
+ tl.float32
936
+ )
937
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(
938
+ tl.float32
939
+ )
940
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
941
+ if not HAS_SEQ_IDX:
942
+ scale = tl.exp(dA_cs_last - dA_cs_m)
943
+ else:
944
+ scale = tl.where(
945
+ seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0
946
+ )
947
+ db *= (scale * dt_m)[:, None]
948
+ if HAS_DDA_CS:
949
+ # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
950
+ ddA_cs = tl.sum(db * b, axis=1)
951
+ tl.atomic_add(
952
+ ddA_cumsum_ptrs + stride_ddA_cs_csize,
953
+ ddA_cs,
954
+ mask=offs_m < chunk_size - 1,
955
+ )
956
+ acc += db
957
+ x_ptrs += stride_x_head
958
+ dstates_ptrs += stride_states_head
959
+ dt_ptrs += stride_dt_head
960
+ dA_cumsum_ptr += stride_dA_cs_head
961
+ dA_cumsum_ptrs += stride_dA_cs_head
962
+ if HAS_DDA_CS:
963
+ ddA_cumsum_ptrs += stride_ddA_cs_head
964
+
965
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
966
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
967
+ # if HAS_SEQ_IDX:
968
+ # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
969
+ # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
970
+ # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
971
+ db_ptrs = db_ptr + (
972
+ offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate
973
+ )
974
+ tl.store(
975
+ db_ptrs,
976
+ acc,
977
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate),
978
+ )
979
+
980
+
981
+ @triton.autotune(
982
+ configs=[
983
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
984
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
985
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
986
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
987
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
988
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
989
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
990
+ # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
991
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
992
+ triton.Config(
993
+ {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
994
+ num_stages=3,
995
+ num_warps=4,
996
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
997
+ ),
998
+ triton.Config(
999
+ {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1000
+ num_stages=3,
1001
+ num_warps=4,
1002
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1003
+ ),
1004
+ triton.Config(
1005
+ {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1006
+ num_stages=3,
1007
+ num_warps=4,
1008
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1009
+ ),
1010
+ triton.Config(
1011
+ {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1012
+ num_stages=3,
1013
+ num_warps=4,
1014
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1015
+ ),
1016
+ triton.Config(
1017
+ {"BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 32},
1018
+ num_stages=4,
1019
+ num_warps=8,
1020
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1021
+ ),
1022
+ triton.Config(
1023
+ {"BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1024
+ num_stages=4,
1025
+ num_warps=8,
1026
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1027
+ ),
1028
+ triton.Config(
1029
+ {"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1030
+ num_stages=4,
1031
+ num_warps=8,
1032
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1033
+ ),
1034
+ triton.Config(
1035
+ {"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1036
+ num_stages=4,
1037
+ num_warps=8,
1038
+ pre_hook=init_to_zero(["ddA_cumsum_ptr"]),
1039
+ ),
1040
+ ],
1041
+ key=["chunk_size", "hdim", "dstate"],
1042
+ )
1043
+ @triton.jit
1044
+ def _chunk_state_bwd_ddAcs_stable_kernel(
1045
+ # Pointers to matrices
1046
+ x_ptr,
1047
+ b_ptr,
1048
+ dstates_ptr,
1049
+ dt_ptr,
1050
+ dA_cumsum_ptr,
1051
+ seq_idx_ptr,
1052
+ ddA_cumsum_ptr,
1053
+ # Matrix dimensions
1054
+ chunk_size,
1055
+ hdim,
1056
+ dstate,
1057
+ batch,
1058
+ seqlen,
1059
+ nheads_ngroups_ratio,
1060
+ # Strides
1061
+ stride_x_batch,
1062
+ stride_x_seqlen,
1063
+ stride_x_head,
1064
+ stride_x_hdim,
1065
+ stride_b_batch,
1066
+ stride_b_seqlen,
1067
+ stride_b_head,
1068
+ stride_b_dstate,
1069
+ stride_dstates_batch,
1070
+ stride_dstates_chunk,
1071
+ stride_states_head,
1072
+ stride_states_hdim,
1073
+ stride_states_dstate,
1074
+ stride_dt_batch,
1075
+ stride_dt_chunk,
1076
+ stride_dt_head,
1077
+ stride_dt_csize,
1078
+ stride_dA_cs_batch,
1079
+ stride_dA_cs_chunk,
1080
+ stride_dA_cs_head,
1081
+ stride_dA_cs_csize,
1082
+ stride_seq_idx_batch,
1083
+ stride_seq_idx_seqlen,
1084
+ stride_ddA_cs_batch,
1085
+ stride_ddA_cs_chunk,
1086
+ stride_ddA_cs_head,
1087
+ stride_ddA_cs_csize,
1088
+ # Meta-parameters
1089
+ HAS_SEQ_IDX: tl.constexpr,
1090
+ BLOCK_SIZE_M: tl.constexpr,
1091
+ BLOCK_SIZE_N: tl.constexpr,
1092
+ BLOCK_SIZE_K: tl.constexpr,
1093
+ BLOCK_SIZE_DSTATE: tl.constexpr,
1094
+ ):
1095
+ pid_bc = tl.program_id(axis=1)
1096
+ pid_c = pid_bc // batch
1097
+ pid_b = pid_bc - pid_c * batch
1098
+ pid_h = tl.program_id(axis=2)
1099
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
1100
+ pid_m = tl.program_id(axis=0) // num_pid_n
1101
+ pid_n = tl.program_id(axis=0) % num_pid_n
1102
+ x_ptr += (
1103
+ pid_b * stride_x_batch
1104
+ + pid_c * chunk_size * stride_x_seqlen
1105
+ + pid_h * stride_x_head
1106
+ )
1107
+ b_ptr += (
1108
+ pid_b * stride_b_batch
1109
+ + pid_c * chunk_size * stride_b_seqlen
1110
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
1111
+ )
1112
+ dstates_ptr += (
1113
+ pid_b * stride_dstates_batch
1114
+ + pid_c * stride_dstates_chunk
1115
+ + pid_h * stride_states_head
1116
+ )
1117
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
1118
+ ddA_cumsum_ptr += (
1119
+ pid_b * stride_ddA_cs_batch
1120
+ + pid_c * stride_ddA_cs_chunk
1121
+ + pid_h * stride_ddA_cs_head
1122
+ )
1123
+ dA_cumsum_ptr += (
1124
+ pid_b * stride_dA_cs_batch
1125
+ + pid_c * stride_dA_cs_chunk
1126
+ + pid_h * stride_dA_cs_head
1127
+ )
1128
+ if HAS_SEQ_IDX:
1129
+ seq_idx_ptr += (
1130
+ pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
1131
+ )
1132
+
1133
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1134
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1135
+
1136
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
1137
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
1138
+ offs_k = tl.arange(
1139
+ 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
1140
+ )
1141
+ b_ptrs = b_ptr + (
1142
+ offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate
1143
+ )
1144
+ dstates_ptrs = dstates_ptr + (
1145
+ offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate
1146
+ )
1147
+ if BLOCK_SIZE_DSTATE <= 128:
1148
+ b = tl.load(
1149
+ b_ptrs,
1150
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate),
1151
+ other=0.0,
1152
+ )
1153
+ dstates = tl.load(
1154
+ dstates_ptrs,
1155
+ mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim),
1156
+ other=0.0,
1157
+ )
1158
+ dstates = dstates.to(b_ptr.dtype.element_ty)
1159
+ acc = tl.dot(b, dstates)
1160
+ else:
1161
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1162
+ for k in range(0, dstate, BLOCK_SIZE_K):
1163
+ b = tl.load(
1164
+ b_ptrs,
1165
+ mask=(offs_m[:, None] < chunk_size_limit)
1166
+ & (offs_k[None, :] < dstate - k),
1167
+ other=0.0,
1168
+ )
1169
+ dstates = tl.load(
1170
+ dstates_ptrs,
1171
+ mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim),
1172
+ other=0.0,
1173
+ )
1174
+ dstates = dstates.to(b_ptr.dtype.element_ty)
1175
+ acc += tl.dot(b, dstates)
1176
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
1177
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
1178
+
1179
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1180
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1181
+
1182
+ dA_cs_m = tl.load(
1183
+ dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
1184
+ ).to(tl.float32)
1185
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
1186
+ tl.float32
1187
+ )
1188
+ if not HAS_SEQ_IDX:
1189
+ scale = tl.exp(dA_cs_last - dA_cs_m)
1190
+ else:
1191
+ seq_idx_m = tl.load(
1192
+ seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
1193
+ mask=offs_m < chunk_size_limit,
1194
+ other=-1,
1195
+ )
1196
+ seq_idx_last = tl.load(
1197
+ seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
1198
+ )
1199
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
1200
+ acc *= scale[:, None]
1201
+
1202
+ x_ptrs = x_ptr + (
1203
+ offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
1204
+ )
1205
+ x = tl.load(
1206
+ x_ptrs,
1207
+ mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
1208
+ other=0.0,
1209
+ ).to(tl.float32)
1210
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
1211
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
1212
+ ddt = tl.sum(acc * x, axis=1)
1213
+ # ddA_cs = -(ddt * dt_m)
1214
+ # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
1215
+ # then call torch.cumsum outside this kernel.
1216
+ # ddA_cs = tl.cumsum(ddt * dt_m)
1217
+ ddA_cs = ddt * dt_m
1218
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
1219
+ # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
1220
+ tl.atomic_add(
1221
+ ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1
1222
+ )
1223
+
1224
+
1225
+ @triton.autotune(
1226
+ configs=[
1227
+ triton.Config(
1228
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
1229
+ num_stages=3,
1230
+ num_warps=8,
1231
+ ),
1232
+ triton.Config(
1233
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
1234
+ num_stages=4,
1235
+ num_warps=4,
1236
+ ),
1237
+ triton.Config(
1238
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1239
+ num_stages=4,
1240
+ num_warps=4,
1241
+ ),
1242
+ triton.Config(
1243
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1244
+ num_stages=4,
1245
+ num_warps=4,
1246
+ ),
1247
+ triton.Config(
1248
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
1249
+ num_stages=4,
1250
+ num_warps=4,
1251
+ ),
1252
+ triton.Config(
1253
+ {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1254
+ num_stages=4,
1255
+ num_warps=4,
1256
+ ),
1257
+ triton.Config(
1258
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
1259
+ num_stages=5,
1260
+ num_warps=2,
1261
+ ),
1262
+ triton.Config(
1263
+ {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1264
+ num_stages=5,
1265
+ num_warps=2,
1266
+ ),
1267
+ triton.Config(
1268
+ {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
1269
+ num_stages=4,
1270
+ num_warps=2,
1271
+ ),
1272
+ ],
1273
+ key=["hdim", "dstate", "chunk_size"],
1274
+ )
1275
+ @triton.jit
1276
+ def _chunk_state_varlen_kernel(
1277
+ # Pointers to matrices
1278
+ x_ptr,
1279
+ b_ptr,
1280
+ dt_ptr,
1281
+ dA_cumsum_ptr,
1282
+ chunk_states_ptr,
1283
+ cu_seqlens_ptr,
1284
+ states_ptr,
1285
+ # Matrix dimensions
1286
+ hdim,
1287
+ dstate,
1288
+ chunk_size,
1289
+ seqlen,
1290
+ nheads_ngroups_ratio,
1291
+ # Strides
1292
+ stride_x_seqlen,
1293
+ stride_x_head,
1294
+ stride_x_hdim,
1295
+ stride_b_seqlen,
1296
+ stride_b_head,
1297
+ stride_b_dstate,
1298
+ stride_dt_chunk,
1299
+ stride_dt_head,
1300
+ stride_dt_csize,
1301
+ stride_dA_cs_chunk,
1302
+ stride_dA_cs_head,
1303
+ stride_dA_cs_csize,
1304
+ stride_chunk_states_chunk,
1305
+ stride_chunk_states_head,
1306
+ stride_chunk_states_hdim,
1307
+ stride_chunk_states_dstate,
1308
+ stride_states_batch,
1309
+ stride_states_head,
1310
+ stride_states_hdim,
1311
+ stride_states_dstate,
1312
+ # Meta-parameters
1313
+ BLOCK_SIZE_M: tl.constexpr,
1314
+ BLOCK_SIZE_N: tl.constexpr,
1315
+ BLOCK_SIZE_K: tl.constexpr,
1316
+ ):
1317
+ pid_b = tl.program_id(axis=1)
1318
+ pid_h = tl.program_id(axis=2)
1319
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
1320
+ pid_m = tl.program_id(axis=0) // num_pid_n
1321
+ pid_n = tl.program_id(axis=0) % num_pid_n
1322
+ end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
1323
+ pid_c = (end_idx - 1) // chunk_size
1324
+ b_ptr += (
1325
+ pid_c * chunk_size * stride_b_seqlen
1326
+ + (pid_h // nheads_ngroups_ratio) * stride_b_head
1327
+ )
1328
+ x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
1329
+ dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
1330
+ dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
1331
+ chunk_states_ptr += (
1332
+ pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
1333
+ )
1334
+
1335
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1336
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1337
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
1338
+ x_ptrs = x_ptr + (
1339
+ offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
1340
+ )
1341
+ b_ptrs = b_ptr + (
1342
+ offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
1343
+ )
1344
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
1345
+ dA_cs_last = tl.load(
1346
+ dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
1347
+ ).to(tl.float32)
1348
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
1349
+
1350
+ chunk_size_limit = end_idx - pid_c * chunk_size
1351
+ start_idx = tl.load(cu_seqlens_ptr + pid_b)
1352
+ start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
1353
+
1354
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
1355
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
1356
+ x = tl.load(
1357
+ x_ptrs,
1358
+ mask=(offs_m[:, None] < hdim)
1359
+ & (offs_k[None, :] < chunk_size_limit - k)
1360
+ & (offs_k[None, :] >= start_idx_cur - k),
1361
+ other=0.0,
1362
+ )
1363
+ b = tl.load(
1364
+ b_ptrs,
1365
+ mask=(offs_k[:, None] < chunk_size_limit - k)
1366
+ & (offs_n[None, :] < dstate)
1367
+ & (offs_k[:, None] >= start_idx_cur - k),
1368
+ other=0.0,
1369
+ ).to(tl.float32)
1370
+ dA_cs_k = tl.load(
1371
+ dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
1372
+ ).to(tl.float32)
1373
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
1374
+ tl.float32
1375
+ )
1376
+ scale = tl.where(
1377
+ (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
1378
+ tl.exp((dA_cs_last - dA_cs_k)) * dt_k,
1379
+ 0.0,
1380
+ )
1381
+ b *= scale[:, None]
1382
+ b = b.to(x_ptr.dtype.element_ty)
1383
+ acc += tl.dot(x, b)
1384
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
1385
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
1386
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
1387
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
1388
+
1389
+ # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
1390
+ if start_idx < pid_c * chunk_size:
1391
+ chunk_states_ptrs = chunk_states_ptr + (
1392
+ offs_m[:, None] * stride_chunk_states_hdim
1393
+ + offs_n[None, :] * stride_chunk_states_dstate
1394
+ )
1395
+ chunk_states = tl.load(
1396
+ chunk_states_ptrs,
1397
+ mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
1398
+ other=0.0,
1399
+ ).to(tl.float32)
1400
+ # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
1401
+ scale = tl.exp(dA_cs_last)
1402
+ acc += chunk_states * scale
1403
+
1404
+ states = acc.to(states_ptr.dtype.element_ty)
1405
+
1406
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
1407
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
1408
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
1409
+ states_ptrs = states_ptr + (
1410
+ offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
1411
+ )
1412
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
1413
+ tl.store(states_ptrs, states, mask=c_mask)
1414
+
1415
+
1416
+ def _chunk_cumsum_fwd(
1417
+ dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
1418
+ ):
1419
+ batch, seqlen, nheads = dt.shape
1420
+ assert A.shape == (nheads,)
1421
+ if dt_bias is not None:
1422
+ assert dt_bias.shape == (nheads,)
1423
+ nchunks = math.ceil(seqlen / chunk_size)
1424
+ dt_out = torch.empty(
1425
+ batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1426
+ )
1427
+ dA_cumsum = torch.empty(
1428
+ batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1429
+ )
1430
+ grid_chunk_cs = lambda META: (
1431
+ batch,
1432
+ nchunks,
1433
+ triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1434
+ )
1435
+ with torch.cuda.device(dt.device.index):
1436
+ _chunk_cumsum_fwd_kernel[grid_chunk_cs](
1437
+ dt,
1438
+ A,
1439
+ dt_bias,
1440
+ dt_out,
1441
+ dA_cumsum,
1442
+ batch,
1443
+ seqlen,
1444
+ nheads,
1445
+ chunk_size,
1446
+ dt_limit[0],
1447
+ dt_limit[1],
1448
+ dt.stride(0),
1449
+ dt.stride(1),
1450
+ dt.stride(2),
1451
+ A.stride(0),
1452
+ dt_bias.stride(0) if dt_bias is not None else 0,
1453
+ dt_out.stride(0),
1454
+ dt_out.stride(2),
1455
+ dt_out.stride(1),
1456
+ dt_out.stride(3),
1457
+ dA_cumsum.stride(0),
1458
+ dA_cumsum.stride(2),
1459
+ dA_cumsum.stride(1),
1460
+ dA_cumsum.stride(3),
1461
+ dt_softplus,
1462
+ HAS_DT_BIAS=dt_bias is not None,
1463
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1464
+ )
1465
+ return dA_cumsum, dt_out
1466
+
1467
+
1468
+ def _chunk_cumsum_bwd(
1469
+ ddA,
1470
+ ddt_out,
1471
+ dt,
1472
+ A,
1473
+ dt_bias=None,
1474
+ dt_softplus=False,
1475
+ dt_limit=(0.0, float("inf")),
1476
+ ddt=None,
1477
+ ):
1478
+ batch, seqlen, nheads = dt.shape
1479
+ _, _, nchunks, chunk_size = ddA.shape
1480
+ assert ddA.shape == (batch, nheads, nchunks, chunk_size)
1481
+ assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
1482
+ assert A.shape == (nheads,)
1483
+ if dt_bias is not None:
1484
+ assert dt_bias.shape == (nheads,)
1485
+ ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
1486
+ else:
1487
+ ddt_bias = None
1488
+ if ddt is not None:
1489
+ assert ddt.shape == dt.shape
1490
+ else:
1491
+ ddt = torch.empty_like(dt)
1492
+ dA = torch.empty_like(A, dtype=torch.float32)
1493
+ grid_chunk_cs = lambda META: (
1494
+ batch,
1495
+ nchunks,
1496
+ triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
1497
+ )
1498
+ with torch.cuda.device(dt.device.index):
1499
+ _chunk_cumsum_bwd_kernel[grid_chunk_cs](
1500
+ ddA,
1501
+ ddt_out,
1502
+ dt,
1503
+ A,
1504
+ dt_bias,
1505
+ ddt,
1506
+ dA,
1507
+ ddt_bias,
1508
+ batch,
1509
+ seqlen,
1510
+ nheads,
1511
+ chunk_size,
1512
+ dt_limit[0],
1513
+ dt_limit[1],
1514
+ ddA.stride(0),
1515
+ ddA.stride(2),
1516
+ ddA.stride(1),
1517
+ ddA.stride(3),
1518
+ ddt_out.stride(0),
1519
+ ddt_out.stride(2),
1520
+ ddt_out.stride(1),
1521
+ ddt_out.stride(3),
1522
+ dt.stride(0),
1523
+ dt.stride(1),
1524
+ dt.stride(2),
1525
+ A.stride(0),
1526
+ dt_bias.stride(0) if dt_bias is not None else 0,
1527
+ ddt.stride(0),
1528
+ ddt.stride(1),
1529
+ ddt.stride(2),
1530
+ dA.stride(0),
1531
+ ddt_bias.stride(0) if ddt_bias is not None else 0,
1532
+ dt_softplus,
1533
+ HAS_DT_BIAS=dt_bias is not None,
1534
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1535
+ )
1536
+ return ddt, dA, ddt_bias
1537
+
1538
+
1539
+ def _chunk_state_fwd(
1540
+ B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
1541
+ ):
1542
+ batch, seqlen, nheads, headdim = x.shape
1543
+ _, _, nchunks, chunk_size = dt.shape
1544
+ _, _, ngroups, dstate = B.shape
1545
+ assert nheads % ngroups == 0
1546
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1547
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1548
+ assert dA_cumsum.shape == dt.shape
1549
+ if seq_idx is not None:
1550
+ assert seq_idx.shape == (batch, seqlen)
1551
+ if states is not None:
1552
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
1553
+ else:
1554
+ states_dtype = torch.float32 if states_in_fp32 else B.dtype
1555
+ states = torch.empty(
1556
+ (batch, nchunks, nheads, headdim, dstate),
1557
+ device=x.device,
1558
+ dtype=states_dtype,
1559
+ )
1560
+ grid = lambda META: (
1561
+ triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1562
+ * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1563
+ batch * nchunks,
1564
+ nheads,
1565
+ )
1566
+ with torch.cuda.device(x.device.index):
1567
+ _chunk_state_fwd_kernel[grid](
1568
+ x,
1569
+ B,
1570
+ states,
1571
+ dt,
1572
+ dA_cumsum,
1573
+ seq_idx,
1574
+ headdim,
1575
+ dstate,
1576
+ chunk_size,
1577
+ batch,
1578
+ seqlen,
1579
+ nheads // ngroups,
1580
+ x.stride(0),
1581
+ x.stride(1),
1582
+ x.stride(2),
1583
+ x.stride(3),
1584
+ B.stride(0),
1585
+ B.stride(1),
1586
+ B.stride(2),
1587
+ B.stride(-1),
1588
+ states.stride(0),
1589
+ states.stride(1),
1590
+ states.stride(2),
1591
+ states.stride(3),
1592
+ states.stride(4),
1593
+ dt.stride(0),
1594
+ dt.stride(2),
1595
+ dt.stride(1),
1596
+ dt.stride(3),
1597
+ dA_cumsum.stride(0),
1598
+ dA_cumsum.stride(2),
1599
+ dA_cumsum.stride(1),
1600
+ dA_cumsum.stride(3),
1601
+ *(
1602
+ (seq_idx.stride(0), seq_idx.stride(1))
1603
+ if seq_idx is not None
1604
+ else (0, 0)
1605
+ ),
1606
+ HAS_SEQ_IDX=seq_idx is not None,
1607
+ )
1608
+ return states
1609
+
1610
+
1611
+ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
1612
+ batch, seqlen, nheads, headdim = x.shape
1613
+ _, _, nchunks, chunk_size = dt.shape
1614
+ _, _, ngroups, dstate = B.shape
1615
+ assert nheads % ngroups == 0
1616
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1617
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1618
+ assert dA_cumsum.shape == dt.shape
1619
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1620
+ if dx is not None:
1621
+ assert dx.shape == x.shape
1622
+ else:
1623
+ dx = torch.empty_like(x)
1624
+ ddt = torch.empty(
1625
+ batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
1626
+ )
1627
+ ddA_cumsum = torch.empty(
1628
+ batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32
1629
+ )
1630
+ grid_dx = lambda META: (
1631
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1632
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1633
+ batch * nchunks,
1634
+ nheads,
1635
+ )
1636
+ with torch.cuda.device(x.device.index):
1637
+ _chunk_state_bwd_dx_kernel[grid_dx](
1638
+ x,
1639
+ B,
1640
+ dstates,
1641
+ dt,
1642
+ dA_cumsum,
1643
+ dx,
1644
+ ddt,
1645
+ ddA_cumsum,
1646
+ chunk_size,
1647
+ headdim,
1648
+ dstate,
1649
+ batch,
1650
+ seqlen,
1651
+ nheads // ngroups,
1652
+ x.stride(0),
1653
+ x.stride(1),
1654
+ x.stride(2),
1655
+ x.stride(3),
1656
+ B.stride(0),
1657
+ B.stride(1),
1658
+ B.stride(2),
1659
+ B.stride(-1),
1660
+ dstates.stride(0),
1661
+ dstates.stride(1),
1662
+ dstates.stride(2),
1663
+ dstates.stride(3),
1664
+ dstates.stride(4),
1665
+ dt.stride(0),
1666
+ dt.stride(2),
1667
+ dt.stride(1),
1668
+ dt.stride(3),
1669
+ dA_cumsum.stride(0),
1670
+ dA_cumsum.stride(2),
1671
+ dA_cumsum.stride(1),
1672
+ dA_cumsum.stride(3),
1673
+ dx.stride(0),
1674
+ dx.stride(1),
1675
+ dx.stride(2),
1676
+ dx.stride(3),
1677
+ ddt.stride(0),
1678
+ ddt.stride(2),
1679
+ ddt.stride(1),
1680
+ ddt.stride(3),
1681
+ ddA_cumsum.stride(0),
1682
+ ddA_cumsum.stride(2),
1683
+ ddA_cumsum.stride(1),
1684
+ ddA_cumsum.stride(3),
1685
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1686
+ )
1687
+ return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
1688
+
1689
+
1690
+ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
1691
+ batch, seqlen, nheads, headdim = x.shape
1692
+ _, _, nchunks, chunk_size = dt.shape
1693
+ dstate = dstates.shape[-1]
1694
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1695
+ assert dA_cumsum.shape == dt.shape
1696
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1697
+ if seq_idx is not None:
1698
+ assert seq_idx.shape == (batch, seqlen)
1699
+ if B is not None:
1700
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1701
+ B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
1702
+ # Use torch.empty since the Triton kernel will call init_to_zero
1703
+ ddA_cumsum = torch.empty(
1704
+ batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1705
+ )
1706
+ ddA_cumsum_strides = (
1707
+ ddA_cumsum.stride(0),
1708
+ ddA_cumsum.stride(2),
1709
+ ddA_cumsum.stride(1),
1710
+ ddA_cumsum.stride(3),
1711
+ )
1712
+ else:
1713
+ B_strides = (0, 0, 0, 0)
1714
+ ddA_cumsum = None
1715
+ ddA_cumsum_strides = (0, 0, 0, 0)
1716
+ nheads_ngroups_ratio = nheads // ngroups
1717
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
1718
+ nheads_per_program = max(
1719
+ min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1
1720
+ )
1721
+ nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
1722
+ dB = torch.empty(
1723
+ batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32
1724
+ )
1725
+ grid_db = lambda META: (
1726
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1727
+ * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1728
+ batch * nchunks,
1729
+ nsplits * ngroups,
1730
+ )
1731
+ with torch.cuda.device(x.device.index):
1732
+ _chunk_state_bwd_db_kernel[grid_db](
1733
+ x,
1734
+ dstates,
1735
+ B,
1736
+ dt,
1737
+ dA_cumsum,
1738
+ seq_idx,
1739
+ dB,
1740
+ ddA_cumsum,
1741
+ chunk_size,
1742
+ dstate,
1743
+ headdim,
1744
+ batch,
1745
+ seqlen,
1746
+ nheads,
1747
+ nheads_per_program,
1748
+ ngroups,
1749
+ x.stride(0),
1750
+ x.stride(1),
1751
+ x.stride(2),
1752
+ x.stride(3),
1753
+ dstates.stride(0),
1754
+ dstates.stride(1),
1755
+ dstates.stride(2),
1756
+ dstates.stride(3),
1757
+ dstates.stride(4),
1758
+ *B_strides,
1759
+ dt.stride(0),
1760
+ dt.stride(2),
1761
+ dt.stride(1),
1762
+ dt.stride(3),
1763
+ dA_cumsum.stride(0),
1764
+ dA_cumsum.stride(2),
1765
+ dA_cumsum.stride(1),
1766
+ dA_cumsum.stride(3),
1767
+ *(
1768
+ (seq_idx.stride(0), seq_idx.stride(1))
1769
+ if seq_idx is not None
1770
+ else (0, 0)
1771
+ ),
1772
+ dB.stride(0),
1773
+ dB.stride(1),
1774
+ dB.stride(2),
1775
+ dB.stride(3),
1776
+ dB.stride(4),
1777
+ *ddA_cumsum_strides,
1778
+ HAS_DDA_CS=ddA_cumsum is not None,
1779
+ HAS_SEQ_IDX=seq_idx is not None,
1780
+ BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
1781
+ )
1782
+ dB = dB.sum(2)
1783
+ if ddA_cumsum is not None:
1784
+ # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
1785
+ # to the state of the chunk.
1786
+ # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1787
+ # But it's easier to just do the cumsum for all elements, the result will be the same.
1788
+ torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
1789
+ return dB if B is None else (dB, ddA_cumsum)
1790
+
1791
+
1792
+ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
1793
+ batch, seqlen, nheads, headdim = x.shape
1794
+ _, _, nchunks, chunk_size = dt.shape
1795
+ _, _, ngroups, dstate = B.shape
1796
+ assert nheads % ngroups == 0
1797
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1798
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1799
+ assert dA_cumsum.shape == dt.shape
1800
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1801
+ if seq_idx is not None:
1802
+ assert seq_idx.shape == (batch, seqlen)
1803
+ # Use torch.empty since the Triton kernel will call init_to_zero
1804
+ ddA_cumsum = torch.empty(
1805
+ batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32
1806
+ )
1807
+ grid_ddtcs = lambda META: (
1808
+ triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
1809
+ * triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
1810
+ batch * nchunks,
1811
+ nheads,
1812
+ )
1813
+ with torch.cuda.device(x.device.index):
1814
+ _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
1815
+ x,
1816
+ B,
1817
+ dstates,
1818
+ dt,
1819
+ dA_cumsum,
1820
+ seq_idx,
1821
+ ddA_cumsum,
1822
+ chunk_size,
1823
+ headdim,
1824
+ dstate,
1825
+ batch,
1826
+ seqlen,
1827
+ nheads // ngroups,
1828
+ x.stride(0),
1829
+ x.stride(1),
1830
+ x.stride(2),
1831
+ x.stride(3),
1832
+ B.stride(0),
1833
+ B.stride(1),
1834
+ B.stride(2),
1835
+ B.stride(-1),
1836
+ dstates.stride(0),
1837
+ dstates.stride(1),
1838
+ dstates.stride(2),
1839
+ dstates.stride(3),
1840
+ dstates.stride(4),
1841
+ dt.stride(0),
1842
+ dt.stride(2),
1843
+ dt.stride(1),
1844
+ dt.stride(3),
1845
+ dA_cumsum.stride(0),
1846
+ dA_cumsum.stride(2),
1847
+ dA_cumsum.stride(1),
1848
+ dA_cumsum.stride(3),
1849
+ *(
1850
+ (seq_idx.stride(0), seq_idx.stride(1))
1851
+ if seq_idx is not None
1852
+ else (0, 0)
1853
+ ),
1854
+ ddA_cumsum.stride(0),
1855
+ ddA_cumsum.stride(2),
1856
+ ddA_cumsum.stride(1),
1857
+ ddA_cumsum.stride(3),
1858
+ HAS_SEQ_IDX=seq_idx is not None,
1859
+ BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
1860
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1861
+ )
1862
+ torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
1863
+ return ddA_cumsum
1864
+
1865
+
1866
+ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
1867
+ total_seqlen, nheads, headdim = x.shape
1868
+ _, nchunks, chunk_size = dt.shape
1869
+ _, ngroups, dstate = B.shape
1870
+ batch = cu_seqlens.shape[0] - 1
1871
+ cu_seqlens = cu_seqlens.contiguous()
1872
+ assert nheads % ngroups == 0
1873
+ assert B.shape == (total_seqlen, ngroups, dstate)
1874
+ assert dt.shape == (nheads, nchunks, chunk_size)
1875
+ assert dA_cumsum.shape == dt.shape
1876
+ assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
1877
+ states = torch.empty(
1878
+ batch,
1879
+ nheads,
1880
+ headdim,
1881
+ dstate,
1882
+ dtype=chunk_states.dtype,
1883
+ device=chunk_states.device,
1884
+ )
1885
+ grid = lambda META: (
1886
+ triton.cdiv(headdim, META["BLOCK_SIZE_M"])
1887
+ * triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
1888
+ batch,
1889
+ nheads,
1890
+ )
1891
+ with torch.cuda.device(x.device.index):
1892
+ _chunk_state_varlen_kernel[grid](
1893
+ x,
1894
+ B,
1895
+ dt,
1896
+ dA_cumsum,
1897
+ chunk_states,
1898
+ cu_seqlens,
1899
+ states,
1900
+ headdim,
1901
+ dstate,
1902
+ chunk_size,
1903
+ total_seqlen,
1904
+ nheads // ngroups,
1905
+ x.stride(0),
1906
+ x.stride(1),
1907
+ x.stride(2),
1908
+ B.stride(0),
1909
+ B.stride(1),
1910
+ B.stride(2),
1911
+ dt.stride(1),
1912
+ dt.stride(0),
1913
+ dt.stride(2),
1914
+ dA_cumsum.stride(1),
1915
+ dA_cumsum.stride(0),
1916
+ dA_cumsum.stride(2),
1917
+ chunk_states.stride(0),
1918
+ chunk_states.stride(1),
1919
+ chunk_states.stride(2),
1920
+ chunk_states.stride(3),
1921
+ states.stride(0),
1922
+ states.stride(1),
1923
+ states.stride(2),
1924
+ states.stride(3),
1925
+ )
1926
+ return states
1927
+
1928
+
1929
+ class ChunkStateFn(torch.autograd.Function):
1930
+
1931
+ @staticmethod
1932
+ def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
1933
+ batch, seqlen, nheads, headdim = x.shape
1934
+ _, _, nchunks, chunk_size = dt.shape
1935
+ assert seqlen <= nchunks * chunk_size
1936
+ _, _, ngroups, dstate = B.shape
1937
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1938
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1939
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
1940
+ if B.stride(-1) != 1:
1941
+ B = B.contiguous()
1942
+ if (
1943
+ x.stride(-1) != 1 and x.stride(1) != 1
1944
+ ): # Either M or K dimension should be contiguous
1945
+ x = x.contiguous()
1946
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
1947
+ ctx.save_for_backward(B, x, dt, dA_cumsum)
1948
+ return states
1949
+
1950
+ @staticmethod
1951
+ def backward(ctx, dstates):
1952
+ B, x, dt, dA_cumsum = ctx.saved_tensors
1953
+ batch, seqlen, nheads, headdim = x.shape
1954
+ _, _, nchunks, chunk_size = dt.shape
1955
+ _, _, ngroups, dstate = B.shape
1956
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
1957
+ if dstates.stride(-1) != 1:
1958
+ dstates = dstates.contiguous()
1959
+ dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
1960
+ dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
1961
+ dB = dB.to(B.dtype)
1962
+ return dB, dx, ddt, ddA_cumsum, None
1963
+
1964
+
1965
+ def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
1966
+ """
1967
+ Argument:
1968
+ B: (batch, seqlen, ngroups, headdim)
1969
+ x: (batch, seqlen, nheads, headdim)
1970
+ dt: (batch, nheads, nchunks, chunk_size)
1971
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
1972
+ Return:
1973
+ states: (batch, nchunks, nheads, headdim, dstate)
1974
+ """
1975
+ return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
1976
+
1977
+
1978
+ def chunk_state_ref(B, x, dt, dA_cumsum):
1979
+ """
1980
+ Argument:
1981
+ B: (batch, seqlen, ngroups, headdim)
1982
+ x: (batch, seqlen, nheads, headdim)
1983
+ dt: (batch, nheads, nchunks, chunk_size)
1984
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
1985
+ Return:
1986
+ states: (batch, nchunks, nheads, headdim, dstate)
1987
+ """
1988
+ # Check constraints.
1989
+ batch, seqlen, nheads, headdim = x.shape
1990
+ dstate = B.shape[-1]
1991
+ _, _, nchunks, chunk_size = dt.shape
1992
+ assert seqlen <= nchunks * chunk_size
1993
+ assert x.shape == (batch, seqlen, nheads, headdim)
1994
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
1995
+ ngroups = B.shape[2]
1996
+ assert nheads % ngroups == 0
1997
+ assert B.shape == (batch, seqlen, ngroups, dstate)
1998
+ B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
1999
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
2000
+ if seqlen < nchunks * chunk_size:
2001
+ x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2002
+ B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
2003
+ x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
2004
+ B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
2005
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
2006
+ return torch.einsum(
2007
+ "bclhn,bhcl,bhcl,bclhp->bchpn",
2008
+ B.to(x.dtype),
2009
+ decay_states.to(x.dtype),
2010
+ dt.to(x.dtype),
2011
+ x,
2012
+ )