Phi-4-multimodal-instruct / modeling_phi4_multimodal.py
cyrilvallez's picture
cyrilvallez HF Staff
Upload folder using huggingface_hub
698b586 verified
raw
history blame
107 kB
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_phi4_multimodal.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from functools import wraps
from typing import Callable, List, Optional, Tuple, Union, Any
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPast,
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig
logger = logging.get_logger(__name__)
def set_attribute_for_modules(module: "torch.nn.Module", key: str, value: Any):
"""
Set a value to a module and all submodules.
"""
setattr(module, key, value)
for submodule in module.children():
set_attribute_for_modules(submodule, key, value)
def del_attribute_from_modules(module: "torch.nn.Module", key: str):
"""
Delete a value from a module and all submodules.
"""
# because we might remove it previously in case it's a shared module, e.g. activation function
if hasattr(module, key):
delattr(module, key)
for submodule in module.children():
del_attribute_from_modules(submodule, key)
def can_return_tuple(func):
"""
Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or
use_return_dict=False is set in the config.
Note:
output.to_tuple() convert output to tuple skipping all `None` values.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
is_requested_to_return_tuple = kwargs.pop("return_dict", True) is False
is_configured_to_return_tuple = self.config.use_return_dict is False if hasattr(self, "config") else False
# The following allows to convert output to tuple ONLY on top level forward call,
# while internal modules of the model will return Output objects
# to be able to use name-based attribute access in modeling code.
# We will check if we are on top level module, if so, turn off to tuple conversion for all
# underling calls.
is_top_level_module = getattr(self, "_is_top_level_module", True)
if is_configured_to_return_tuple and is_top_level_module:
set_attribute_for_modules(self, "_is_top_level_module", False)
try:
output = func(self, *args, **kwargs)
if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
output = output.to_tuple()
finally:
# Remove the flag after the model forward call is finished.
if is_configured_to_return_tuple and is_top_level_module:
del_attribute_from_modules(self, "_is_top_level_module")
return output
return wrapper
def dynamic_rope_update(rope_forward):
"""
Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
(i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
Args:
rope_forward (Callable):
The forward pass of the RoPE implementation.
Returns:
The decorated forward pass.
"""
def longrope_frequency_update(self, position_ids, device):
"""Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
seq_len = torch.max(position_ids) + 1
if hasattr(self.config, "original_max_position_embeddings"):
original_max_position_embeddings = self.config.original_max_position_embeddings
else:
original_max_position_embeddings = self.config.max_position_embeddings
if seq_len > original_max_position_embeddings:
if not hasattr(self, "long_inv_freq"):
self.long_inv_freq, _ = self.rope_init_fn(
self.config, device, seq_len=original_max_position_embeddings + 1
)
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
else:
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
def dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@wraps(rope_forward)
def wrapper(self, x, position_ids):
if "dynamic" in self.rope_type:
dynamic_frequency_update(self, position_ids, device=x.device)
elif self.rope_type == "longrope":
longrope_frequency_update(self, position_ids, device=x.device)
return rope_forward(self, x, position_ids)
return wrapper
class Phi4MultimodalVisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def simple_eager_attention_forward(
module: nn.Module,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class Phi4MultimodalVisionAttention(nn.Module):
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scaling = self.head_dim**-0.5
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = simple_eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class Phi4MultimodalVisionEncoderLayer(nn.Module):
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = Phi4MultimodalVisionAttention(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Phi4MultimodalVisionMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Phi4MultimodalVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Phi4MultimodalVisionEncoderLayer`].
Args:
config: Phi4MultimodalVisionConfig
"""
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
# Ignore copy
@can_return_tuple
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsequently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Phi4MultimodalVisionConfig
base_model_prefix = "phi4_vision"
supports_gradient_checkpointing = True
_no_split_modules = ["Phi4MultimodalVisionEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Phi4MultimodalVisionEmbeddings):
width = (
self.config.hidden_size
if isinstance(self.config, Phi4MultimodalVisionConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, Phi4MultimodalVisionAttention):
nn.init.normal_(module.q_proj.weight)
nn.init.normal_(module.k_proj.weight)
nn.init.normal_(module.v_proj.weight)
nn.init.normal_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, Phi4MultimodalVisionMLP):
nn.init.normal_(module.fc1.weight)
nn.init.normal_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead):
nn.init.normal_(module.probe.data)
nn.init.normal_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class Phi4MultimodalVisionEmbeddings(nn.Module):
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.config = config
self.patch_size = config.patch_size
self.num_patches_per_side = config.image_size // self.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=config.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and no class embeddings.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embedding.weight.shape[0]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding(self.position_ids)
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
batch_size = pixel_values.size(0)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = Phi4MultimodalVisionMLP(config)
def forward(self, hidden_state, attention_mask):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(
query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel):
config_class = Phi4MultimodalVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: Phi4MultimodalVisionConfig):
super().__init__(config)
self.config = config
self.embeddings = Phi4MultimodalVisionEmbeddings(config)
self.encoder = Phi4MultimodalVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embedding
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_attention_mask = torch.ones(
size=(
batch_size,
pixel_values.size(2) // self.config.patch_size,
pixel_values.size(3) // self.config.patch_size,
),
dtype=torch.bool,
device=pixel_values.device,
)
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
attention_mask = None
else:
attention_mask = (
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
if not self.config._attn_implementation == "flash_attention_2"
else patch_attention_mask
)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(
hidden_state=last_hidden_state,
attention_mask=patch_attention_mask,
)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Phi4MultimodalImageEmbedding(nn.Module):
"""Image embedding."""
def __init__(self, config: Phi4MultimodalConfig):
super().__init__()
self.config = config
self.layer_idx = config.vision_config.feature_layer
self.crop_size = config.vision_config.crop_size
self.image_dim_out = config.vision_config.hidden_size
n_patches = config.vision_config.image_size // config.vision_config.patch_size
if n_patches % 2 != 0:
self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1))
n_patches += 1
self.num_img_tokens = (n_patches // 2) ** 2
self.drop = nn.Dropout(config.embd_pdrop)
self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config)
self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2)
self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size)
self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size)
self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out]))
self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out]))
def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor:
img_processor_output = self.img_processor(
img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True
)
img_feature = img_processor_output.hidden_states[self.layer_idx]
patch_feature = img_feature
# reshape to 2D tensor
width = int(math.sqrt(patch_feature.size(1)))
patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1))
# convert to NCHW
patch_feature = patch_feature.permute(0, 3, 1, 2)
if getattr(self, "img_processor_padding", None) is not None:
patch_feature = self.img_processor_padding(patch_feature)
patch_feature = self.image_token_compression(patch_feature)
# convert to NHWC
patch_feature = patch_feature.permute(0, 2, 3, 1)
patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1))
return patch_feature
def forward(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.Tensor,
image_pixel_values: torch.FloatTensor,
image_sizes: Optional[torch.Tensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype)
target_device = self.img_projection_up.bias.device
target_dtype = self.img_projection_up.bias.dtype
batch_size = image_pixel_values.shape[0]
img_features = self.get_img_features(
image_pixel_values.flatten(0, 1),
attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device),
)
base_feat_size = int(np.sqrt(img_features.shape[1]))
img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out)
image_sizes = image_sizes.view(-1, 2)
output_imgs = []
for idx in range(batch_size):
height, width = image_sizes[idx]
height_ratio = height // self.crop_size
width_ratio = width // self.crop_size
area_ratio = height_ratio * width_ratio
global_img = img_features[idx, :1]
global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous()
temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1)
global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out)
sub_img = img_features[idx, 1:]
sub_img = sub_img[:area_ratio]
sub_img = (
sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out)
.transpose(1, 2)
.reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out)
.contiguous()
)
if image_attention_mask is not None:
reshaped_image_attention_mask = (
image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2]
.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size)
.transpose(1, 2)
.reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size)
)
useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
sub_img = sub_img[:, :useful_height, :useful_width]
temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1)
else:
temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1)
sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out)
# Merge global and sub
output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1))
img_set_tensor = []
for output_img in output_imgs:
output_img = output_img.to(device=target_device, dtype=target_dtype)
img_feature_proj = self.img_projection_up(output_img)
img_feature_proj = nn.functional.gelu(img_feature_proj)
img_feature_proj = self.img_projection_down(img_feature_proj)
img_set_tensor.append(img_feature_proj)
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
with torch.no_grad():
positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True)
# Temporarily disable autocast to avoid issue on bf16 tensors
# Ref: https://github.com/pytorch/pytorch/issues/132715
with torch.autocast(device_type=inputs_embeds.device.type, enabled=False):
image_embeds = inputs_embeds.index_put(
indices=positions_tuple, values=merged_img_set_tensor, accumulate=False
)
image_embeds = self.drop(image_embeds)
return image_embeds
########################################################## AUDIO #############################################
class Phi4MultimodalAudioMLP(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.act_fn = ACT2FN[config.activation]
self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
up_states = self.gate_up_proj(hidden_states)
up_states, gate = up_states.chunk(2, dim=-1)
up_states = up_states * self.act_fn(gate)
up_states = self.dropout(up_states)
hidden_states = self.down_proj(up_states)
out = self.dropout(hidden_states)
return out
class Phi4MultimodalAudioAttention(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.config = config
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.dropout_rate
self.is_causal = True
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs,
):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = simple_eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0):
super().__init__()
self.dw_conv = nn.Conv1d(
config.hidden_size,
config.hidden_size * config.depthwise_multiplier,
config.kernel_size,
1,
padding=padding,
groups=config.hidden_size,
)
self.pw_conv = nn.Conv1d(
config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0
)
def forward(self, hidden_states):
return self.pw_conv(self.dw_conv(hidden_states))
class Phi4MultimodalAudioGluPointWiseConv(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.config = config
self.output_dim = config.ext_pw_out_channel
self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1)
self.glu_act = ACT2FN[config.conv_glu_type]
self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1))
self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1))
def forward(self, hidden_states):
# we assume the input always has the #channel (#dim) in the last dimension of the
# tensor, so need to switch the dimension first for 1D-Conv case
hidden_states = hidden_states.permute([0, 2, 1])
hidden_states = self.ext_pw_conv_1d(hidden_states)
out = hidden_states[:, 0 : self.output_dim, :] + self.b1
out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2)
return out.permute([0, 2, 1])
class Phi4MultimodalAudioConvModule(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.config = config
self.kernel_size = config.kernel_size
self.layer_norm = nn.LayerNorm(config.hidden_size)
self.glu = Phi4MultimodalAudioGluPointWiseConv(config)
self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1)
self.act = ACT2FN[config.conv_activation]
self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.glu(self.layer_norm(hidden_states))
hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1]))
if self.kernel_size > 1:
hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)]
hidden_states = self.act(hidden_states)
hidden_states = self.ext_pw_conv_1d(hidden_states)
out = self.dropout(hidden_states.permute([0, 2, 1]))
return out
class Phi4MultimodalAudioConformerEncoderLayer(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.feed_forward_in = Phi4MultimodalAudioMLP(config)
self.self_attn = Phi4MultimodalAudioAttention(config)
self.conv = Phi4MultimodalAudioConvModule(config)
self.feed_forward_out = Phi4MultimodalAudioMLP(config)
self.layer_norm_att = nn.LayerNorm(config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
):
residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states)
hidden_states = self.layer_norm_att(residual)
hidden_states = residual + self.self_attn(hidden_states, attention_mask)
hidden_states = hidden_states + self.conv(hidden_states)
hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states)
out = self.layer_norm(hidden_states)
return out
class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.subsampling_factor = config.time_reduction
self.sampling_num = int(math.log(self.subsampling_factor, 2))
self.act_fn = ACT2FN[config.nemo_activation]
conv_channels = config.nemo_conv_channels
layers = [
nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1),
self.act_fn,
]
for _ in range(self.sampling_num - 1):
layers.extend(
[
nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels),
nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1),
self.act_fn,
]
)
# Aggregate the layers
self.conv = torch.nn.Sequential(*layers)
self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]):
# Unsqueeze Channel Axis
hidden_states = hidden_states.unsqueeze(1)
hidden_states = self.conv(hidden_states)
# Flatten Channel and Frequency Axes
b, _, t, _ = hidden_states.size()
hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1))
if mask is None:
return hidden_states, None
max_audio_length = hidden_states.shape[1]
feature_lens = mask.sum(1)
padding_length = torch.ceil(feature_lens / self.subsampling_factor)
arange_ = torch.arange(0, max_audio_length, device=hidden_states.device)
pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1)
return hidden_states, pad_mask.unsqueeze(1)
class Phi4MultimodalAudioRelativeAttentionBias(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.max_distance = config.bias_max_distance
self.symmetric = config.bias_symmetric
self.num_buckets = self.max_distance
if not config.bias_symmetric:
self.num_buckets *= 2
self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads)
def forward(self, x):
# instantiate bias compatible with shape of x
max_pos = x.size(1)
context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None]
memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :]
relative_position = memory_position - context_position
# clipping to a maximum distance using ops that play well with ONNX export
relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance)
relative_position = relative_position.masked_fill(
relative_position > self.max_distance - 1, self.max_distance - 1
)
# mapping from relative position to index in the bias parameter
bias_idx = relative_position
bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2
att_bias = self.bias_values(bias_idx)
att_bias = att_bias.permute(2, 0, 1).unsqueeze(0)
return att_bias
class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__()
self.register_buffer("global_mean", torch.zeros(config.input_size))
self.register_buffer("global_invstd", torch.ones(config.input_size))
def forward(self, x):
return (x - self.global_mean) * self.global_invstd
class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel):
config_class = Phi4MultimodalAudioConfig
supports_gradient_checkpointing = True
_no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv):
module.b1.data.zero_()
module.b2.data.zero_()
def unfold_tensor(tensor, max_seq_len):
"""
For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len,
this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len.
Args:
tensor: N, T, D
"""
_, _, D = tensor.shape
tensor = tensor.transpose(-1, -2)
# N x D x 1 x T => N x (D x max_seq_len) x T'
tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len))
new_bsz, _, slen = tensor.shape
tensor = tensor.view(new_bsz, -1, max_seq_len, slen)
tensor = tensor.permute(0, 3, 2, 1)
tensor = tensor.view(-1, max_seq_len, D).contiguous()
return tensor
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
"""
The function is very important for Transformer Transducer Streaming mode
Args:
xs_len (int): sequence length
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45]
left_window (int): how many left chunks can be seen
right_window (int): how many right chunks can be seen. It is used for chunk overlap model.
Returns:
mask (torch.Tensor): a mask tensor for streaming model
"""
chunk_start_idx = torch.Tensor(chunk_start_idx).long()
start_pad = torch.nn.functional.pad(
chunk_start_idx, (1, 0)
) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
end_pad = torch.nn.functional.pad(
chunk_start_idx, (0, 1), value=x_len
) # append x_len to the end, so it becomes [0,18,36,48, x_len]
seq_range = torch.arange(0, x_len).unsqueeze(-1)
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
idx_left = idx - left_window
idx_left[idx_left < 0] = 0
boundary_left = start_pad[idx_left]
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
idx_right = idx + right_window
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
boundary_right = end_pad[idx_right]
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
return mask_left & mask_right
class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel):
def __init__(self, config: Phi4MultimodalAudioConfig):
super().__init__(config)
self.config = config
self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config)
self.embed = Phi4MultimodalAudioNemoConvSubsampling(config)
self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config)
self.encoders = nn.ModuleList(
[Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)]
)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size)
# avoid randomness when run evaluation or decoding
if self.training and np.random.rand() > 0.5:
# Either first or last chunk is not complete.
# If only the last one is not complete, EOS is not effective
chunk_start_idx = seq_len - chunk_start_idx
chunk_start_idx = chunk_start_idx[::-1]
chunk_start_idx = chunk_start_idx[:-1]
chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
enc_streaming_mask = (
adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk)
.unsqueeze(0)
.expand([batch_size, -1, -1])
)
return enc_streaming_mask
def forward_embeddings(self, hidden_states, masks):
"""Forwarding the inputs through the top embedding layers"""
seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction)
if seq_len <= 0:
raise ValueError(
f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short."
)
batch_size = hidden_states.shape[0]
enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk)
enc_streaming_mask = enc_streaming_mask.to(hidden_states.device)
hidden_states, masks = self.embed(hidden_states, masks)
streaming_mask = enc_streaming_mask
if streaming_mask is not None and masks is not None:
hs_mask = masks & streaming_mask
elif masks is not None:
hs_mask = masks
else:
hs_mask = streaming_mask
return hidden_states, hs_mask, masks
def calculate_hs_mask(self, hidden_states, device, mask):
max_audio_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0]
enc_streaming_mask = self._streaming_mask(
max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk
)
enc_streaming_mask = enc_streaming_mask.to(device)
if mask is None:
return enc_streaming_mask
feature_lens = mask.sum(1)
padding_length = feature_lens
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) < padding_length.unsqueeze(1)
pad_mask = pad_mask.unsqueeze(1)
pad_mask = pad_mask & enc_streaming_mask
return pad_mask
def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]):
hidden_states = self.encoder_embedding(hidden_states)
hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask)
unfolded = False
bs, seq_len, _ = hidden_states.shape
max_seq_len = 500 # maxium position for absolute positional encoding
if seq_len > max_seq_len:
# audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len
unfolded = True
# the unfold op will drop residual frames, pad it to the multiple of max_seq_len
if seq_len % max_seq_len > 0:
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
else:
chunk_pad_size = 0
if chunk_pad_size > 0:
hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0)
hidden_states = hidden_states_pad.to(hidden_states.device)
hidden_states = unfold_tensor(hidden_states, max_seq_len)
masks_unfold = None
if mask is not None:
# revise hs_mask here because the previous calculated hs_mask did not consider extra pad
subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len]
extra_padded_subsamlped_pad_mask = F.pad(
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
) # extra padding to the pad mask
extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
masks_unfold = unfold_tensor(
extra_padded_subsamlped_pad_mask, max_seq_len
) # unfold the pad mask like we did to the input tensor
masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor
hs_mask = self.calculate_hs_mask(
hidden_states, hidden_states.device, masks_unfold
) # calculate hs_mask based on the unfolded pad mask
relative_attention_bias = self.relative_attention_bias_layer(hidden_states)
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias
for layer in self.encoders:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
)
else:
hidden_states = layer(hidden_states, attention_mask)
if unfolded:
embed_dim = hidden_states.shape[-1]
hidden_states = hidden_states.reshape(bs, -1, embed_dim)
# if we ever padded before unfolding, we need to remove the padding
if chunk_pad_size > 0:
hidden_states = hidden_states[:, :-chunk_pad_size, :]
return hidden_states
class Phi4MultimodalAudioEmbedding(nn.Module):
def __init__(self, config: Phi4MultimodalConfig):
super().__init__()
self.config = config
self.layer_idx = config.audio_config.feature_layer
self.drop = nn.Dropout(config.embd_pdrop)
self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config)
self.up_proj_for_speech = nn.Linear(
config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size
)
self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size)
self.up_proj_for_vision_speech = nn.Linear(
config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size
)
self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size)
def forward(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.Tensor,
audio_input_features: torch.FloatTensor,
audio_embed_sizes=None,
audio_attention_mask=None,
audio_projection_mode="speech",
) -> torch.FloatTensor:
with torch.no_grad():
positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True)
up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech
down_proj = (
self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech
)
target_device = up_proj.bias.device
target_dtype = up_proj.bias.dtype
audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype)
audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask)
audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states)
audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states)
audio_embeds = down_proj(audio_encoder_hidden_states)
merged_audio_embeds = torch.cat(
[audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0
)
merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
# Temporarily disable autocast to avoid issue on bf16 tensors
# Ref: https://github.com/pytorch/pytorch/issues/132715
with torch.autocast(device_type=inputs_embeds.device.type, enabled=False):
audio_embeds = inputs_embeds.index_put(
indices=positions_tuple, values=merged_audio_embeds, accumulate=False
)
audio_embeds = self.drop(audio_embeds)
return audio_embeds
class Phi4MultimodalRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Phi4MultimodalRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Phi4MultimodalMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
return q_embed, k_embed
class Phi4MultimodalAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.num_key_value_heads = config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
qkv = self.qkv_proj(hidden_states)
query_pos = self.config.num_attention_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=getattr(self.config, "sliding_window", None),
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Phi4MultimodalDecoderLayer(nn.Module):
def __init__(self, config: Phi4MultimodalConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx)
self.mlp = Phi4MultimodalMLP(config)
self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.config = config
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`):
input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
past_key_value (`Cache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class Phi4MultimodalFeatureEmbedding(nn.Module):
"""Image-audio embedding."""
def __init__(self, config: Phi4MultimodalConfig) -> None:
super().__init__()
self.config = config
self.image_token_id = config.vision_config.image_token_id
self.audio_token_id = config.audio_config.audio_token_id
self.image_embed = Phi4MultimodalImageEmbedding(config)
self.audio_embed = Phi4MultimodalAudioEmbedding(config)
def forward(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.Tensor,
image_pixel_values: Optional[torch.FloatTensor] = None,
audio_input_features: Optional[torch.FloatTensor] = None,
image_sizes=None,
image_attention_mask=None,
audio_embed_sizes=None,
audio_attention_mask=None,
) -> torch.FloatTensor:
with torch.no_grad():
image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1)
non_image_position_mask = ~image_position_mask
image_embeds = None
audio_embeds = None
if image_pixel_values is not None and (input_ids == self.image_token_id).any():
image_embeds = self.image_embed(
input_ids,
inputs_embeds,
image_pixel_values=image_pixel_values,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
)
if audio_input_features is not None and (input_ids == self.audio_token_id).any():
audio_projection_mode = "vision" if image_pixel_values is not None else "speech"
audio_embeds = self.audio_embed(
input_ids,
inputs_embeds,
audio_input_features=audio_input_features,
audio_embed_sizes=audio_embed_sizes,
audio_attention_mask=audio_attention_mask,
audio_projection_mode=audio_projection_mode,
)
# merge image and audio
if image_embeds is not None and audio_embeds is not None:
inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask
elif image_embeds is not None:
inputs_embeds = image_embeds
elif audio_embeds is not None:
inputs_embeds = audio_embeds
return inputs_embeds
class Phi4MultimodalRotaryEmbedding(nn.Module):
def __init__(self, config: Phi4MultimodalConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
PHI4_MULTIMODAL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Phi4MultimodalConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.",
PHI4_MULTIMODAL_START_DOCSTRING,
)
class Phi4MultimodalPreTrainedModel(PreTrainedModel):
config_class = Phi4MultimodalConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Phi4MultimodalDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
_version = "0.0.5"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Phi4MultimodalRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, Phi4MultimodalImageEmbedding):
module.global_img_feature_extensor.data.zero_()
module.sub_img_feature_extensor.data.zero_()
PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache`)`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
image_pixel_values (`torch.FloatTensor`, *optional*):
If the input contains images, these correspond to the pixel values after transformations (as returned by
the Processor)
image_sizes (`torch.LongTensor`, *optional*):
If the input contains images, these correspond to size of each image.
image_attention_mask (`torch.LongTensor`, *optional*):
Attention mask for the images.
audio_input_features (`torch.FloatTensor`, *optional*):
If the input contains audio samples, these correspond to the values after transformation (as returned by
the Processor).
audio_embed_sizes (`torch.Tensor`, *optional*):
Size of the audio inputs.
audio_attention_mask (`torch.Tensor, *optional*):
Attention mask for the audio inputs.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
@add_start_docstrings(
"The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.",
PHI4_MULTIMODAL_START_DOCSTRING,
)
class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`]
Args:
config: Phi4MultimodalMMConfig
"""
def __init__(self, config: Phi4MultimodalConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@can_return_tuple
@add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
image_attention_mask=None,
audio_input_features: Optional[torch.FloatTensor] = None,
audio_embed_sizes=None,
audio_attention_mask=None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.embed_tokens_extend(
input_ids,
inputs_embeds,
image_pixel_values=image_pixel_values,
audio_input_features=audio_input_features,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
audio_embed_sizes=audio_embed_sizes,
audio_attention_mask=audio_attention_mask,
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: Union[torch.Tensor, "BlockMask"],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool = False,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
return attention_mask
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache:
target_length = past_key_values.get_max_cache_shape()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
config=self.config,
past_key_values=past_key_values,
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
config: Phi4MultimodalConfig,
past_key_values: Cache,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to place the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
config (`Phi4MultimodalConfig`):
The model's configuration class
past_key_values (`Cache`):
The cache class that is being used currently to generate
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if config.get_text_config().sliding_window is not None:
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
# the check is needed to verify is current checkpoint was trained with sliding window or not
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
sliding_attend_mask = torch.arange(target_length, device=device) <= (
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window
)
diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
causal_mask *= diagonal_attend_mask
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.shape[-1] > target_length:
attention_mask = attention_mask[:, :target_length]
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Phi4MultimodalModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
image_pixel_values: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
image_attention_mask=None,
audio_input_features: Optional[torch.FloatTensor] = None,
audio_embed_sizes=None,
audio_attention_mask=None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM
>>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA")
>>> tokenizer = AutoTokenizer.from_pretrained("TBA")
>>> prompt = "This is an example script ."
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
image_pixel_values=image_pixel_values,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
audio_input_features=audio_input_features,
audio_embed_sizes=audio_embed_sizes,
audio_attention_mask=audio_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
image_pixel_values=None,
image_sizes=None,
image_attention_mask=None,
audio_input_features=None,
audio_embed_sizes=None,
audio_attention_mask=None,
cache_position=None,
position_ids=None,
use_cache=True,
logits_to_keep=0,
**kwargs,
):
# Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
# process
# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
image_pixel_values=image_pixel_values,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
audio_input_features=audio_input_features,
audio_embed_sizes=audio_embed_sizes,
audio_attention_mask=audio_attention_mask,
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return model_inputs
__all__ = [
"Phi4MultimodalAudioPreTrainedModel",
"Phi4MultimodalAudioModel",
"Phi4MultimodalVisionPreTrainedModel",
"Phi4MultimodalVisionModel",
"Phi4MultimodalPreTrainedModel",
"Phi4MultimodalModel",
"Phi4MultimodalForCausalLM",
]
Phi4MultimodalForCausalLM.register_for_auto_class("AutoModelForCausalLM")