from typing import Optional, List import torch from ._ops import ops def mha_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, p_dropout: float = 0.0, softmax_scale: float = 1.0, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, return_softmax: bool = False, gen: Optional[torch.Generator] = None, ) -> List[torch.Tensor]: """ Forward pass for multi-head attention. Args: q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] out: Optional output tensor, same shape as q alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] p_dropout: Dropout probability softmax_scale: Scale factor for softmax is_causal: Whether to use causal attention window_size_left: Window size for left context (-1 for unlimited) window_size_right: Window size for right context (-1 for unlimited) softcap: Soft cap for attention weights return_softmax: Whether to return softmax weights gen: Optional random number generator Returns: List of tensors: [output, softmax_lse, (softmax if return_softmax)] """ return ops.mha_fwd( q, k, v, out, alibi_slopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, softcap, return_softmax, gen, ) def mha_varlen_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, out: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, max_seqlen_q: int = 0, max_seqlen_k: int = 0, p_dropout: float = 0.0, softmax_scale: float = 1.0, zero_tensors: bool = False, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, return_softmax: bool = False, gen: Optional[torch.Generator] = None, ) -> List[torch.Tensor]: """ Forward pass for multi-head attention with variable sequence lengths. Args: q: Query tensor of shape [total_q, num_heads, head_size] k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] out: Optional output tensor of shape [total_q, num_heads, head_size] seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size] leftpad_k: Optional left padding for keys of shape [batch_size] block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] max_seqlen_q: Maximum sequence length for queries max_seqlen_k: Maximum sequence length for keys p_dropout: Dropout probability softmax_scale: Scale factor for softmax zero_tensors: Whether to zero tensors before computation is_causal: Whether to use causal attention window_size_left: Window size for left context (-1 for unlimited) window_size_right: Window size for right context (-1 for unlimited) softcap: Soft cap for attention weights return_softmax: Whether to return softmax weights gen: Optional random number generator Returns: List of tensors: [output, softmax_lse, (softmax if return_softmax)] """ return ops.mha_varlen_fwd( q, k, v, out, cu_seqlens_q, cu_seqlens_k, seqused_k, leftpad_k, block_table, alibi_slopes, max_seqlen_q, max_seqlen_k, p_dropout, softmax_scale, zero_tensors, is_causal, window_size_left, window_size_right, softcap, return_softmax, gen, ) def mha_bwd( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, p_dropout: float = 0.0, softmax_scale: float = 1.0, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, deterministic: bool = False, gen: Optional[torch.Generator] = None, rng_state: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Backward pass for multi-head attention. Args: dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] dq: Optional gradient tensor for queries, same shape as q dk: Optional gradient tensor for keys, same shape as k dv: Optional gradient tensor for values, same shape as v alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] p_dropout: Dropout probability softmax_scale: Scale factor for softmax is_causal: Whether to use causal attention window_size_left: Window size for left context (-1 for unlimited) window_size_right: Window size for right context (-1 for unlimited) softcap: Soft cap for attention weights deterministic: Whether to use deterministic algorithms gen: Optional random number generator rng_state: Optional RNG state from forward pass Returns: List of tensors: [dq, dk, dv] """ return ops.mha_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, alibi_slopes, p_dropout, softmax_scale, is_causal, window_size_left, window_size_right, softcap, deterministic, gen, rng_state, ) def mha_varlen_bwd( dout: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, softmax_lse: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, max_seqlen_q: int = 0, max_seqlen_k: int = 0, p_dropout: float = 0.0, softmax_scale: float = 1.0, zero_tensors: bool = False, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, deterministic: bool = False, gen: Optional[torch.Generator] = None, rng_state: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Backward pass for multi-head attention with variable sequence lengths. Args: dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] dq: Optional gradient tensor for queries, same shape as q dk: Optional gradient tensor for keys, same shape as k dv: Optional gradient tensor for values, same shape as v alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] max_seqlen_q: Maximum sequence length for queries max_seqlen_k: Maximum sequence length for keys p_dropout: Dropout probability softmax_scale: Scale factor for softmax zero_tensors: Whether to zero tensors before computation is_causal: Whether to use causal attention window_size_left: Window size for left context (-1 for unlimited) window_size_right: Window size for right context (-1 for unlimited) softcap: Soft cap for attention weights deterministic: Whether to use deterministic algorithms gen: Optional random number generator rng_state: Optional RNG state from forward pass Returns: List of tensors: [dq, dk, dv] """ return ops.mha_varlen_bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, alibi_slopes, max_seqlen_q, max_seqlen_k, p_dropout, softmax_scale, zero_tensors, is_causal, window_size_left, window_size_right, softcap, deterministic, gen, rng_state, ) def mha_fwd_kvcache( q: torch.Tensor, kcache: torch.Tensor, vcache: torch.Tensor, k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, seqlens_k: Optional[torch.Tensor] = None, rotary_cos: Optional[torch.Tensor] = None, rotary_sin: Optional[torch.Tensor] = None, cache_batch_idx: Optional[torch.Tensor] = None, leftpad_k: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, softmax_scale: float = 1.0, is_causal: bool = False, window_size_left: int = -1, window_size_right: int = -1, softcap: float = 0.0, is_rotary_interleaved: bool = False, num_splits: int = 1, ) -> List[torch.Tensor]: """ Forward pass for multi-head attention with KV cache. Args: q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] seqlens_k: Optional sequence lengths for keys of shape [batch_size] rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2] rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2] cache_batch_idx: Optional indices to index into the KV cache leftpad_k: Optional left padding for keys of shape [batch_size] block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] out: Optional output tensor, same shape as q softmax_scale: Scale factor for softmax is_causal: Whether to use causal attention window_size_left: Window size for left context (-1 for unlimited) window_size_right: Window size for right context (-1 for unlimited) softcap: Soft cap for attention weights is_rotary_interleaved: Whether rotary embeddings are interleaved num_splits: Number of splits for computation Returns: List of tensors: [output, softmax_lse] """ return ops.mha_fwd_kvcache( q, kcache, vcache, k, v, seqlens_k, rotary_cos, rotary_sin, cache_batch_idx, leftpad_k, block_table, alibi_slopes, out, softmax_scale, is_causal, window_size_left, window_size_right, softcap, is_rotary_interleaved, num_splits, )