|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import warnings |
|
from functools import wraps |
|
from typing import Callable, List, Optional, Tuple, Union, Any |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn.init import _calculate_fan_in_and_fan_out |
|
|
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutput, |
|
BaseModelOutputWithPast, |
|
BaseModelOutputWithPooling, |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
torch_int, |
|
) |
|
from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def set_attribute_for_modules(module: "torch.nn.Module", key: str, value: Any): |
|
""" |
|
Set a value to a module and all submodules. |
|
""" |
|
setattr(module, key, value) |
|
for submodule in module.children(): |
|
set_attribute_for_modules(submodule, key, value) |
|
|
|
|
|
def del_attribute_from_modules(module: "torch.nn.Module", key: str): |
|
""" |
|
Delete a value from a module and all submodules. |
|
""" |
|
|
|
if hasattr(module, key): |
|
delattr(module, key) |
|
|
|
for submodule in module.children(): |
|
del_attribute_from_modules(submodule, key) |
|
|
|
|
|
def can_return_tuple(func): |
|
""" |
|
Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or |
|
use_return_dict=False is set in the config. |
|
|
|
Note: |
|
output.to_tuple() convert output to tuple skipping all `None` values. |
|
""" |
|
|
|
@wraps(func) |
|
def wrapper(self, *args, **kwargs): |
|
is_requested_to_return_tuple = kwargs.pop("return_dict", True) is False |
|
is_configured_to_return_tuple = self.config.use_return_dict is False if hasattr(self, "config") else False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_top_level_module = getattr(self, "_is_top_level_module", True) |
|
if is_configured_to_return_tuple and is_top_level_module: |
|
set_attribute_for_modules(self, "_is_top_level_module", False) |
|
|
|
try: |
|
output = func(self, *args, **kwargs) |
|
if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module): |
|
output = output.to_tuple() |
|
finally: |
|
|
|
if is_configured_to_return_tuple and is_top_level_module: |
|
del_attribute_from_modules(self, "_is_top_level_module") |
|
|
|
return output |
|
|
|
return wrapper |
|
|
|
|
|
def dynamic_rope_update(rope_forward): |
|
""" |
|
Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE |
|
(i.e. a RoPE implementation that may recompute its frequencies in the forward pass). |
|
|
|
Args: |
|
rope_forward (Callable): |
|
The forward pass of the RoPE implementation. |
|
|
|
Returns: |
|
The decorated forward pass. |
|
""" |
|
|
|
def longrope_frequency_update(self, position_ids, device): |
|
"""Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" |
|
seq_len = torch.max(position_ids) + 1 |
|
if hasattr(self.config, "original_max_position_embeddings"): |
|
original_max_position_embeddings = self.config.original_max_position_embeddings |
|
else: |
|
original_max_position_embeddings = self.config.max_position_embeddings |
|
if seq_len > original_max_position_embeddings: |
|
if not hasattr(self, "long_inv_freq"): |
|
self.long_inv_freq, _ = self.rope_init_fn( |
|
self.config, device, seq_len=original_max_position_embeddings + 1 |
|
) |
|
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) |
|
else: |
|
|
|
|
|
self.original_inv_freq = self.original_inv_freq.to(device) |
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
|
|
|
def dynamic_frequency_update(self, position_ids, device): |
|
""" |
|
dynamic RoPE layers should recompute `inv_freq` in the following situations: |
|
1 - growing beyond the cached sequence length (allow scaling) |
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences) |
|
""" |
|
seq_len = torch.max(position_ids) + 1 |
|
if seq_len > self.max_seq_len_cached: |
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.max_seq_len_cached = seq_len |
|
|
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: |
|
|
|
|
|
self.original_inv_freq = self.original_inv_freq.to(device) |
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) |
|
self.max_seq_len_cached = self.original_max_seq_len |
|
|
|
@wraps(rope_forward) |
|
def wrapper(self, x, position_ids): |
|
if "dynamic" in self.rope_type: |
|
dynamic_frequency_update(self, position_ids, device=x.device) |
|
elif self.rope_type == "longrope": |
|
longrope_frequency_update(self, position_ids, device=x.device) |
|
return rope_forward(self, x, position_ids) |
|
|
|
return wrapper |
|
|
|
|
|
class Phi4MultimodalVisionMLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.activation_fn = ACT2FN[config.hidden_act] |
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
hidden_states = self.fc1(hidden_states) |
|
hidden_states = self.activation_fn(hidden_states) |
|
hidden_states = self.fc2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
def simple_eager_attention_forward( |
|
module: nn.Module, |
|
query_states: torch.Tensor, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
scaling: float, |
|
dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
class Phi4MultimodalVisionAttention(nn.Module): |
|
def __init__(self, config: Phi4MultimodalVisionConfig): |
|
super().__init__() |
|
self.config = config |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
self.scaling = self.head_dim**-0.5 |
|
self.is_causal = True |
|
self.attention_dropout = config.attention_dropout |
|
|
|
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
attention_interface: Callable = simple_eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1) |
|
attn_output = self.out_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
|
|
class Phi4MultimodalVisionEncoderLayer(nn.Module): |
|
def __init__(self, config: Phi4MultimodalVisionConfig): |
|
super().__init__() |
|
self.embed_dim = config.hidden_size |
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.self_attn = Phi4MultimodalVisionAttention(config) |
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.mlp = Phi4MultimodalVisionMLP(config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.FloatTensor]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): |
|
Input to the layer of shape `(batch, seq_len, embed_dim)`. |
|
attention_mask (`torch.FloatTensor`): |
|
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. |
|
output_attentions (`bool`, *optional*, defaults to `False`): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
""" |
|
residual = hidden_states |
|
|
|
hidden_states = self.layer_norm1(hidden_states) |
|
hidden_states, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class Phi4MultimodalVisionEncoder(nn.Module): |
|
""" |
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
|
[`Phi4MultimodalVisionEncoderLayer`]. |
|
|
|
Args: |
|
config: Phi4MultimodalVisionConfig |
|
""" |
|
|
|
def __init__(self, config: Phi4MultimodalVisionConfig): |
|
super().__init__() |
|
self.config = config |
|
self.layers = nn.ModuleList( |
|
[Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] |
|
) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
@can_return_tuple |
|
def forward( |
|
self, |
|
inputs_embeds, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
) -> BaseModelOutput: |
|
r""" |
|
Args: |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
than the model's internal embedding lookup matrix. |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
for more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
hidden_states = inputs_embeds |
|
for encoder_layer in self.layers: |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_states, |
|
attentions=all_attentions, |
|
) |
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
def norm_cdf(x): |
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn( |
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2, |
|
) |
|
|
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
|
|
|
def trunc_normal_tf_( |
|
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 |
|
) -> torch.Tensor: |
|
"""Fills the input Tensor with values drawn from a truncated |
|
normal distribution. The values are effectively drawn from the |
|
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` |
|
with values outside :math:`[a, b]` redrawn until they are within |
|
the bounds. The method used for generating the random values works |
|
best when :math:`a \\leq \text{mean} \\leq b`. |
|
|
|
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the |
|
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 |
|
and the result is subsequently scaled and shifted by the mean and std args. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mean: the mean of the normal distribution |
|
std: the standard deviation of the normal distribution |
|
a: the minimum cutoff value |
|
b: the maximum cutoff value |
|
""" |
|
with torch.no_grad(): |
|
_trunc_normal_(tensor, 0, 1.0, a, b) |
|
tensor.mul_(std).add_(mean) |
|
|
|
|
|
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): |
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
|
if mode == "fan_in": |
|
denom = fan_in |
|
elif mode == "fan_out": |
|
denom = fan_out |
|
elif mode == "fan_avg": |
|
denom = (fan_in + fan_out) / 2 |
|
|
|
variance = scale / denom |
|
|
|
if distribution == "truncated_normal": |
|
|
|
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) |
|
elif distribution == "normal": |
|
with torch.no_grad(): |
|
tensor.normal_(std=math.sqrt(variance)) |
|
elif distribution == "uniform": |
|
bound = math.sqrt(3 * variance) |
|
with torch.no_grad(): |
|
tensor.uniform_(-bound, bound) |
|
else: |
|
raise ValueError(f"invalid distribution {distribution}") |
|
|
|
|
|
def lecun_normal_(tensor): |
|
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") |
|
|
|
|
|
def default_flax_embed_init(tensor): |
|
variance_scaling_(tensor, mode="fan_in", distribution="normal") |
|
|
|
|
|
class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = Phi4MultimodalVisionConfig |
|
base_model_prefix = "phi4_vision" |
|
supports_gradient_checkpointing = True |
|
|
|
_no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, Phi4MultimodalVisionEmbeddings): |
|
width = ( |
|
self.config.hidden_size |
|
if isinstance(self.config, Phi4MultimodalVisionConfig) |
|
else self.config.hidden_size |
|
) |
|
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) |
|
elif isinstance(module, nn.Embedding): |
|
default_flax_embed_init(module.weight) |
|
elif isinstance(module, Phi4MultimodalVisionAttention): |
|
nn.init.normal_(module.q_proj.weight) |
|
nn.init.normal_(module.k_proj.weight) |
|
nn.init.normal_(module.v_proj.weight) |
|
nn.init.normal_(module.out_proj.weight) |
|
nn.init.zeros_(module.q_proj.bias) |
|
nn.init.zeros_(module.k_proj.bias) |
|
nn.init.zeros_(module.v_proj.bias) |
|
nn.init.zeros_(module.out_proj.bias) |
|
elif isinstance(module, Phi4MultimodalVisionMLP): |
|
nn.init.normal_(module.fc1.weight) |
|
nn.init.normal_(module.fc2.weight) |
|
nn.init.normal_(module.fc1.bias, std=1e-6) |
|
nn.init.normal_(module.fc2.bias, std=1e-6) |
|
elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): |
|
nn.init.normal_(module.probe.data) |
|
nn.init.normal_(module.attention.in_proj_weight.data) |
|
nn.init.zeros_(module.attention.in_proj_bias.data) |
|
elif isinstance(module, (nn.Linear, nn.Conv2d)): |
|
lecun_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class Phi4MultimodalVisionEmbeddings(nn.Module): |
|
def __init__(self, config: Phi4MultimodalVisionConfig): |
|
super().__init__() |
|
self.config = config |
|
self.patch_size = config.patch_size |
|
self.num_patches_per_side = config.image_size // self.patch_size |
|
|
|
self.patch_embedding = nn.Conv2d( |
|
in_channels=config.num_channels, |
|
out_channels=config.hidden_size, |
|
kernel_size=self.patch_size, |
|
stride=self.patch_size, |
|
padding="valid", |
|
) |
|
self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) |
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
""" |
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution |
|
images. This method is also adapted to support torch.jit tracing and no class embeddings. |
|
|
|
Adapted from: |
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and |
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 |
|
""" |
|
|
|
num_patches = embeddings.shape[1] |
|
num_positions = self.position_embedding.weight.shape[0] |
|
|
|
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width: |
|
return self.position_embedding(self.position_ids) |
|
|
|
patch_pos_embed = self.position_embedding.weight.unsqueeze(0) |
|
|
|
dim = embeddings.shape[-1] |
|
|
|
new_height = height // self.patch_size |
|
new_width = width // self.patch_size |
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5) |
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) |
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
|
|
|
patch_pos_embed = nn.functional.interpolate( |
|
patch_pos_embed, |
|
size=(new_height, new_width), |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
return patch_pos_embed |
|
|
|
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: |
|
batch_size = pixel_values.size(0) |
|
|
|
patch_embeds = self.patch_embedding(pixel_values) |
|
embeddings = patch_embeds.flatten(2).transpose(1, 2) |
|
|
|
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) |
|
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size |
|
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) |
|
position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) |
|
|
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask): |
|
nb_patches_h = p_attn_mask[:, 0].sum() |
|
nb_patches_w = p_attn_mask[0].sum() |
|
|
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) |
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) |
|
|
|
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) |
|
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) |
|
|
|
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() |
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids |
|
|
|
position_ids = position_ids.to(self.position_embedding.weight.device) |
|
|
|
embeddings = embeddings + self.position_embedding(position_ids) |
|
return embeddings |
|
|
|
|
|
class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Module): |
|
"""Multihead Attention Pooling.""" |
|
|
|
def __init__(self, config: Phi4MultimodalVisionConfig): |
|
super().__init__() |
|
|
|
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) |
|
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) |
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.mlp = Phi4MultimodalVisionMLP(config) |
|
|
|
def forward(self, hidden_state, attention_mask): |
|
batch_size = hidden_state.shape[0] |
|
probe = self.probe.repeat(batch_size, 1, 1) |
|
|
|
hidden_state = self.attention( |
|
query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask |
|
)[0] |
|
|
|
residual = hidden_state |
|
hidden_state = self.layernorm(hidden_state) |
|
hidden_state = residual + self.mlp(hidden_state) |
|
|
|
return hidden_state[:, 0] |
|
|
|
|
|
class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): |
|
config_class = Phi4MultimodalVisionConfig |
|
main_input_name = "pixel_values" |
|
|
|
def __init__(self, config: Phi4MultimodalVisionConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = Phi4MultimodalVisionEmbeddings(config) |
|
self.encoder = Phi4MultimodalVisionEncoder(config) |
|
self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.embeddings.patch_embedding |
|
|
|
def forward( |
|
self, |
|
pixel_values, |
|
patch_attention_mask: Optional[torch.BoolTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
) -> BaseModelOutputWithPooling: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
batch_size = pixel_values.size(0) |
|
if patch_attention_mask is None: |
|
patch_attention_mask = torch.ones( |
|
size=( |
|
batch_size, |
|
pixel_values.size(2) // self.config.patch_size, |
|
pixel_values.size(3) // self.config.patch_size, |
|
), |
|
dtype=torch.bool, |
|
device=pixel_values.device, |
|
) |
|
|
|
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) |
|
|
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1) |
|
|
|
|
|
|
|
if not torch.any(~patch_attention_mask): |
|
attention_mask = None |
|
else: |
|
attention_mask = ( |
|
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) |
|
if not self.config._attn_implementation == "flash_attention_2" |
|
else patch_attention_mask |
|
) |
|
|
|
encoder_outputs: BaseModelOutput = self.encoder( |
|
inputs_embeds=hidden_states, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
last_hidden_state = encoder_outputs.last_hidden_state |
|
last_hidden_state = self.post_layernorm(last_hidden_state) |
|
|
|
pooled_output = self.head( |
|
hidden_state=last_hidden_state, |
|
attention_mask=patch_attention_mask, |
|
) |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
|
|
class Phi4MultimodalImageEmbedding(nn.Module): |
|
"""Image embedding.""" |
|
|
|
def __init__(self, config: Phi4MultimodalConfig): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = config.vision_config.feature_layer |
|
self.crop_size = config.vision_config.crop_size |
|
self.image_dim_out = config.vision_config.hidden_size |
|
|
|
n_patches = config.vision_config.image_size // config.vision_config.patch_size |
|
if n_patches % 2 != 0: |
|
self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) |
|
n_patches += 1 |
|
self.num_img_tokens = (n_patches // 2) ** 2 |
|
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) |
|
self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) |
|
self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) |
|
self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) |
|
self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) |
|
|
|
def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: |
|
img_processor_output = self.img_processor( |
|
img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True |
|
) |
|
img_feature = img_processor_output.hidden_states[self.layer_idx] |
|
|
|
patch_feature = img_feature |
|
|
|
width = int(math.sqrt(patch_feature.size(1))) |
|
patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) |
|
|
|
patch_feature = patch_feature.permute(0, 3, 1, 2) |
|
if getattr(self, "img_processor_padding", None) is not None: |
|
patch_feature = self.img_processor_padding(patch_feature) |
|
patch_feature = self.image_token_compression(patch_feature) |
|
|
|
patch_feature = patch_feature.permute(0, 2, 3, 1) |
|
patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) |
|
return patch_feature |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
inputs_embeds: torch.Tensor, |
|
image_pixel_values: torch.FloatTensor, |
|
image_sizes: Optional[torch.Tensor] = None, |
|
image_attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.FloatTensor: |
|
image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) |
|
|
|
target_device = self.img_projection_up.bias.device |
|
target_dtype = self.img_projection_up.bias.dtype |
|
|
|
batch_size = image_pixel_values.shape[0] |
|
|
|
img_features = self.get_img_features( |
|
image_pixel_values.flatten(0, 1), |
|
attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), |
|
) |
|
base_feat_size = int(np.sqrt(img_features.shape[1])) |
|
img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) |
|
image_sizes = image_sizes.view(-1, 2) |
|
|
|
output_imgs = [] |
|
for idx in range(batch_size): |
|
height, width = image_sizes[idx] |
|
height_ratio = height // self.crop_size |
|
width_ratio = width // self.crop_size |
|
area_ratio = height_ratio * width_ratio |
|
|
|
global_img = img_features[idx, :1] |
|
global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() |
|
temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) |
|
global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) |
|
|
|
sub_img = img_features[idx, 1:] |
|
sub_img = sub_img[:area_ratio] |
|
sub_img = ( |
|
sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) |
|
.transpose(1, 2) |
|
.reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) |
|
.contiguous() |
|
) |
|
|
|
if image_attention_mask is not None: |
|
reshaped_image_attention_mask = ( |
|
image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] |
|
.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) |
|
.transpose(1, 2) |
|
.reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) |
|
) |
|
useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) |
|
useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) |
|
sub_img = sub_img[:, :useful_height, :useful_width] |
|
temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) |
|
else: |
|
temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) |
|
|
|
sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) |
|
|
|
|
|
output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) |
|
|
|
img_set_tensor = [] |
|
for output_img in output_imgs: |
|
output_img = output_img.to(device=target_device, dtype=target_dtype) |
|
img_feature_proj = self.img_projection_up(output_img) |
|
img_feature_proj = nn.functional.gelu(img_feature_proj) |
|
img_feature_proj = self.img_projection_down(img_feature_proj) |
|
img_set_tensor.append(img_feature_proj) |
|
|
|
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) |
|
merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) |
|
|
|
with torch.no_grad(): |
|
positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) |
|
|
|
|
|
|
|
with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): |
|
image_embeds = inputs_embeds.index_put( |
|
indices=positions_tuple, values=merged_img_set_tensor, accumulate=False |
|
) |
|
|
|
image_embeds = self.drop(image_embeds) |
|
|
|
return image_embeds |
|
|
|
|
|
|
|
|
|
|
|
class Phi4MultimodalAudioMLP(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
self.act_fn = ACT2FN[config.activation] |
|
self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) |
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.layer_norm(hidden_states) |
|
up_states = self.gate_up_proj(hidden_states) |
|
up_states, gate = up_states.chunk(2, dim=-1) |
|
up_states = up_states * self.act_fn(gate) |
|
up_states = self.dropout(up_states) |
|
hidden_states = self.down_proj(up_states) |
|
out = self.dropout(hidden_states) |
|
|
|
return out |
|
|
|
|
|
class Phi4MultimodalAudioAttention(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
self.config = config |
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
self.scaling = self.head_dim**-0.5 |
|
self.attention_dropout = config.dropout_rate |
|
self.is_causal = True |
|
|
|
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) |
|
self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) |
|
self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) |
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
**kwargs, |
|
): |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
attention_interface: Callable = simple_eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
attn_output, _ = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output |
|
|
|
|
|
class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): |
|
super().__init__() |
|
self.dw_conv = nn.Conv1d( |
|
config.hidden_size, |
|
config.hidden_size * config.depthwise_multiplier, |
|
config.kernel_size, |
|
1, |
|
padding=padding, |
|
groups=config.hidden_size, |
|
) |
|
self.pw_conv = nn.Conv1d( |
|
config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0 |
|
) |
|
|
|
def forward(self, hidden_states): |
|
return self.pw_conv(self.dw_conv(hidden_states)) |
|
|
|
|
|
class Phi4MultimodalAudioGluPointWiseConv(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
self.config = config |
|
self.output_dim = config.ext_pw_out_channel |
|
|
|
self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) |
|
self.glu_act = ACT2FN[config.conv_glu_type] |
|
self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) |
|
self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) |
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
hidden_states = hidden_states.permute([0, 2, 1]) |
|
hidden_states = self.ext_pw_conv_1d(hidden_states) |
|
out = hidden_states[:, 0 : self.output_dim, :] + self.b1 |
|
out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) |
|
return out.permute([0, 2, 1]) |
|
|
|
|
|
class Phi4MultimodalAudioConvModule(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
self.config = config |
|
self.kernel_size = config.kernel_size |
|
|
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
self.glu = Phi4MultimodalAudioGluPointWiseConv(config) |
|
self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1) |
|
self.act = ACT2FN[config.conv_activation] |
|
self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
hidden_states = self.glu(self.layer_norm(hidden_states)) |
|
hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) |
|
|
|
if self.kernel_size > 1: |
|
hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] |
|
|
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.ext_pw_conv_1d(hidden_states) |
|
out = self.dropout(hidden_states.permute([0, 2, 1])) |
|
return out |
|
|
|
|
|
class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
|
|
self.feed_forward_in = Phi4MultimodalAudioMLP(config) |
|
self.self_attn = Phi4MultimodalAudioAttention(config) |
|
self.conv = Phi4MultimodalAudioConvModule(config) |
|
self.feed_forward_out = Phi4MultimodalAudioMLP(config) |
|
self.layer_norm_att = nn.LayerNorm(config.hidden_size) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
): |
|
residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) |
|
hidden_states = self.layer_norm_att(residual) |
|
|
|
hidden_states = residual + self.self_attn(hidden_states, attention_mask) |
|
hidden_states = hidden_states + self.conv(hidden_states) |
|
hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) |
|
|
|
out = self.layer_norm(hidden_states) |
|
|
|
return out |
|
|
|
|
|
class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
self.subsampling_factor = config.time_reduction |
|
self.sampling_num = int(math.log(self.subsampling_factor, 2)) |
|
self.act_fn = ACT2FN[config.nemo_activation] |
|
conv_channels = config.nemo_conv_channels |
|
|
|
layers = [ |
|
nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), |
|
self.act_fn, |
|
] |
|
for _ in range(self.sampling_num - 1): |
|
layers.extend( |
|
[ |
|
nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), |
|
nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), |
|
self.act_fn, |
|
] |
|
) |
|
|
|
|
|
self.conv = torch.nn.Sequential(*layers) |
|
self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) |
|
|
|
def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): |
|
|
|
hidden_states = hidden_states.unsqueeze(1) |
|
hidden_states = self.conv(hidden_states) |
|
|
|
|
|
b, _, t, _ = hidden_states.size() |
|
hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) |
|
|
|
if mask is None: |
|
return hidden_states, None |
|
|
|
max_audio_length = hidden_states.shape[1] |
|
feature_lens = mask.sum(1) |
|
padding_length = torch.ceil(feature_lens / self.subsampling_factor) |
|
arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) |
|
pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) |
|
return hidden_states, pad_mask.unsqueeze(1) |
|
|
|
|
|
class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
|
|
self.max_distance = config.bias_max_distance |
|
self.symmetric = config.bias_symmetric |
|
self.num_buckets = self.max_distance |
|
if not config.bias_symmetric: |
|
self.num_buckets *= 2 |
|
self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) |
|
|
|
def forward(self, x): |
|
|
|
max_pos = x.size(1) |
|
context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] |
|
memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] |
|
relative_position = memory_position - context_position |
|
|
|
relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) |
|
relative_position = relative_position.masked_fill( |
|
relative_position > self.max_distance - 1, self.max_distance - 1 |
|
) |
|
|
|
|
|
bias_idx = relative_position |
|
bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 |
|
|
|
att_bias = self.bias_values(bias_idx) |
|
att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) |
|
|
|
return att_bias |
|
|
|
|
|
class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__() |
|
self.register_buffer("global_mean", torch.zeros(config.input_size)) |
|
self.register_buffer("global_invstd", torch.ones(config.input_size)) |
|
|
|
def forward(self, x): |
|
return (x - self.global_mean) * self.global_invstd |
|
|
|
|
|
class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): |
|
config_class = Phi4MultimodalAudioConfig |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): |
|
module.b1.data.zero_() |
|
module.b2.data.zero_() |
|
|
|
|
|
def unfold_tensor(tensor, max_seq_len): |
|
""" |
|
For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, |
|
this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. |
|
Args: |
|
tensor: N, T, D |
|
""" |
|
_, _, D = tensor.shape |
|
tensor = tensor.transpose(-1, -2) |
|
|
|
tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) |
|
|
|
new_bsz, _, slen = tensor.shape |
|
tensor = tensor.view(new_bsz, -1, max_seq_len, slen) |
|
tensor = tensor.permute(0, 3, 2, 1) |
|
tensor = tensor.view(-1, max_seq_len, D).contiguous() |
|
return tensor |
|
|
|
|
|
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): |
|
""" |
|
The function is very important for Transformer Transducer Streaming mode |
|
Args: |
|
xs_len (int): sequence length |
|
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] |
|
left_window (int): how many left chunks can be seen |
|
right_window (int): how many right chunks can be seen. It is used for chunk overlap model. |
|
Returns: |
|
mask (torch.Tensor): a mask tensor for streaming model |
|
""" |
|
chunk_start_idx = torch.Tensor(chunk_start_idx).long() |
|
start_pad = torch.nn.functional.pad( |
|
chunk_start_idx, (1, 0) |
|
) |
|
end_pad = torch.nn.functional.pad( |
|
chunk_start_idx, (0, 1), value=x_len |
|
) |
|
seq_range = torch.arange(0, x_len).unsqueeze(-1) |
|
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] |
|
seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) |
|
idx_left = idx - left_window |
|
idx_left[idx_left < 0] = 0 |
|
boundary_left = start_pad[idx_left] |
|
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) |
|
idx_right = idx + right_window |
|
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) |
|
boundary_right = end_pad[idx_right] |
|
mask_right = seq_range_expand < boundary_right.unsqueeze(-1) |
|
return mask_left & mask_right |
|
|
|
|
|
class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): |
|
def __init__(self, config: Phi4MultimodalAudioConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) |
|
self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) |
|
self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) |
|
self.encoders = nn.ModuleList( |
|
[Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] |
|
) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): |
|
|
|
|
|
chunk_start_idx = np.arange(0, seq_len, chunk_size) |
|
|
|
if self.training and np.random.rand() > 0.5: |
|
|
|
|
|
chunk_start_idx = seq_len - chunk_start_idx |
|
chunk_start_idx = chunk_start_idx[::-1] |
|
chunk_start_idx = chunk_start_idx[:-1] |
|
chunk_start_idx = np.insert(chunk_start_idx, 0, 0) |
|
|
|
enc_streaming_mask = ( |
|
adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) |
|
.unsqueeze(0) |
|
.expand([batch_size, -1, -1]) |
|
) |
|
return enc_streaming_mask |
|
|
|
def forward_embeddings(self, hidden_states, masks): |
|
"""Forwarding the inputs through the top embedding layers""" |
|
seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) |
|
if seq_len <= 0: |
|
raise ValueError( |
|
f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short." |
|
) |
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) |
|
enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) |
|
|
|
hidden_states, masks = self.embed(hidden_states, masks) |
|
|
|
streaming_mask = enc_streaming_mask |
|
if streaming_mask is not None and masks is not None: |
|
hs_mask = masks & streaming_mask |
|
elif masks is not None: |
|
hs_mask = masks |
|
else: |
|
hs_mask = streaming_mask |
|
|
|
return hidden_states, hs_mask, masks |
|
|
|
def calculate_hs_mask(self, hidden_states, device, mask): |
|
max_audio_length = hidden_states.shape[1] |
|
batch_size = hidden_states.shape[0] |
|
enc_streaming_mask = self._streaming_mask( |
|
max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk |
|
) |
|
enc_streaming_mask = enc_streaming_mask.to(device) |
|
if mask is None: |
|
return enc_streaming_mask |
|
|
|
feature_lens = mask.sum(1) |
|
padding_length = feature_lens |
|
pad_mask = torch.arange(0, max_audio_length, device=device).expand( |
|
padding_length.size(0), -1 |
|
) < padding_length.unsqueeze(1) |
|
pad_mask = pad_mask.unsqueeze(1) |
|
pad_mask = pad_mask & enc_streaming_mask |
|
return pad_mask |
|
|
|
def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): |
|
hidden_states = self.encoder_embedding(hidden_states) |
|
hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) |
|
|
|
unfolded = False |
|
bs, seq_len, _ = hidden_states.shape |
|
max_seq_len = 500 |
|
if seq_len > max_seq_len: |
|
|
|
unfolded = True |
|
|
|
if seq_len % max_seq_len > 0: |
|
chunk_pad_size = max_seq_len - (seq_len % max_seq_len) |
|
else: |
|
chunk_pad_size = 0 |
|
if chunk_pad_size > 0: |
|
hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) |
|
hidden_states = hidden_states_pad.to(hidden_states.device) |
|
|
|
hidden_states = unfold_tensor(hidden_states, max_seq_len) |
|
masks_unfold = None |
|
if mask is not None: |
|
|
|
subsampled_pad_mask = mask.squeeze(1) |
|
extra_padded_subsamlped_pad_mask = F.pad( |
|
subsampled_pad_mask, (0, chunk_pad_size), "constant", False |
|
) |
|
extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() |
|
masks_unfold = unfold_tensor( |
|
extra_padded_subsamlped_pad_mask, max_seq_len |
|
) |
|
masks_unfold = masks_unfold.squeeze(-1).bool() |
|
hs_mask = self.calculate_hs_mask( |
|
hidden_states, hidden_states.device, masks_unfold |
|
) |
|
|
|
relative_attention_bias = self.relative_attention_bias_layer(hidden_states) |
|
attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias |
|
|
|
for layer in self.encoders: |
|
if self.gradient_checkpointing and self.training: |
|
hidden_states = self._gradient_checkpointing_func( |
|
layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
) |
|
else: |
|
hidden_states = layer(hidden_states, attention_mask) |
|
|
|
if unfolded: |
|
embed_dim = hidden_states.shape[-1] |
|
hidden_states = hidden_states.reshape(bs, -1, embed_dim) |
|
|
|
if chunk_pad_size > 0: |
|
hidden_states = hidden_states[:, :-chunk_pad_size, :] |
|
|
|
return hidden_states |
|
|
|
|
|
class Phi4MultimodalAudioEmbedding(nn.Module): |
|
def __init__(self, config: Phi4MultimodalConfig): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = config.audio_config.feature_layer |
|
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) |
|
self.up_proj_for_speech = nn.Linear( |
|
config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size |
|
) |
|
self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.up_proj_for_vision_speech = nn.Linear( |
|
config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size |
|
) |
|
self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
inputs_embeds: torch.Tensor, |
|
audio_input_features: torch.FloatTensor, |
|
audio_embed_sizes=None, |
|
audio_attention_mask=None, |
|
audio_projection_mode="speech", |
|
) -> torch.FloatTensor: |
|
with torch.no_grad(): |
|
positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) |
|
|
|
up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech |
|
down_proj = ( |
|
self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech |
|
) |
|
|
|
target_device = up_proj.bias.device |
|
target_dtype = up_proj.bias.dtype |
|
|
|
audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) |
|
|
|
audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) |
|
audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) |
|
audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) |
|
audio_embeds = down_proj(audio_encoder_hidden_states) |
|
|
|
merged_audio_embeds = torch.cat( |
|
[audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 |
|
) |
|
merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) |
|
|
|
|
|
with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): |
|
audio_embeds = inputs_embeds.index_put( |
|
indices=positions_tuple, values=merged_audio_embeds, accumulate=False |
|
) |
|
|
|
audio_embeds = self.drop(audio_embeds) |
|
|
|
return audio_embeds |
|
|
|
|
|
class Phi4MultimodalRMSNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-6): |
|
""" |
|
Phi4MultimodalRMSNorm is equivalent to T5LayerNorm |
|
""" |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.variance_epsilon = eps |
|
|
|
def forward(self, hidden_states): |
|
input_dtype = hidden_states.dtype |
|
hidden_states = hidden_states.to(torch.float32) |
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
def extra_repr(self): |
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
|
class Phi4MultimodalMLP(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.config = config |
|
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
self.activation_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: |
|
up_states = self.gate_up_proj(hidden_states) |
|
|
|
gate, up_states = up_states.chunk(2, dim=-1) |
|
up_states = up_states * self.activation_fn(gate) |
|
|
|
return self.down_proj(up_states) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def eager_attention_forward( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
scaling: float, |
|
dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
key_states = repeat_kv(key, module.num_key_value_groups) |
|
value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`, *optional*): |
|
Deprecated and unused. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
|
rotary_dim = cos.shape[-1] |
|
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] |
|
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] |
|
|
|
q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) |
|
k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) |
|
return q_embed, k_embed |
|
|
|
|
|
class Phi4MultimodalAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) |
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.scaling = self.head_dim**-0.5 |
|
self.attention_dropout = config.attention_dropout |
|
self.is_causal = True |
|
|
|
op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) |
|
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) |
|
self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
|
attention_mask: Optional[torch.Tensor], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
qkv = self.qkv_proj(hidden_states) |
|
query_pos = self.config.num_attention_heads * self.head_dim |
|
query_states = qkv[..., :query_pos] |
|
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] |
|
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] |
|
|
|
query_states = query_states.view(hidden_shape).transpose(1, 2) |
|
key_states = key_states.view(hidden_shape).transpose(1, 2) |
|
value_states = value_states.view(hidden_shape).transpose(1, 2) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
attention_interface: Callable = eager_attention_forward |
|
if self.config._attn_implementation != "eager": |
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
|
logger.warning_once( |
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
) |
|
else: |
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
scaling=self.scaling, |
|
sliding_window=getattr(self.config, "sliding_window", None), |
|
**kwargs, |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, attn_weights |
|
|
|
|
|
class Phi4MultimodalDecoderLayer(nn.Module): |
|
def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) |
|
self.mlp = Phi4MultimodalMLP(config) |
|
self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.config = config |
|
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) |
|
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: Optional[bool] = False, |
|
use_cache: Optional[bool] = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
|
""" |
|
Args: |
|
hidden_states (`torch.FloatTensor`): |
|
input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range |
|
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) |
|
past_key_value (`Cache`, *optional*): cached past key and value projection states |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
|
(see `past_key_values`). |
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
Indices depicting the position of the input sequence tokens in the sequence |
|
kwargs (`dict`, *optional*): |
|
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code |
|
into the model |
|
""" |
|
residual = hidden_states |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
hidden_states, self_attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
**kwargs, |
|
) |
|
hidden_states = residual + self.resid_attn_dropout(hidden_states) |
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + self.resid_mlp_dropout(hidden_states) |
|
|
|
outputs = (hidden_states,) |
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class Phi4MultimodalFeatureEmbedding(nn.Module): |
|
"""Image-audio embedding.""" |
|
|
|
def __init__(self, config: Phi4MultimodalConfig) -> None: |
|
super().__init__() |
|
self.config = config |
|
self.image_token_id = config.vision_config.image_token_id |
|
self.audio_token_id = config.audio_config.audio_token_id |
|
self.image_embed = Phi4MultimodalImageEmbedding(config) |
|
self.audio_embed = Phi4MultimodalAudioEmbedding(config) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
inputs_embeds: torch.Tensor, |
|
image_pixel_values: Optional[torch.FloatTensor] = None, |
|
audio_input_features: Optional[torch.FloatTensor] = None, |
|
image_sizes=None, |
|
image_attention_mask=None, |
|
audio_embed_sizes=None, |
|
audio_attention_mask=None, |
|
) -> torch.FloatTensor: |
|
with torch.no_grad(): |
|
image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) |
|
non_image_position_mask = ~image_position_mask |
|
|
|
image_embeds = None |
|
audio_embeds = None |
|
if image_pixel_values is not None and (input_ids == self.image_token_id).any(): |
|
image_embeds = self.image_embed( |
|
input_ids, |
|
inputs_embeds, |
|
image_pixel_values=image_pixel_values, |
|
image_sizes=image_sizes, |
|
image_attention_mask=image_attention_mask, |
|
) |
|
if audio_input_features is not None and (input_ids == self.audio_token_id).any(): |
|
audio_projection_mode = "vision" if image_pixel_values is not None else "speech" |
|
audio_embeds = self.audio_embed( |
|
input_ids, |
|
inputs_embeds, |
|
audio_input_features=audio_input_features, |
|
audio_embed_sizes=audio_embed_sizes, |
|
audio_attention_mask=audio_attention_mask, |
|
audio_projection_mode=audio_projection_mode, |
|
) |
|
|
|
|
|
if image_embeds is not None and audio_embeds is not None: |
|
inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask |
|
elif image_embeds is not None: |
|
inputs_embeds = image_embeds |
|
elif audio_embeds is not None: |
|
inputs_embeds = audio_embeds |
|
|
|
return inputs_embeds |
|
|
|
|
|
class Phi4MultimodalRotaryEmbedding(nn.Module): |
|
def __init__(self, config: Phi4MultimodalConfig, device=None): |
|
super().__init__() |
|
|
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
|
else: |
|
self.rope_type = "default" |
|
self.max_seq_len_cached = config.max_position_embeddings |
|
self.original_max_seq_len = config.max_position_embeddings |
|
|
|
self.config = config |
|
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
self.original_inv_freq = self.inv_freq |
|
|
|
@torch.no_grad() |
|
@dynamic_rope_update |
|
def forward(self, x, position_ids): |
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
|
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
|
with torch.autocast(device_type=device_type, enabled=False): |
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
cos = emb.cos() * self.attention_scaling |
|
sin = emb.sin() * self.attention_scaling |
|
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
PHI4_MULTIMODAL_START_DOCSTRING = r""" |
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
|
and behavior. |
|
|
|
Parameters: |
|
config ([`Phi4MultimodalConfig`]): |
|
Model configuration class with all the parameters of the model. Initializing with a config file does not |
|
load the weights associated with the model, only the configuration. Check out the |
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", |
|
PHI4_MULTIMODAL_START_DOCSTRING, |
|
) |
|
class Phi4MultimodalPreTrainedModel(PreTrainedModel): |
|
config_class = Phi4MultimodalConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["Phi4MultimodalDecoderLayer"] |
|
_skip_keys_device_placement = ["past_key_values"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
_supports_flex_attn = True |
|
_supports_cache_class = True |
|
_supports_quantized_cache = True |
|
_supports_static_cache = True |
|
_supports_attention_backend = True |
|
_version = "0.0.5" |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, Phi4MultimodalRMSNorm): |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, Phi4MultimodalImageEmbedding): |
|
module.global_img_feature_extensor.data.zero_() |
|
module.sub_img_feature_extensor.data.zero_() |
|
|
|
|
|
PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
|
it. |
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: |
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
[What are attention masks?](../glossary#attention-mask) |
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
config.n_positions - 1]`. |
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
past_key_values (`Cache`)`, *optional*): |
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` |
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. |
|
See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); |
|
|
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't |
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` |
|
of shape `(batch_size, sequence_length)`. |
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
model's internal embedding lookup matrix. |
|
image_pixel_values (`torch.FloatTensor`, *optional*): |
|
If the input contains images, these correspond to the pixel values after transformations (as returned by |
|
the Processor) |
|
image_sizes (`torch.LongTensor`, *optional*): |
|
If the input contains images, these correspond to size of each image. |
|
image_attention_mask (`torch.LongTensor`, *optional*): |
|
Attention mask for the images. |
|
audio_input_features (`torch.FloatTensor`, *optional*): |
|
If the input contains audio samples, these correspond to the values after transformation (as returned by |
|
the Processor). |
|
audio_embed_sizes (`torch.Tensor`, *optional*): |
|
Size of the audio inputs. |
|
audio_attention_mask (`torch.Tensor, *optional*): |
|
Attention mask for the audio inputs. |
|
use_cache (`bool`, *optional*): |
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
|
`past_key_values`). |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
|
tensors for more detail. |
|
output_hidden_states (`bool`, *optional*): |
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
|
more detail. |
|
return_dict (`bool`, *optional*): |
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
|
the complete sequence length. |
|
""" |
|
|
|
|
|
@add_start_docstrings( |
|
"The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", |
|
PHI4_MULTIMODAL_START_DOCSTRING, |
|
) |
|
class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): |
|
""" |
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] |
|
Args: |
|
config: Phi4MultimodalMMConfig |
|
""" |
|
|
|
def __init__(self, config: Phi4MultimodalConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
|
self.layers = nn.ModuleList( |
|
[Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config) |
|
|
|
self.gradient_checkpointing = False |
|
self.embed_dropout = nn.Dropout(config.embd_pdrop) |
|
|
|
self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.embed_tokens = value |
|
|
|
@can_return_tuple |
|
@add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
image_pixel_values: Optional[torch.FloatTensor] = None, |
|
image_sizes: Optional[torch.LongTensor] = None, |
|
image_attention_mask=None, |
|
audio_input_features: Optional[torch.FloatTensor] = None, |
|
audio_embed_sizes=None, |
|
audio_attention_mask=None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
) -> BaseModelOutputWithPast: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
if use_cache and past_key_values is None: |
|
past_key_values = DynamicCache() |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
inputs_embeds = self.embed_tokens_extend( |
|
input_ids, |
|
inputs_embeds, |
|
image_pixel_values=image_pixel_values, |
|
audio_input_features=audio_input_features, |
|
image_sizes=image_sizes, |
|
image_attention_mask=image_attention_mask, |
|
audio_embed_sizes=audio_embed_sizes, |
|
audio_attention_mask=audio_attention_mask, |
|
) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
causal_mask = self._update_causal_mask( |
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
|
|
for decoder_layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
causal_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
position_embeddings, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=past_key_values if use_cache else None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: Union[torch.Tensor, "BlockMask"], |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool = False, |
|
): |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and past_key_values is not None: |
|
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] |
|
if is_padding_right: |
|
raise ValueError( |
|
"You are attempting to perform batched generation with padding_side='right'" |
|
" this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to " |
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
|
) |
|
if attention_mask is not None and 0.0 in attention_mask: |
|
return attention_mask |
|
return None |
|
if self.config._attn_implementation == "flex_attention": |
|
if isinstance(attention_mask, torch.Tensor): |
|
attention_mask = make_flex_block_causal_mask(attention_mask) |
|
return attention_mask |
|
|
|
|
|
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
using_static_cache = isinstance(past_key_values, StaticCache) |
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) |
|
|
|
|
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and not (using_static_cache or using_sliding_window_cache) |
|
and not output_attentions |
|
): |
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
attention_mask, |
|
inputs_embeds=input_tensor, |
|
past_key_values_length=past_seen_tokens, |
|
sliding_window=self.config.sliding_window, |
|
is_training=self.training, |
|
): |
|
return None |
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
min_dtype = torch.finfo(dtype).min |
|
sequence_length = input_tensor.shape[1] |
|
|
|
if using_sliding_window_cache or using_static_cache: |
|
target_length = past_key_values.get_max_cache_shape() |
|
|
|
else: |
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else past_seen_tokens + sequence_length + 1 |
|
) |
|
|
|
|
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=target_length, |
|
dtype=dtype, |
|
device=device, |
|
cache_position=cache_position, |
|
batch_size=input_tensor.shape[0], |
|
config=self.config, |
|
past_key_values=past_key_values, |
|
) |
|
|
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and attention_mask is not None |
|
and attention_mask.device.type in ["cuda", "xpu", "npu"] |
|
and not output_attentions |
|
): |
|
|
|
|
|
|
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
|
return causal_mask |
|
|
|
@staticmethod |
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask: torch.Tensor, |
|
sequence_length: int, |
|
target_length: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
cache_position: torch.Tensor, |
|
batch_size: int, |
|
config: Phi4MultimodalConfig, |
|
past_key_values: Cache, |
|
): |
|
""" |
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
Args: |
|
attention_mask (`torch.Tensor`): |
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
|
sequence_length (`int`): |
|
The sequence length being processed. |
|
target_length (`int`): |
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
|
dtype (`torch.dtype`): |
|
The dtype to use for the 4D attention mask. |
|
device (`torch.device`): |
|
The device to place the 4D attention mask on. |
|
cache_position (`torch.Tensor`): |
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
batch_size (`torch.Tensor`): |
|
Batch size. |
|
config (`Phi4MultimodalConfig`): |
|
The model's configuration class |
|
past_key_values (`Cache`): |
|
The cache class that is being used currently to generate |
|
""" |
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
causal_mask = attention_mask |
|
else: |
|
min_dtype = torch.finfo(dtype).min |
|
causal_mask = torch.full( |
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
|
) |
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
if config.get_text_config().sliding_window is not None: |
|
|
|
|
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: |
|
sliding_attend_mask = torch.arange(target_length, device=device) <= ( |
|
cache_position.reshape(-1, 1) - config.get_text_config().sliding_window |
|
) |
|
diagonal_attend_mask.bitwise_or_(sliding_attend_mask) |
|
causal_mask *= diagonal_attend_mask |
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
if attention_mask is not None: |
|
causal_mask = causal_mask.clone() |
|
if attention_mask.shape[-1] > target_length: |
|
attention_mask = attention_mask[:, :target_length] |
|
mask_length = attention_mask.shape[-1] |
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
|
causal_mask.device |
|
) |
|
padding_mask = padding_mask == 0 |
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
padding_mask, min_dtype |
|
) |
|
return causal_mask |
|
|
|
|
|
class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
_tp_plan = {"lm_head": "colwise_rep"} |
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = Phi4MultimodalModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
@can_return_tuple |
|
@add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
image_pixel_values: Optional[torch.FloatTensor] = None, |
|
image_sizes: Optional[torch.LongTensor] = None, |
|
image_attention_mask=None, |
|
audio_input_features: Optional[torch.FloatTensor] = None, |
|
audio_embed_sizes=None, |
|
audio_attention_mask=None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**kwargs, |
|
) -> CausalLMOutputWithPast: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
logits_to_keep (`int` or `torch.Tensor`, *optional*): |
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
|
This is useful when using packed tensor format (single dimension for batch and sequence length). |
|
Returns: |
|
|
|
Example: |
|
```python |
|
>>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM |
|
>>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("TBA") |
|
>>> prompt = "This is an example script ." |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
|
|
outputs: BaseModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
image_pixel_values=image_pixel_values, |
|
image_sizes=image_sizes, |
|
image_attention_mask=image_attention_mask, |
|
audio_input_features=audio_input_features, |
|
audio_embed_sizes=audio_embed_sizes, |
|
audio_attention_mask=audio_attention_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.vocab_size) |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
image_pixel_values=None, |
|
image_sizes=None, |
|
image_attention_mask=None, |
|
audio_input_features=None, |
|
audio_embed_sizes=None, |
|
audio_attention_mask=None, |
|
cache_position=None, |
|
position_ids=None, |
|
use_cache=True, |
|
logits_to_keep=0, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
past_key_values |
|
and self.config.rope_scaling |
|
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 |
|
): |
|
past_length = cache_position[0] |
|
if past_length <= self.config.original_max_position_embeddings: |
|
past_key_values = None |
|
|
|
model_inputs = super().prepare_inputs_for_generation( |
|
input_ids=input_ids, |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
image_pixel_values=image_pixel_values, |
|
image_sizes=image_sizes, |
|
image_attention_mask=image_attention_mask, |
|
audio_input_features=audio_input_features, |
|
audio_embed_sizes=audio_embed_sizes, |
|
audio_attention_mask=audio_attention_mask, |
|
cache_position=cache_position, |
|
position_ids=position_ids, |
|
use_cache=use_cache, |
|
logits_to_keep=logits_to_keep, |
|
**kwargs, |
|
) |
|
return model_inputs |
|
|
|
|
|
__all__ = [ |
|
"Phi4MultimodalAudioPreTrainedModel", |
|
"Phi4MultimodalAudioModel", |
|
"Phi4MultimodalVisionPreTrainedModel", |
|
"Phi4MultimodalVisionModel", |
|
"Phi4MultimodalPreTrainedModel", |
|
"Phi4MultimodalModel", |
|
"Phi4MultimodalForCausalLM", |
|
] |
|
|
|
|
|
Phi4MultimodalForCausalLM.register_for_auto_class("AutoModelForCausalLM") |