kernel
danieldk HF Staff commited on
Commit
f849035
·
1 Parent(s): b58ed97

Build (aarch64)

Browse files
build/torch26-cxx11-cu126-aarch64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx11-cu126-aarch64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80fb5d7d2d79174b113a48fcf87f1ee99e58cb10e37525ce6d7fbe88b92b2b4e
3
+ size 646378472
build/torch26-cxx11-cu126-aarch64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx98-cu126-aarch64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch26-cxx98-cu126-aarch64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2d761fb906a27113da522801b91d3e2c6db042253caaf9bceb6b893cf76964a
3
+ size 646373888
build/torch26-cxx98-cu126-aarch64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch27-cxx11-cu126-aarch64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch27-cxx11-cu126-aarch64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bdf651fa75409d3d8e04e85dd7e2ade1f263114e5c58fae0f1e2dde76f3554c
3
+ size 646378696
build/torch27-cxx11-cu126-aarch64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch27-cxx11-cu128-aarch64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch27-cxx11-cu128-aarch64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a23b3783a4b8aa7f7f03d964cc476baa3521e8d2bad7fd0f7376afbed2640dac
3
+ size 1503161696
build/torch27-cxx11-cu128-aarch64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"