kernel
drbh
feat improve readme and library code
d6cc1b0
from typing import Optional, List
import torch
from ._ops import ops
def mha_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
return_softmax: bool = False,
gen: Optional[torch.Generator] = None,
) -> List[torch.Tensor]:
"""
Forward pass for multi-head attention.
Args:
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
out: Optional output tensor, same shape as q
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
return_softmax: Whether to return softmax weights
gen: Optional random number generator
Returns:
List of tensors: [output, softmax_lse, (softmax if return_softmax)]
"""
return ops.mha_fwd(
q,
k,
v,
out,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
def mha_varlen_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
out: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
max_seqlen_q: int = 0,
max_seqlen_k: int = 0,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
zero_tensors: bool = False,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
return_softmax: bool = False,
gen: Optional[torch.Generator] = None,
) -> List[torch.Tensor]:
"""
Forward pass for multi-head attention with variable sequence lengths.
Args:
q: Query tensor of shape [total_q, num_heads, head_size]
k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
out: Optional output tensor of shape [total_q, num_heads, head_size]
seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size]
leftpad_k: Optional left padding for keys of shape [batch_size]
block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
max_seqlen_q: Maximum sequence length for queries
max_seqlen_k: Maximum sequence length for keys
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
zero_tensors: Whether to zero tensors before computation
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
return_softmax: Whether to return softmax weights
gen: Optional random number generator
Returns:
List of tensors: [output, softmax_lse, (softmax if return_softmax)]
"""
return ops.mha_varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
def mha_bwd(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
gen: Optional[torch.Generator] = None,
rng_state: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""
Backward pass for multi-head attention.
Args:
dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
dq: Optional gradient tensor for queries, same shape as q
dk: Optional gradient tensor for keys, same shape as k
dv: Optional gradient tensor for values, same shape as v
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
deterministic: Whether to use deterministic algorithms
gen: Optional random number generator
rng_state: Optional RNG state from forward pass
Returns:
List of tensors: [dq, dk, dv]
"""
return ops.mha_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
deterministic,
gen,
rng_state,
)
def mha_varlen_bwd(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
max_seqlen_q: int = 0,
max_seqlen_k: int = 0,
p_dropout: float = 0.0,
softmax_scale: float = 1.0,
zero_tensors: bool = False,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
gen: Optional[torch.Generator] = None,
rng_state: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""
Backward pass for multi-head attention with variable sequence lengths.
Args:
dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size]
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size]
out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size]
softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q]
cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1]
cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1]
dq: Optional gradient tensor for queries, same shape as q
dk: Optional gradient tensor for keys, same shape as k
dv: Optional gradient tensor for values, same shape as v
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
max_seqlen_q: Maximum sequence length for queries
max_seqlen_k: Maximum sequence length for keys
p_dropout: Dropout probability
softmax_scale: Scale factor for softmax
zero_tensors: Whether to zero tensors before computation
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
deterministic: Whether to use deterministic algorithms
gen: Optional random number generator
rng_state: Optional RNG state from forward pass
Returns:
List of tensors: [dq, dk, dv]
"""
return ops.mha_varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
softcap,
deterministic,
gen,
rng_state,
)
def mha_fwd_kvcache(
q: torch.Tensor,
kcache: torch.Tensor,
vcache: torch.Tensor,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
seqlens_k: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
softmax_scale: float = 1.0,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
is_rotary_interleaved: bool = False,
num_splits: int = 1,
) -> List[torch.Tensor]:
"""
Forward pass for multi-head attention with KV cache.
Args:
q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size]
kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size]
k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size]
seqlens_k: Optional sequence lengths for keys of shape [batch_size]
rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2]
rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2]
cache_batch_idx: Optional indices to index into the KV cache
leftpad_k: Optional left padding for keys of shape [batch_size]
block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq]
alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads]
out: Optional output tensor, same shape as q
softmax_scale: Scale factor for softmax
is_causal: Whether to use causal attention
window_size_left: Window size for left context (-1 for unlimited)
window_size_right: Window size for right context (-1 for unlimited)
softcap: Soft cap for attention weights
is_rotary_interleaved: Whether rotary embeddings are interleaved
num_splits: Number of splits for computation
Returns:
List of tensors: [output, softmax_lse]
"""
return ops.mha_fwd_kvcache(
q,
kcache,
vcache,
k,
v,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
leftpad_k,
block_table,
alibi_slopes,
out,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
is_rotary_interleaved,
num_splits,
)