ramimu's picture
Upload 586 files
1c72248 verified
import math
from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING
import os
import torch
import torch.nn as nn
from dataclasses import dataclass
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
import json
if TYPE_CHECKING:
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, **kwargs):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# type: ignore
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
return keys, values
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis[:, None, :]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
head_dim: int,
n_kv_heads: int,
**kwargs,
):
super().__init__()
self.n_heads: int = n_heads
self.head_dim: int = head_dim
self.n_kv_heads: int = n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.head_dim ** -0.5
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False)
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: Optional[Any] = None,
mask: Optional['BlockDiagonalMask'] = None,
) -> torch.Tensor:
from xformers.ops.fmha import memory_efficient_attention
assert mask is None or cache is None
seqlen_sum, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if cache is None:
key, val = xk, xv
elif cache.prefill:
key, val = cache.interleave_kv(xk, xv)
cache.update(xk, xv)
else:
cache.update(xk, xv)
key, val = cache.key, cache.value
key = key.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
val = val.view(seqlen_sum * cache.max_seq_len,
self.n_kv_heads, self.head_dim)
# Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(
xq, key, val, mask if cache is None else cache.mask)
output = output.view(seqlen_sum, self.n_heads * self.head_dim)
assert isinstance(output, torch.Tensor)
return self.wo(output) # type: ignore
class TransformerBlock(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
n_heads: int,
n_kv_heads: int,
head_dim: int,
norm_eps: float,
**kwargs,
):
super().__init__()
self.n_heads = n_heads
self.dim = dim
self.attention = Attention(
dim=dim,
n_heads=n_heads,
head_dim=head_dim,
n_kv_heads=n_kv_heads,
)
self.attention_norm = RMSNorm(dim, eps=norm_eps)
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
self.feed_forward: nn.Module
self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim)
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: Optional[Any] = None,
mask: Optional['BlockDiagonalMask'] = None,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
@dataclass
class VisionEncoderArgs:
hidden_size: int
num_channels: int
image_size: int
patch_size: int
intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
rope_theta: float = 1e4 # for rope-2D
image_token_id: int = 10
def precompute_freqs_cis_2d(
dim: int,
height: int,
width: int,
theta: float,
) -> torch.Tensor:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by
(height, width) position tuples
"""
# (dim / 2) frequency bases
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height, device=freqs.device)
w = torch.arange(width, device=freqs.device)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat(
[
freqs_h[:, None, :].repeat(1, width, 1),
freqs_w[None, :, :].repeat(height, 1, 1),
],
dim=-1,
)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def position_meshgrid(
patch_embeds_list: list[torch.Tensor],
) -> torch.Tensor:
positions = torch.cat(
[
torch.stack(
torch.meshgrid(
torch.arange(p.shape[-2]),
torch.arange(p.shape[-1]),
indexing="ij",
),
dim=-1,
).reshape(-1, 2)
for p in patch_embeds_list
]
)
return positions
class PixtralVisionEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 1024,
num_channels: int = 3,
image_size: int = 1024,
patch_size: int = 16,
intermediate_size: int = 4096,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
rope_theta: float = 1e4, # for rope-2D
image_token_id: int = 10,
**kwargs,
):
super().__init__()
self.args = VisionEncoderArgs(
hidden_size=hidden_size,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
rope_theta=rope_theta,
image_token_id=image_token_id,
)
args = self.args
self.patch_conv = nn.Conv2d(
in_channels=args.num_channels,
out_channels=args.hidden_size,
kernel_size=args.patch_size,
stride=args.patch_size,
bias=False,
)
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
self.transformer = VisionTransformerBlocks(args)
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder':
if os.path.isdir(pretrained_model_name_or_path):
model_folder = pretrained_model_name_or_path
else:
model_folder = snapshot_download(pretrained_model_name_or_path)
# make sure there is a config
if not os.path.exists(os.path.join(model_folder, "config.json")):
raise ValueError(f"Could not find config.json in {model_folder}")
# load config
with open(os.path.join(model_folder, "config.json"), "r") as f:
config = json.load(f)
model = cls(**config)
# see if there is a state_dict
if os.path.exists(os.path.join(model_folder, "model.safetensors")):
state_dict = load_file(os.path.join(
model_folder, "model.safetensors"))
model.load_state_dict(state_dict)
return model
@property
def max_patches_per_side(self) -> int:
return self.args.image_size // self.args.patch_size
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def freqs_cis(self) -> torch.Tensor:
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis_2d(
dim=self.args.hidden_size // self.args.num_attention_heads,
height=self.max_patches_per_side,
width=self.max_patches_per_side,
theta=self.args.rope_theta,
)
if self._freqs_cis.device != self.device:
self._freqs_cis = self._freqs_cis.to(device=self.device)
return self._freqs_cis
def forward(
self,
images: List[torch.Tensor],
) -> torch.Tensor:
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
"""
Args:
images: list of N_img images of variable sizes, each of shape (C, H, W)
Returns:
image_features: tensor of token features for all tokens of all images of
shape (N_toks, D)
"""
assert isinstance(
images, list), f"Expected list of images, got {type(images)}"
assert all(len(img.shape) == 3 for img in
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}"
# pass images through initial convolution independently
patch_embeds_list = [self.patch_conv(
img.unsqueeze(0)).squeeze(0) for img in images]
# flatten to a single sequence
patch_embeds = torch.cat([p.flatten(1).permute(1, 0)
for p in patch_embeds_list], dim=0)
patch_embeds = self.ln_pre(patch_embeds)
# positional embeddings
positions = position_meshgrid(patch_embeds_list).to(self.device)
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
# pass through Transformer with a block diagonal mask delimiting images
mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
)
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
# remove batch dimension of the single sequence
return out # type: ignore[no-any-return]
class VisionLanguageAdapter(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.w_in = nn.Linear(
in_dim,
out_dim,
bias=True,
)
self.gelu = nn.GELU()
self.w_out = nn.Linear(out_dim, out_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# type: ignore[no-any-return]
return self.w_out(self.gelu(self.w_in(x)))
class VisionTransformerBlocks(nn.Module):
def __init__(self, args: VisionEncoderArgs):
super().__init__()
self.layers = torch.nn.ModuleList()
for _ in range(args.num_hidden_layers):
self.layers.append(
TransformerBlock(
dim=args.hidden_size,
hidden_dim=args.intermediate_size,
n_heads=args.num_attention_heads,
n_kv_heads=args.num_attention_heads,
head_dim=args.hidden_size // args.num_attention_heads,
norm_eps=1e-5,
)
)
def forward(
self,
x: torch.Tensor,
mask: 'BlockDiagonalMask',
freqs_cis: Optional[torch.Tensor],
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
return x
DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB
DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB
def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
"""
Normalize a tensor image with mean and standard deviation.
Args:
image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1].
mean (torch.Tensor): Mean for each channel.
std (torch.Tensor): Standard deviation for each channel.
Returns:
torch.Tensor: Normalized image with shape (C, H, W).
"""
assert image.shape[0] == len(mean) == len(
std), f"{image.shape=}, {mean.shape=}, {std.shape=}"
# Reshape mean and std to (C, 1, 1) for broadcasting
mean = mean.view(-1, 1, 1)
std = std.view(-1, 1, 1)
return (image - mean) / std
def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor:
"""
Resize and normalize the input image.
Args:
image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1].
new_size (tuple[int, int]): Target size (height, width) for resizing.
Returns:
torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W).
"""
# Resize the image
resized_image = torch.nn.functional.interpolate(
image.unsqueeze(0),
size=new_size,
mode='bicubic',
align_corners=False
).squeeze(0)
# Normalize the image
normalized_image = normalize(
resized_image,
torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype),
torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype)
)
return normalized_image
class PixtralVisionImagePreprocessor:
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
self.image_patch_size = image_patch_size
self.max_image_size = max_image_size
self.image_token = 10
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]:
w: Union[int, float]
h: Union[int, float]
if max_image_size is None:
max_image_size = self.max_image_size
w, h = img.shape[-1], img.shape[-2]
# originally, pixtral used the largest of the 2 dimensions, but we
# will use the base size of the image based on number of pixels.
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original
base_size = int(math.sqrt(w * h))
ratio = base_size / max_image_size
if ratio > 1:
w = round(w / ratio)
h = round(h / ratio)
width_tokens = (w - 1) // self.image_patch_size + 1
height_tokens = (h - 1) // self.image_patch_size + 1
return width_tokens, height_tokens
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor:
"""
Converts ImageChunks to numpy image arrays and image token ids
Args:
image torch tensor with values 0-1 and shape of (C, H, W)
Returns:
processed_image: tensor of token features for all tokens of all images of
"""
# should not have batch
if len(image.shape) == 4:
raise ValueError(
f"Expected image with shape (C, H, W), got {image.shape}")
if image.min() < 0.0 or image.max() > 1.0:
raise ValueError(
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}")
if max_image_size is None:
max_image_size = self.max_image_size
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size)
assert w > 0
assert h > 0
new_image_size = (
w * self.image_patch_size,
h * self.image_patch_size,
)
processed_image = transform_image(image, new_image_size)
return processed_image
class PixtralVisionImagePreprocessorCompatibleReturn:
def __init__(self, pixel_values) -> None:
self.pixel_values = pixel_values
# Compatable version with ai toolkit flow
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor):
def __init__(self, image_patch_size=16, max_image_size=1024) -> None:
super().__init__(
image_patch_size=image_patch_size,
max_image_size=max_image_size
)
self.size = {
'height': max_image_size,
'width': max_image_size
}
self.max_image_size = max_image_size
self.image_mean = DATASET_MEAN
self.image_std = DATASET_STD
def __call__(
self,
images,
return_tensors="pt",
do_resize=True,
do_rescale=False,
max_image_size=None,
) -> torch.Tensor:
if max_image_size is None:
max_image_size = self.max_image_size
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
processed_image = super().__call__(image, max_image_size=max_image_size)
out_stack.append(processed_image)
output = torch.stack(out_stack, dim=0)
return PixtralVisionImagePreprocessorCompatibleReturn(output)
class PixtralVisionEncoderCompatibleReturn:
def __init__(self, hidden_states) -> None:
self.hidden_states = hidden_states
class PixtralVisionEncoderCompatibleConfig:
def __init__(self):
self.image_size = 1024
self.hidden_size = 1024
self.patch_size = 16
class PixtralVisionEncoderCompatible(PixtralVisionEncoder):
def __init__(
self,
hidden_size: int = 1024,
num_channels: int = 3,
image_size: int = 1024,
patch_size: int = 16,
intermediate_size: int = 4096,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
rope_theta: float = 1e4, # for rope-2D
image_token_id: int = 10,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
rope_theta=rope_theta,
image_token_id=image_token_id,
)
self.config = PixtralVisionEncoderCompatibleConfig()
def forward(
self,
images,
output_hidden_states=True,
) -> torch.Tensor:
out_stack = []
if len(images.shape) == 3:
images = images.unsqueeze(0)
for i in range(images.shape[0]):
image = images[i]
# must be in an array
image_output = super().forward([image])
out_stack.append(image_output)
output = torch.stack(out_stack, dim=0)
return PixtralVisionEncoderCompatibleReturn([output])