"""PyTorch Qwen2 model.""" import math from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from einops import rearrange from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10 ) from transformers.activations import ACT2FN if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis from flash_attn import flash_attn_varlen_func class ScaleDotProductCrossAttention(nn.Module): def __init__(self, layer_number, softmax_scale=None, attention_dropout=0.0): super().__init__() self.layer_number = layer_number self.softmax_scale = softmax_scale self.dropout_p = attention_dropout def forward(self, q, k, v, attn_mask=None): """Implements the multihead softmax attention. Arguments --------- q, k, v: The tensor containing the query, key, and value. (B, S, H, D) """ # (N,...,L,E) if attn_mask is not None: attn_mask = attn_mask[:,None,:,:].repeat(1, q.shape[1], 1, 1) # attention mask, True means it will take part in attention B H s_q s_k if self.training: dropout_p = self.dropout_p else: dropout_p = 0.0 if q.device.type == "cuda" and attn_mask is not None: q = q.contiguous() k = k.contiguous() v = v.contiguous() # debug only, calculate the FLOPs for cross-attn ################## # attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(128) # hardcode # if attn_mask is not None: # no matter the length, we just slice it # causal_mask = attn_mask[:, :, :, : k.shape[-2]] # attn_weights = attn_weights + causal_mask # # upcast attention to fp32 # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) # # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) # o = torch.matmul(attn_weights, v) ################### o = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=self.softmax_scale) # B Head L D -> L B (Head D) o = rearrange(o, 'B Head L D -> B L (Head D)').contiguous() return o class FlashAttnCrossAttention(nn.Module): def __init__(self, layer_number, softmax_scale=None, attention_dropout=0.0): super().__init__() self.layer_number = layer_number self.softmax_scale = softmax_scale self.dropout_p = attention_dropout def _get_unpad_data(self, attention_mask: torch.Tensor): """ Retrieves indexing data required to repad unpadded (ragged) tensors. Arguments: attention_mask (`torch.Tensor`): Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. Return: indices (`torch.Tensor`): The indices of non-masked tokens from the flattened input sequence. cu_seqlens (`torch.Tensor`): The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). max_seqlen_in_batch (`int`): Maximum sequence length in batch. """ seqlens_in_batch = attention_mask[:, 0, :].sum(dim=-1, dtype=torch.int32) # attn mask are the same for the query dimension, pick the first query indices = torch.nonzero(attention_mask[:, 0, :].flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, seqlens_in_batch ) def unpad_q(self, q_layer): # no need to unpad, just flatten batch_size, q_seq_len, num_key_value_heads, head_dim = q_layer.shape cu_seqlens_q = torch.tensor([q_seq_len] * batch_size, dtype=torch.int32, device=q_layer.device) cu_seqlens_q = nn.functional.pad(torch.cumsum(cu_seqlens_q, dim=0, dtype=torch.int32), (1, 0)) q_layer = q_layer.reshape(batch_size * q_seq_len, num_key_value_heads, head_dim) return ( q_layer, cu_seqlens_q, q_seq_len) def unpad_kv(self, key_layer, value_layer, attn_mask): indices_k, cu_seqlens_k, max_seqlen_in_batch_k, split_size = self._get_unpad_data(attn_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k ) return ( key_layer, value_layer, indices_k, cu_seqlens_k, max_seqlen_in_batch_k, split_size) def forward(self, q, k, v, attn_mask=None): """ Implements the multihead softmax attention with flash attention varlen api. Unpad the kv sequence Arguments --------- q, k, v: The tensor containing the query, key, and value. (B, S, H, D) """ # (N,...,L,E) q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) # NOTE: don't know if it's necessary if q.device.type == "cuda" and attn_mask is not None: q = q.contiguous() k = k.contiguous() v = v.contiguous() # batch_size = q.shape[0] # first unpad the q and kv, get cu_seq_len and indices batch_size, q_seq_len, head_num, head_dim = q.shape q, cu_seq_lens_q, max_seqlen_in_batch_q = self.unpad_q(q) k, v, indices_kv, cu_seq_lens_kv, max_seqlen_in_batch_kv, split_size = self.unpad_kv(k, v, attn_mask) attn_output = flash_attn_varlen_func( q, k, v, cu_seqlens_q=cu_seq_lens_q, cu_seqlens_k=cu_seq_lens_kv, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_kv, dropout_p=self.dropout_p if self.training else 0.0, softmax_scale=None, causal=False, # **flash_kwargs ) return attn_output.reshape(batch_size, q_seq_len, head_num, head_dim).flatten(2, 3).contiguous() # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__( self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, rope_type="default", config=None, ): super().__init__() # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} if config is None: self.rope_kwargs = { "rope_type": rope_type, "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = rope_type self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, position_ids): if "dynamic" in self.rope_type: self._dynamic_frequency_update(position_ids, device=x.device) # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention cos = cos * self.attention_scaling sin = sin * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state): return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class Qwen2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer and "Generating Long Sequences with Sparse Transformers". """ def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx # if layer_idx is None: # logger.warning_once( # f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " # "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " # "when creating this class." # ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Qwen2FlashAttention2(Qwen2Attention): """ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom config.max_window_layers layers. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 kv_seq_len = key_states.shape[-2] + cache_position[0] if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and cache_has_contents ): slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" f" {past_key.shape}" ) if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype # logger.warning_once( # f"The input hidden states seems to be silently casted in float32, this might be related to" # f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" # f" {target_dtype}." # ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window else: sliding_window = None attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Qwen2HybridFlashAttention2(Qwen2FlashAttention2): """ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom config.max_window_layers layers. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, is_hyper_enabled, gating_type, cross_attn_implementation, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() self.is_hyper_enabled = is_hyper_enabled if self.is_hyper_enabled: self.gating_type = gating_type self.cross_attention_implementation = cross_attn_implementation self.cross_attn_kv_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim * 2, bias=True) if gating_type.startswith("whole-dynamic"): if "tanh" in gating_type: self.cross_attn_gate_proj = nn.Sequential( nn.Linear(self.hidden_size, 1), nn.Tanh() ) else: self.cross_attn_gate_proj = nn.Sequential( nn.Linear(self.hidden_size, 1), ) if gating_type.endswith("warmup"): self.cross_attn_warm_up_gate = torch.nn.Parameter(torch.zeros(1)) if "flashattn" in self.cross_attention_implementation: self.cross_attn_core_attention = FlashAttnCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout) else: self.cross_attn_core_attention = ScaleDotProductCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout) def all2media_cross_attn(self, text_state, text_query, vision_features, text2vision_cross_attn_mask=None, all_text_mask=None): ''' text_query: [s b h d] text_state: s b d vision_features: [num_vis, b, d] ''' if vision_features is None or (self.is_hyper_enabled == False): return text_state L_c, B_c = text_state.shape[:2] D_head = self.head_dim if "whole-dynamic" in self.gating_type: gate_value = self.cross_attn_gate_proj(text_state) # n, bs, head_D if "warmup" in self.gating_type: gate_value = gate_value * self.cross_attn_warm_up_gate vision_features = vision_features.contiguous() vision_features = self.cross_attn_kv_proj(vision_features) text_query = rearrange(text_query, 'L B H D -> B H L D') # [25, 2, 32, 128]) vision_kv = rearrange(vision_features, 'BL Lv (H KV D) -> KV BL H Lv D', KV=2, H=self.num_key_value_heads) vision_key = vision_kv[0].contiguous() # [b h s d] vision_value = vision_kv[1].contiguous() vision_key = repeat_kv(vision_key, self.num_key_value_groups) vision_value = repeat_kv(vision_value, self.num_key_value_groups) # expend_cross_attn_mask attention_mask = text2vision_cross_attn_mask[:, None, :].repeat(1, text_state.shape[0], 1) vision_context = self.cross_attn_core_attention(text_query, vision_key, vision_value, attn_mask=attention_mask).transpose(0, 1) # mask out the output if a sample is pure text vision_context = all_text_mask[None, :, None] * vision_context # Apply dynamic gate text_state = text_state + vision_context * gate_value return text_state def onlytext2media_cross_attn(self, text_state, text_query, vision_features, token_type, text2vision_cross_attn_mask=None, all_text_mask=None): ''' text_query: [bs n h d] text_state: [bs n d] vision_features: [bs, vis_n, d] token_type: [bs, n] ''' # if vision_features is None or (self.is_hyper_enabled == False) or (all_text_mask.sum() == 0): if vision_features is None or (self.is_hyper_enabled == False): return text_state # select all the pure text token pure_text_query = [] text_mask = ((token_type - 2) <= 0).bool() if "masksystem" in self.cross_attention_implementation: new_text_masks = [] for idx, text_query_ in enumerate(text_query): # mask out all the tokens before the media first_im_token = (token_type[idx] == 3).nonzero() if len(first_im_token) == 0: start = 0 else: start = first_im_token[0] text_mask_ = text_mask[idx].clone() text_mask_[:start] = False pure_text_query.append(text_query_[text_mask_]) new_text_masks.append(text_mask_) text_mask = torch.stack(new_text_masks, dim=0) else: for idx, text_query_ in enumerate(text_query): pure_text_query.append(text_query_[text_mask[idx]]) # 2. pad all the text tokens text_query = torch.nn.utils.rnn.pad_sequence(pure_text_query, batch_first=True) padding_attn_mask = torch.ones(text_query.shape[:-2], dtype=torch.bool, device=text_state.device) for i, tensor in enumerate(pure_text_query): padding_attn_mask[i, len(tensor):] = False # Mark padded elements as False B_c, L_c = text_query.shape[:2] D_head = self.head_dim # obtain dynamic gate value gate_value = self.cross_attn_gate_proj(text_state[text_mask]) # n, D if "warmup" in self.gating_type: gate_value = gate_value * self.cross_attn_warm_up_gate.tanh() vision_features = vision_features.contiguous() vision_features = self.cross_attn_kv_proj(vision_features) text_query = text_query.transpose(1, 2) vision_kv = rearrange(vision_features, 'BL Lv (H KV D) -> KV BL H Lv D', KV=2, H=self.num_key_value_heads) vision_key = vision_kv[0].contiguous() # [b h s d] vision_value = vision_kv[1].contiguous() vision_key = repeat_kv(vision_key, self.num_key_value_groups) vision_value = repeat_kv(vision_value, self.num_key_value_groups) # expend_cross_attn_mask attention_mask = text2vision_cross_attn_mask[:, None, :].repeat(1, text_query.shape[2], 1) vision_context = self.cross_attn_core_attention(text_query, vision_key, vision_value, attn_mask=attention_mask) # mask out the output if a sample is pure text vision_context = all_text_mask[:, None, None] * vision_context # Apply dynamic gate extended_attn_output = torch.zeros_like(text_state, dtype=text_state.dtype, device=text_state.device) extended_attn_output[text_mask] = extended_attn_output[text_mask] + vision_context[padding_attn_mask] * gate_value text_state = text_state + extended_attn_output # NOTE Min: just equvalent to the following line. Avoid error under deepspeed zero3 # text_state[text_mask] = text_state[text_mask] + vision_context[padding_attn_mask] * gate_value return text_state def forward( self, hidden_states: torch.Tensor, visual_hidden_states: torch.Tensor, token_type: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, text2visual_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 kv_seq_len = key_states.shape[-2] + cache_position[0] if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window and cache_has_contents ): slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" f" {past_key.shape}" ) if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype # logger.warning_once( # f"The input hidden states seems to be silently casted in float32, this might be related to" # f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" # f" {target_dtype}." # ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window else: sliding_window = None attn_output = _flash_attention_forward( query_states, # bs, n, head, head_dim key_states, value_states, attention_mask, q_len, position_ids=position_ids, dropout=dropout_rate, sliding_window=sliding_window, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() # text-to-image cross-attention #### all_text_mask = (token_type == 3).sum(dim=-1).bool() # [bs, ] if False, indicate that this sample contains no image input if self.cross_attention_implementation.startswith("vanilla"): # all tokens can attend to the slow tokens attn_output = self.all2media_cross_attn(attn_output.permute(1, 0, 2), query_states.permute(1, 0, 2, 3), visual_hidden_states, text2visual_attention_mask, all_text_mask) attn_output = attn_output.permute(1,0,2) elif self.cross_attention_implementation.startswith("text-only-vanilla"): # only text tokens are allowed to attend the slow tokens attn_output = self.onlytext2media_cross_attn(attn_output, query_states, visual_hidden_states, token_type=token_type, text2vision_cross_attn_mask=text2visual_attention_mask, all_text_mask=all_text_mask ) else: raise NotImplementedError(f"cross-attention type {self.cross_attention_implementation} not implemented") #### attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ # Adapted from Qwen2Attention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # logger.warning_once( # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' # ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value # TODO: Min: Not implementated yet class Qwen2HybridSdpaAttention(Qwen2SdpaAttention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ def __init__(self, is_hyper_enabled, gating_type, cross_attn_implementation, *args, **kwargs): super().__init__(*args, **kwargs) self.is_hyper_enabled = is_hyper_enabled if self.is_hyper_enabled: self.gating_type = gating_type self.cross_attention_implementation = cross_attn_implementation self.cross_attn_kv_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim * 2, bias=True) if gating_type.startswith("whole-dynamic"): if "tanh" in gating_type: self.cross_attn_gate_proj = nn.Sequential( nn.Linear(self.hidden_size, 1), nn.Tanh() ) else: self.cross_attn_gate_proj = nn.Sequential( nn.Linear(self.hidden_size, 1), ) if gating_type.endswith("warmup"): self.cross_attn_warm_up_gate = torch.nn.Parameter(torch.zeros(1)) if "flashattn" in self.cross_attention_implementation: self.cross_attn_core_attention = FlashAttnCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout) else: self.cross_attn_core_attention = ScaleDotProductCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout) def text2media_cross_attn(self, text_state, text_query, vision_features, text2vision_cross_attn_mask=None, all_text_mask=None): ''' text_query: [s b h d] text_state: s b d vision_features: [num_vis, b, d] ''' if vision_features is None or (self.is_hyper_enabled == False): return text_state # obtain dynamic gate value L_c, B_c = text_state.shape[:2] D_head = self.head_dim gate_value = rearrange( self.gate_proj( rearrange(text_state, 'L B (Head D) -> (L B Head) D', D=D_head)), '(L B Head) D -> L B (Head D)', L=L_c, B=B_c) vision_features = vision_features.contiguous() vision_features = self.v_kv_proj(vision_features) # length_each_img = vision_features.shape[1] # sequence_length = text_query.shape[0] query_layer = rearrange(query_layer, 'L B H D -> B H L D') # [25, 2, 32, 128]) vision_kv = rearrange(vision_features, 'BL Lv (H KV D) -> KV 1 H (BL Lv) D', KV=2, H=self.num_key_value_heads) vision_key = vision_kv[0].contiguous() # [b h s d] vision_value = vision_kv[1].contiguous() # Apply MI-Rope # key_layer = self.apply_mi_rope(key_layer, media_offset_line=self.visual_cache['media_offset'][batch_id,:,1]-curr_offset[0], length_each_img=length_each_img) key_layer = repeat_kv(key_layer, self.num_key_value_groups) value_layer = repeat_kv(value_layer, self.num_key_value_groups) vision_context = self.v_core_attention_sdpa(query_layer, vision_key, vision_value, attn_mask=None, order='bhsd').squeeze(1) # TODO # Apply dynamic gate text_state = text_state * (1 - gate_value) + vision_context * gate_value return text_state # Adapted from Qwen2Attention.forward def forward( self, hidden_states: torch.Tensor, visual_hidden_states: torch.Tensor, token_type: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, text2visual_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # logger.warning_once( # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' # ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) # text-to-image cross-attention #### all_text_mask = (token_type == 3).sum(dim=-1).bool() # [bs, ] if False, indicate that this sample contains no image input if self.cross_attention_implementation.startswith("vanilla"): attn_output = self.text2media_cross_attn(attn_output.permute(1, 0, 2), query_states.permute(1, 0, 2, 3), visual_hidden_states, text2visual_attention_mask, all_text_mask) attn_output = attn_output.permute(1,0,2) elif self.cross_attention_implementation.startswith("text-only-vanilla"): attn_output = self.onlytext2media_cross_attn(attn_output, query_states, visual_hidden_states, token_type=token_type, text2vision_cross_attn_mask=text2visual_attention_mask, all_text_mask=all_text_mask ) else: raise NotImplementedError(f"cross-attention type {self.cross_attention_implementation} not implemented") #### attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, "flash_attention_2": Qwen2FlashAttention2, "sdpa": Qwen2SdpaAttention, } QWEN2_HYBRID_ATTENTION_CLASSES = { "flash_attention_2": Qwen2HybridFlashAttention2, "sdpa": Qwen2HybridSdpaAttention, # Not implemented yet, only support flash attn } class Qwen2DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size if config.sliding_window and config._attn_implementation != "flash_attention_2": # logger.warning_once( # f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " # "unexpected results may be encountered." # ) pass self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs class Qwen2HybridDecoderLayer(nn.Module): def __init__(self, config, layer_idx: int, is_hyper_enabled=False, cross_attn_implementation="vanilla", # in ['vanilla' and 'text-only-vanilla'] cross_attn_gating_type="channel-wise-dynamic-sigmoid"): super().__init__() self.is_hyper_enabled = is_hyper_enabled self.hidden_size = config.hidden_size if config.sliding_window and config._attn_implementation != "flash_attention_2": # logger.warning_once( # f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " # "unexpected results may be encountered." # ) pass self.self_attn = QWEN2_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx, is_hyper_enabled=is_hyper_enabled, cross_attn_implementation=cross_attn_implementation, gating_type=cross_attn_gating_type) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # move the gradient checkpointing to the forward function of attn and MLP # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) def condition_vis_x(self, vis_x, cross_attn_mask=None, token_type=None): self.vis_x = vis_x self.cross_attn_mask = cross_attn_mask self.media_locations = token_type def clear_vis_x(self): self.vis_x = None self.cross_attn_mask = None self.media_locations = None def mlp_forward(self, hidden_states): hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) return hidden_states def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # process image embedding visual_tokens = self.vis_x cross_attn_mask = self.cross_attn_mask token_type = self.media_locations visual_tokens = self.input_layernorm(visual_tokens) # Self Attention if self.gradient_checkpointing and self.training: hidden_states, self_attn_weights, present_key_value = torch.utils.checkpoint.checkpoint( self.self_attn, hidden_states, visual_tokens, token_type, attention_mask, cross_attn_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings ) else: hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, visual_hidden_states=visual_tokens, text2visual_attention_mask=cross_attn_mask, token_type=token_type, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states if self.gradient_checkpointing and self.training: hidden_states = torch.utils.checkpoint.checkpoint( self.mlp_forward, hidden_states) else: hidden_states = self.mlp_forward(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs