File size: 13,193 Bytes
d6cc1b0 a7165c8 d6cc1b0 a7165c8 d6cc1b0 a7165c8 d6cc1b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
from typing import Optional, List
import torch
from ._ops import ops
def mha_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
return_softmax: bool = False,
gen: Optional[torch.Generator] = None,
) -> List[torch.Tensor]:
"""
Forward pass for multi-head attention.
Args:
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
out: Optional output tensor, same shape as q
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
return_softmax: Whether to return softmax weights
gen: Optional random number generator
Returns:
List of tensors: [output, softmax_lse, (softmax if return_softmax)]
"""
return ops.mha_fwd(
q,
k,
v,
out,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
def mha_varlen_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
out: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
max_seqlen_q: int = 0,
max_seqlen_k: int = 0,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
zero_tensors: bool = False,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
return_softmax: bool = False,
gen: Optional[torch.Generator] = None,
) -> List[torch.Tensor]:
"""
Forward pass for multi-head attention with variable sequence lengths.
Args:
q: Query tensor of shape [total_q, num_heads, head_size]
k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
out: Optional output tensor of shape [total_q, num_heads, head_size]
seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
leftpad_k: Optional left padding for keys of shape [batch_size]
block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
max_seqlen_q: Maximum sequence length for queries
max_seqlen_k: Maximum sequence length for keys
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
zero_tensors: Whether to zero tensors before computation
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
return_softmax: Whether to return softmax weights
gen: Optional random number generator
Returns:
List of tensors: [output, softmax_lse, (softmax if return_softmax)]
"""
return ops.mha_varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
def mha_bwd(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
gen: Optional[torch.Generator] = None,
rng_state: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""
Backward pass for multi-head attention.
Args:
dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
dq: Optional gradient tensor for queries, same shape as q
dk: Optional gradient tensor for keys, same shape as k
dv: Optional gradient tensor for values, same shape as v
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
deterministic: Whether to use deterministic algorithms
gen: Optional random number generator
rng_state: Optional RNG state from forward pass
Returns:
List of tensors: [dq, dk, dv]
"""
return ops.mha_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
deterministic,
gen,
rng_state,
)
def mha_varlen_bwd(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
max_seqlen_q: int = 0,
max_seqlen_k: int = 0,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
zero_tensors: bool = False,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
gen: Optional[torch.Generator] = None,
rng_state: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""
Backward pass for multi-head attention with variable sequence lengths.
Args:
dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
dq: Optional gradient tensor for queries, same shape as q
dk: Optional gradient tensor for keys, same shape as k
dv: Optional gradient tensor for values, same shape as v
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
max_seqlen_q: Maximum sequence length for queries
max_seqlen_k: Maximum sequence length for keys
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
zero_tensors: Whether to zero tensors before computation
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
deterministic: Whether to use deterministic algorithms
gen: Optional random number generator
rng_state: Optional RNG state from forward pass
Returns:
List of tensors: [dq, dk, dv]
"""
return ops.mha_varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
softcap,
deterministic,
gen,
rng_state,
)
def mha_fwd_kvcache(
q: torch.Tensor,
kcache: torch.Tensor,
vcache: torch.Tensor,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
seqlens_k: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
softmax_scale: float = 1.0,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
is_rotary_interleaved: bool = False,
num_splits: int = 1,
) -> List[torch.Tensor]:
"""
Forward pass for multi-head attention with KV cache.
Args:
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
seqlens_k: Optional sequence lengths for keys of shape [batch_size]
rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
cache_batch_idx: Optional indices to index into the KV cache
leftpad_k: Optional left padding for keys of shape [batch_size]
block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
out: Optional output tensor, same shape as q
softmax_scale: Scale factor for softmax
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
is_rotary_interleaved: Whether rotary embeddings are interleaved
num_splits: Number of splits for computation
Returns:
List of tensors: [output, softmax_lse]
"""
return ops.mha_fwd_kvcache(
q,
kcache,
vcache,
k,
v,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
leftpad_k,
block_table,
alibi_slopes,
out,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
is_rotary_interleaved,
num_splits,
)
|