kernel
danieldk HF Staff commited on
Commit
b58ed97
·
1 Parent(s): dd2f0f9
Files changed (42) hide show
  1. build/torch25-cxx11-cu118-x86_64-linux/flash_attn/__init__.py +343 -16
  2. build/torch25-cxx11-cu118-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} +2 -2
  3. build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +3 -3
  4. build/torch25-cxx11-cu121-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} +2 -2
  5. build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_ops.py +3 -3
  6. build/torch25-cxx11-cu124-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} +2 -2
  7. build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +3 -3
  8. build/torch25-cxx98-cu118-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} +2 -2
  9. build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +3 -3
  10. build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  11. build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  12. build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_ops.py +3 -3
  13. build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  14. build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  15. build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +3 -3
  16. build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  17. build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  18. build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +3 -3
  19. build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  20. build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  21. build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py +3 -3
  22. build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  23. build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  24. build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py +3 -3
  25. build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  26. build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  27. build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py +3 -3
  28. build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  29. build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  30. build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py +3 -3
  31. build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so +0 -3
  32. build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  33. build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py +3 -3
  34. build/torch27-cxx11-cu118-x86_64-linux/flash_attn/__init__.py +364 -0
  35. build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  36. build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_ops.py +9 -0
  37. build/torch27-cxx11-cu126-x86_64-linux/flash_attn/__init__.py +364 -0
  38. build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  39. build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_ops.py +9 -0
  40. build/torch27-cxx11-cu128-x86_64-linux/flash_attn/__init__.py +364 -0
  41. build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so +3 -0
  42. build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_ops.py +9 -0
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/__init__.py CHANGED
@@ -1,25 +1,45 @@
1
- from typing import Optional
2
-
3
  import torch
4
-
5
  from ._ops import ops
6
 
 
7
  def mha_fwd(
8
  q: torch.Tensor,
9
  k: torch.Tensor,
10
  v: torch.Tensor,
11
- out: torch.Tensor,
12
- alibi_slopes: torch.Tensor,
13
- p_dropout: float,
14
- softmax_scale: float,
15
- is_causal: bool,
16
- window_size_left: int,
17
- window_size_right: int,
18
- softcap: float,
19
- return_softmax: bool,
20
- gen: Optional[torch.Generator],
21
- ) -> torch.Tensor:
22
- ops.mha_fwd(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  q,
24
  k,
25
  v,
@@ -34,4 +54,311 @@ def mha_fwd(
34
  return_softmax,
35
  gen,
36
  )
37
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
 
2
  import torch
 
3
  from ._ops import ops
4
 
5
+
6
  def mha_fwd(
7
  q: torch.Tensor,
8
  k: torch.Tensor,
9
  v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
  q,
44
  k,
45
  v,
 
54
  return_softmax,
55
  gen,
56
  )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:85f70ae9ee6f5b27b149808f14aedf0dbb327fcfac6e6320c48d17810009dc77
3
- size 1301385392
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14e43c95a52d7b6a974bc54b6ec30068ae8fa513583a686494caf123137dc2e5
3
+ size 658100376
build/torch25-cxx11-cu118-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:91b3c70a49f7d039bc7a238d0147dabe94cffd2485463bdd641bb74b395ada99
3
- size 1295653368
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5405ad9405b2c3ded5f971fdc7a7fdfa0531eb2f1aca2e37e396003a149b1379
3
+ size 653617624
build/torch25-cxx11-cu121-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0e074eed034da9275d49c87d904babd0a718c8e22d12cdedfae01e7c38260113
3
- size 1262747328
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2b7aeea4adc77aefd217ccdfa7bcadcaef3dc6d0d7567f4b1c2c5f0321738fe
3
+ size 640704152
build/torch25-cxx11-cu124-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/{_flash_attn_ab4cc6a_dirty.abi3.so → _flash_attn_dd2f0f9.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b245e7fe66f20cef74aaab7c86d1e33913faeff9d6dae530763d4a5dd256af5
3
- size 1301380832
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73f7d75dcba8295aa14721ffc8c1aca0d86872ae03af18ab2e5149c043201d2a
3
+ size 658091712
build/torch25-cxx98-cu118-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:347160ae8e05c11d1a99da542ecb4c2f6dbd30627cc6002b08c107b9d3d8af3c
3
- size 1295640880
 
 
 
 
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:534cc70a7659f0acfca6fbd135d229af62f11557f910360939f55454cd2f6ce3
3
+ size 653605136
build/torch25-cxx98-cu121-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c79d6703c033ea9e1bfcc6fc3006ac88a9713d8371ea3a96d70e8495c7692f68
3
- size 1262738936
 
 
 
 
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8051fcc2f91c45d364292f8e2f04804a93b2e78844747a46afcfa926007769be
3
+ size 640695760
build/torch25-cxx98-cu124-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8cc0f02a6eea5c9fe8e5bc7b0138cef9bf77c026dc26b08f878bd809799189e
3
- size 1301389752
 
 
 
 
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af9873164f58acc3dcea5a1ca046af1872f3a5d2061668edcec7e7802c02a0a6
3
+ size 658100640
build/torch26-cxx11-cu118-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fc6aaa5d51f3d329ec4d6fe7422ff8ff5223fa1a1e01644da196504534bd4fb6
3
- size 1262747768
 
 
 
 
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:469a5bac698ba7be5a9aecd831b5ca5fd21ff37843d603d3e39888fad477d6e6
3
+ size 640704600
build/torch26-cxx11-cu124-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:afce8d0bc6516f4e2ade3b45453d6370ead51ab9d368786b20109544cc8b4772
3
- size 1273150064
 
 
 
 
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:475a51fa6fe806195457f3ea76e64343bb7b0beca8be1f128d24c0672de6a5ee
3
+ size 646613576
build/torch26-cxx11-cu126-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0f65a659aa158221085014ade1e92475fe08871894796ca8db38ef2d2dbbcb99
3
- size 1301381128
 
 
 
 
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ca74dec8c6dcab25c0359d147b769d1025c46fb5a8ea81dd87ca4d03876044b
3
+ size 658092008
build/torch26-cxx98-cu118-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ad532518c0a821e096e21c16bd89ec4c0b57b5b9cae92daa4c75100cfe712c6
3
- size 1262739232
 
 
 
 
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c366befff828dfa0bccf8de7de54c5fa6b3f796d55690e5e0623e518de89e4f2
3
+ size 640696056
build/torch26-cxx98-cu124-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_ab4cc6a_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:300868f1f33c620a923efa0629916bb0afda4763af425de233e48389eede6db4
3
- size 1273141520
 
 
 
 
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0753c6dad0da882bae86dae5658e5915eb20c89f19cb69352f239c597b5d697
3
+ size 646605032
build/torch26-cxx98-cu126-x86_64-linux/flash_attn/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _flash_attn_ab4cc6a_dirty
3
- ops = torch.ops._flash_attn_ab4cc6a_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_flash_attn_ab4cc6a_dirty::{op_name}"
 
1
  import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcb1788a80f0624ec6532ea3abdbd1ef504364006129ef4564d131f2a44dc916
3
+ size 658100920
build/torch27-cxx11-cu118-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7da95f182ca7f57b45cfe9387045d78397c312283d8a5eecd9bce96e6888ea8
3
+ size 646613312
build/torch27-cxx11-cu126-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/flash_attn/__init__.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from ._ops import ops
4
+
5
+
6
+ def mha_fwd(
7
+ q: torch.Tensor,
8
+ k: torch.Tensor,
9
+ v: torch.Tensor,
10
+ out: Optional[torch.Tensor] = None,
11
+ alibi_slopes: Optional[torch.Tensor] = None,
12
+ p_dropout: float = 0.0,
13
+ softmax_scale: float = 1.0,
14
+ is_causal: bool = False,
15
+ window_size_left: int = -1,
16
+ window_size_right: int = -1,
17
+ softcap: float = 0.0,
18
+ return_softmax: bool = False,
19
+ gen: Optional[torch.Generator] = None,
20
+ ) -> List[torch.Tensor]:
21
+ """
22
+ Forward pass for multi-head attention.
23
+
24
+ Args:
25
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
26
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
27
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
28
+ out: Optional output tensor, same shape as q
29
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
30
+ p_dropout: Dropout probability
31
+ softmax_scale: Scale factor for softmax
32
+ is_causal: Whether to use causal attention
33
+ window_size_left: Window size for left context (-1 for unlimited)
34
+ window_size_right: Window size for right context (-1 for unlimited)
35
+ softcap: Soft cap for attention weights
36
+ return_softmax: Whether to return softmax weights
37
+ gen: Optional random number generator
38
+
39
+ Returns:
40
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
41
+ """
42
+ return ops.mha_fwd(
43
+ q,
44
+ k,
45
+ v,
46
+ out,
47
+ alibi_slopes,
48
+ p_dropout,
49
+ softmax_scale,
50
+ is_causal,
51
+ window_size_left,
52
+ window_size_right,
53
+ softcap,
54
+ return_softmax,
55
+ gen,
56
+ )
57
+
58
+
59
+ def mha_varlen_fwd(
60
+ q: torch.Tensor,
61
+ k: torch.Tensor,
62
+ v: torch.Tensor,
63
+ cu_seqlens_q: torch.Tensor,
64
+ cu_seqlens_k: torch.Tensor,
65
+ out: Optional[torch.Tensor] = None,
66
+ seqused_k: Optional[torch.Tensor] = None,
67
+ leftpad_k: Optional[torch.Tensor] = None,
68
+ block_table: Optional[torch.Tensor] = None,
69
+ alibi_slopes: Optional[torch.Tensor] = None,
70
+ max_seqlen_q: int = 0,
71
+ max_seqlen_k: int = 0,
72
+ p_dropout: float = 0.0,
73
+ softmax_scale: float = 1.0,
74
+ zero_tensors: bool = False,
75
+ is_causal: bool = False,
76
+ window_size_left: int = -1,
77
+ window_size_right: int = -1,
78
+ softcap: float = 0.0,
79
+ return_softmax: bool = False,
80
+ gen: Optional[torch.Generator] = None,
81
+ ) -> List[torch.Tensor]:
82
+ """
83
+ Forward pass for multi-head attention with variable sequence lengths.
84
+
85
+ Args:
86
+ q: Query tensor of shape [total_q, num_heads, head_size]
87
+ k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
88
+ v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
89
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
90
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
91
+ out: Optional output tensor of shape [total_q, num_heads, head_size]
92
+ seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
93
+ leftpad_k: Optional left padding for keys of shape [batch_size]
94
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
95
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
96
+ max_seqlen_q: Maximum sequence length for queries
97
+ max_seqlen_k: Maximum sequence length for keys
98
+ p_dropout: Dropout probability
99
+ softmax_scale: Scale factor for softmax
100
+ zero_tensors: Whether to zero tensors before computation
101
+ is_causal: Whether to use causal attention
102
+ window_size_left: Window size for left context (-1 for unlimited)
103
+ window_size_right: Window size for right context (-1 for unlimited)
104
+ softcap: Soft cap for attention weights
105
+ return_softmax: Whether to return softmax weights
106
+ gen: Optional random number generator
107
+
108
+ Returns:
109
+ List of tensors: [output, softmax_lse, (softmax if return_softmax)]
110
+ """
111
+ return ops.mha_varlen_fwd(
112
+ q,
113
+ k,
114
+ v,
115
+ out,
116
+ cu_seqlens_q,
117
+ cu_seqlens_k,
118
+ seqused_k,
119
+ leftpad_k,
120
+ block_table,
121
+ alibi_slopes,
122
+ max_seqlen_q,
123
+ max_seqlen_k,
124
+ p_dropout,
125
+ softmax_scale,
126
+ zero_tensors,
127
+ is_causal,
128
+ window_size_left,
129
+ window_size_right,
130
+ softcap,
131
+ return_softmax,
132
+ gen,
133
+ )
134
+
135
+
136
+ def mha_bwd(
137
+ dout: torch.Tensor,
138
+ q: torch.Tensor,
139
+ k: torch.Tensor,
140
+ v: torch.Tensor,
141
+ out: torch.Tensor,
142
+ softmax_lse: torch.Tensor,
143
+ dq: Optional[torch.Tensor] = None,
144
+ dk: Optional[torch.Tensor] = None,
145
+ dv: Optional[torch.Tensor] = None,
146
+ alibi_slopes: Optional[torch.Tensor] = None,
147
+ p_dropout: float = 0.0,
148
+ softmax_scale: float = 1.0,
149
+ is_causal: bool = False,
150
+ window_size_left: int = -1,
151
+ window_size_right: int = -1,
152
+ softcap: float = 0.0,
153
+ deterministic: bool = False,
154
+ gen: Optional[torch.Generator] = None,
155
+ rng_state: Optional[torch.Tensor] = None,
156
+ ) -> List[torch.Tensor]:
157
+ """
158
+ Backward pass for multi-head attention.
159
+
160
+ Args:
161
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
162
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
163
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
164
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
165
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
166
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
167
+ dq: Optional gradient tensor for queries, same shape as q
168
+ dk: Optional gradient tensor for keys, same shape as k
169
+ dv: Optional gradient tensor for values, same shape as v
170
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
171
+ p_dropout: Dropout probability
172
+ softmax_scale: Scale factor for softmax
173
+ is_causal: Whether to use causal attention
174
+ window_size_left: Window size for left context (-1 for unlimited)
175
+ window_size_right: Window size for right context (-1 for unlimited)
176
+ softcap: Soft cap for attention weights
177
+ deterministic: Whether to use deterministic algorithms
178
+ gen: Optional random number generator
179
+ rng_state: Optional RNG state from forward pass
180
+
181
+ Returns:
182
+ List of tensors: [dq, dk, dv]
183
+ """
184
+ return ops.mha_bwd(
185
+ dout,
186
+ q,
187
+ k,
188
+ v,
189
+ out,
190
+ softmax_lse,
191
+ dq,
192
+ dk,
193
+ dv,
194
+ alibi_slopes,
195
+ p_dropout,
196
+ softmax_scale,
197
+ is_causal,
198
+ window_size_left,
199
+ window_size_right,
200
+ softcap,
201
+ deterministic,
202
+ gen,
203
+ rng_state,
204
+ )
205
+
206
+
207
+ def mha_varlen_bwd(
208
+ dout: torch.Tensor,
209
+ q: torch.Tensor,
210
+ k: torch.Tensor,
211
+ v: torch.Tensor,
212
+ out: torch.Tensor,
213
+ softmax_lse: torch.Tensor,
214
+ cu_seqlens_q: torch.Tensor,
215
+ cu_seqlens_k: torch.Tensor,
216
+ dq: Optional[torch.Tensor] = None,
217
+ dk: Optional[torch.Tensor] = None,
218
+ dv: Optional[torch.Tensor] = None,
219
+ alibi_slopes: Optional[torch.Tensor] = None,
220
+ max_seqlen_q: int = 0,
221
+ max_seqlen_k: int = 0,
222
+ p_dropout: float = 0.0,
223
+ softmax_scale: float = 1.0,
224
+ zero_tensors: bool = False,
225
+ is_causal: bool = False,
226
+ window_size_left: int = -1,
227
+ window_size_right: int = -1,
228
+ softcap: float = 0.0,
229
+ deterministic: bool = False,
230
+ gen: Optional[torch.Generator] = None,
231
+ rng_state: Optional[torch.Tensor] = None,
232
+ ) -> List[torch.Tensor]:
233
+ """
234
+ Backward pass for multi-head attention with variable sequence lengths.
235
+
236
+ Args:
237
+ dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
238
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
239
+ k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
240
+ v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
241
+ out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
242
+ softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
243
+ cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
244
+ cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
245
+ dq: Optional gradient tensor for queries, same shape as q
246
+ dk: Optional gradient tensor for keys, same shape as k
247
+ dv: Optional gradient tensor for values, same shape as v
248
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
249
+ max_seqlen_q: Maximum sequence length for queries
250
+ max_seqlen_k: Maximum sequence length for keys
251
+ p_dropout: Dropout probability
252
+ softmax_scale: Scale factor for softmax
253
+ zero_tensors: Whether to zero tensors before computation
254
+ is_causal: Whether to use causal attention
255
+ window_size_left: Window size for left context (-1 for unlimited)
256
+ window_size_right: Window size for right context (-1 for unlimited)
257
+ softcap: Soft cap for attention weights
258
+ deterministic: Whether to use deterministic algorithms
259
+ gen: Optional random number generator
260
+ rng_state: Optional RNG state from forward pass
261
+
262
+ Returns:
263
+ List of tensors: [dq, dk, dv]
264
+ """
265
+ return ops.mha_varlen_bwd(
266
+ dout,
267
+ q,
268
+ k,
269
+ v,
270
+ out,
271
+ softmax_lse,
272
+ dq,
273
+ dk,
274
+ dv,
275
+ cu_seqlens_q,
276
+ cu_seqlens_k,
277
+ alibi_slopes,
278
+ max_seqlen_q,
279
+ max_seqlen_k,
280
+ p_dropout,
281
+ softmax_scale,
282
+ zero_tensors,
283
+ is_causal,
284
+ window_size_left,
285
+ window_size_right,
286
+ softcap,
287
+ deterministic,
288
+ gen,
289
+ rng_state,
290
+ )
291
+
292
+
293
+ def mha_fwd_kvcache(
294
+ q: torch.Tensor,
295
+ kcache: torch.Tensor,
296
+ vcache: torch.Tensor,
297
+ k: Optional[torch.Tensor] = None,
298
+ v: Optional[torch.Tensor] = None,
299
+ seqlens_k: Optional[torch.Tensor] = None,
300
+ rotary_cos: Optional[torch.Tensor] = None,
301
+ rotary_sin: Optional[torch.Tensor] = None,
302
+ cache_batch_idx: Optional[torch.Tensor] = None,
303
+ leftpad_k: Optional[torch.Tensor] = None,
304
+ block_table: Optional[torch.Tensor] = None,
305
+ alibi_slopes: Optional[torch.Tensor] = None,
306
+ out: Optional[torch.Tensor] = None,
307
+ softmax_scale: float = 1.0,
308
+ is_causal: bool = False,
309
+ window_size_left: int = -1,
310
+ window_size_right: int = -1,
311
+ softcap: float = 0.0,
312
+ is_rotary_interleaved: bool = False,
313
+ num_splits: int = 1,
314
+ ) -> List[torch.Tensor]:
315
+ """
316
+ Forward pass for multi-head attention with KV cache.
317
+
318
+ Args:
319
+ q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
320
+ kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
321
+ vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
322
+ k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
323
+ v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
324
+ seqlens_k: Optional sequence lengths for keys of shape [batch_size]
325
+ rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
326
+ rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
327
+ cache_batch_idx: Optional indices to index into the KV cache
328
+ leftpad_k: Optional left padding for keys of shape [batch_size]
329
+ block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
330
+ alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
331
+ out: Optional output tensor, same shape as q
332
+ softmax_scale: Scale factor for softmax
333
+ is_causal: Whether to use causal attention
334
+ window_size_left: Window size for left context (-1 for unlimited)
335
+ window_size_right: Window size for right context (-1 for unlimited)
336
+ softcap: Soft cap for attention weights
337
+ is_rotary_interleaved: Whether rotary embeddings are interleaved
338
+ num_splits: Number of splits for computation
339
+
340
+ Returns:
341
+ List of tensors: [output, softmax_lse]
342
+ """
343
+ return ops.mha_fwd_kvcache(
344
+ q,
345
+ kcache,
346
+ vcache,
347
+ k,
348
+ v,
349
+ seqlens_k,
350
+ rotary_cos,
351
+ rotary_sin,
352
+ cache_batch_idx,
353
+ leftpad_k,
354
+ block_table,
355
+ alibi_slopes,
356
+ out,
357
+ softmax_scale,
358
+ is_causal,
359
+ window_size_left,
360
+ window_size_right,
361
+ softcap,
362
+ is_rotary_interleaved,
363
+ num_splits,
364
+ )
build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_flash_attn_dd2f0f9.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ba8c31bf3488a6f0a93e2d5d83d28a27daa26c156ed357ba5443ac66e3809fc
3
+ size 1502967480
build/torch27-cxx11-cu128-x86_64-linux/flash_attn/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _flash_attn_dd2f0f9
3
+ ops = torch.ops._flash_attn_dd2f0f9
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_flash_attn_dd2f0f9::{op_name}"