Spaces:
Paused
Paused
import math | |
from typing import List, Optional, Tuple, Any, Union, TYPE_CHECKING | |
import os | |
import torch | |
import torch.nn as nn | |
from dataclasses import dataclass | |
from huggingface_hub import snapshot_download | |
from safetensors.torch import load_file | |
import json | |
if TYPE_CHECKING: | |
from xformers.ops.fmha.attn_bias import BlockDiagonalMask | |
class RMSNorm(torch.nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def _norm(self, x: torch.Tensor) -> torch.Tensor: | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
output = self._norm(x.float()).type_as(x) | |
return output * self.weight | |
class FeedForward(nn.Module): | |
def __init__(self, dim: int, hidden_dim: int, **kwargs): | |
super().__init__() | |
self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# type: ignore | |
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) | |
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim) | |
values = torch.repeat_interleave(values, repeats=repeats, dim=dim) | |
return keys, values | |
def apply_rotary_emb( | |
xq: torch.Tensor, | |
xk: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
freqs_cis = freqs_cis[:, None, :] | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
n_heads: int, | |
head_dim: int, | |
n_kv_heads: int, | |
**kwargs, | |
): | |
super().__init__() | |
self.n_heads: int = n_heads | |
self.head_dim: int = head_dim | |
self.n_kv_heads: int = n_kv_heads | |
self.repeats = self.n_heads // self.n_kv_heads | |
self.scale = self.head_dim ** -0.5 | |
self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) | |
self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) | |
def forward( | |
self, | |
x: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
cache: Optional[Any] = None, | |
mask: Optional['BlockDiagonalMask'] = None, | |
) -> torch.Tensor: | |
from xformers.ops.fmha import memory_efficient_attention | |
assert mask is None or cache is None | |
seqlen_sum, _ = x.shape | |
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) | |
xq = xq.view(seqlen_sum, self.n_heads, self.head_dim) | |
xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim) | |
xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim) | |
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) | |
if cache is None: | |
key, val = xk, xv | |
elif cache.prefill: | |
key, val = cache.interleave_kv(xk, xv) | |
cache.update(xk, xv) | |
else: | |
cache.update(xk, xv) | |
key, val = cache.key, cache.value | |
key = key.view(seqlen_sum * cache.max_seq_len, | |
self.n_kv_heads, self.head_dim) | |
val = val.view(seqlen_sum * cache.max_seq_len, | |
self.n_kv_heads, self.head_dim) | |
# Repeat keys and values to match number of query heads | |
key, val = repeat_kv(key, val, self.repeats, dim=1) | |
# xformers requires (B=1, S, H, D) | |
xq, key, val = xq[None, ...], key[None, ...], val[None, ...] | |
output = memory_efficient_attention( | |
xq, key, val, mask if cache is None else cache.mask) | |
output = output.view(seqlen_sum, self.n_heads * self.head_dim) | |
assert isinstance(output, torch.Tensor) | |
return self.wo(output) # type: ignore | |
class TransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
n_heads: int, | |
n_kv_heads: int, | |
head_dim: int, | |
norm_eps: float, | |
**kwargs, | |
): | |
super().__init__() | |
self.n_heads = n_heads | |
self.dim = dim | |
self.attention = Attention( | |
dim=dim, | |
n_heads=n_heads, | |
head_dim=head_dim, | |
n_kv_heads=n_kv_heads, | |
) | |
self.attention_norm = RMSNorm(dim, eps=norm_eps) | |
self.ffn_norm = RMSNorm(dim, eps=norm_eps) | |
self.feed_forward: nn.Module | |
self.feed_forward = FeedForward(dim=dim, hidden_dim=hidden_dim) | |
def forward( | |
self, | |
x: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
cache: Optional[Any] = None, | |
mask: Optional['BlockDiagonalMask'] = None, | |
) -> torch.Tensor: | |
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache) | |
h = x + r | |
r = self.feed_forward.forward(self.ffn_norm(h)) | |
out = h + r | |
return out | |
class VisionEncoderArgs: | |
hidden_size: int | |
num_channels: int | |
image_size: int | |
patch_size: int | |
intermediate_size: int | |
num_hidden_layers: int | |
num_attention_heads: int | |
rope_theta: float = 1e4 # for rope-2D | |
image_token_id: int = 10 | |
def precompute_freqs_cis_2d( | |
dim: int, | |
height: int, | |
width: int, | |
theta: float, | |
) -> torch.Tensor: | |
""" | |
freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by | |
(height, width) position tuples | |
""" | |
# (dim / 2) frequency bases | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
h = torch.arange(height, device=freqs.device) | |
w = torch.arange(width, device=freqs.device) | |
freqs_h = torch.outer(h, freqs[::2]).float() | |
freqs_w = torch.outer(w, freqs[1::2]).float() | |
freqs_2d = torch.cat( | |
[ | |
freqs_h[:, None, :].repeat(1, width, 1), | |
freqs_w[None, :, :].repeat(height, 1, 1), | |
], | |
dim=-1, | |
) | |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d) | |
def position_meshgrid( | |
patch_embeds_list: list[torch.Tensor], | |
) -> torch.Tensor: | |
positions = torch.cat( | |
[ | |
torch.stack( | |
torch.meshgrid( | |
torch.arange(p.shape[-2]), | |
torch.arange(p.shape[-1]), | |
indexing="ij", | |
), | |
dim=-1, | |
).reshape(-1, 2) | |
for p in patch_embeds_list | |
] | |
) | |
return positions | |
class PixtralVisionEncoder(nn.Module): | |
def __init__( | |
self, | |
hidden_size: int = 1024, | |
num_channels: int = 3, | |
image_size: int = 1024, | |
patch_size: int = 16, | |
intermediate_size: int = 4096, | |
num_hidden_layers: int = 24, | |
num_attention_heads: int = 16, | |
rope_theta: float = 1e4, # for rope-2D | |
image_token_id: int = 10, | |
**kwargs, | |
): | |
super().__init__() | |
self.args = VisionEncoderArgs( | |
hidden_size=hidden_size, | |
num_channels=num_channels, | |
image_size=image_size, | |
patch_size=patch_size, | |
intermediate_size=intermediate_size, | |
num_hidden_layers=num_hidden_layers, | |
num_attention_heads=num_attention_heads, | |
rope_theta=rope_theta, | |
image_token_id=image_token_id, | |
) | |
args = self.args | |
self.patch_conv = nn.Conv2d( | |
in_channels=args.num_channels, | |
out_channels=args.hidden_size, | |
kernel_size=args.patch_size, | |
stride=args.patch_size, | |
bias=False, | |
) | |
self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) | |
self.transformer = VisionTransformerBlocks(args) | |
head_dim = self.args.hidden_size // self.args.num_attention_heads | |
assert head_dim % 2 == 0, "ROPE requires even head_dim" | |
self._freqs_cis: Optional[torch.Tensor] = None | |
def from_pretrained(cls, pretrained_model_name_or_path: str) -> 'PixtralVisionEncoder': | |
if os.path.isdir(pretrained_model_name_or_path): | |
model_folder = pretrained_model_name_or_path | |
else: | |
model_folder = snapshot_download(pretrained_model_name_or_path) | |
# make sure there is a config | |
if not os.path.exists(os.path.join(model_folder, "config.json")): | |
raise ValueError(f"Could not find config.json in {model_folder}") | |
# load config | |
with open(os.path.join(model_folder, "config.json"), "r") as f: | |
config = json.load(f) | |
model = cls(**config) | |
# see if there is a state_dict | |
if os.path.exists(os.path.join(model_folder, "model.safetensors")): | |
state_dict = load_file(os.path.join( | |
model_folder, "model.safetensors")) | |
model.load_state_dict(state_dict) | |
return model | |
def max_patches_per_side(self) -> int: | |
return self.args.image_size // self.args.patch_size | |
def device(self) -> torch.device: | |
return next(self.parameters()).device | |
def freqs_cis(self) -> torch.Tensor: | |
if self._freqs_cis is None: | |
self._freqs_cis = precompute_freqs_cis_2d( | |
dim=self.args.hidden_size // self.args.num_attention_heads, | |
height=self.max_patches_per_side, | |
width=self.max_patches_per_side, | |
theta=self.args.rope_theta, | |
) | |
if self._freqs_cis.device != self.device: | |
self._freqs_cis = self._freqs_cis.to(device=self.device) | |
return self._freqs_cis | |
def forward( | |
self, | |
images: List[torch.Tensor], | |
) -> torch.Tensor: | |
from xformers.ops.fmha.attn_bias import BlockDiagonalMask | |
""" | |
Args: | |
images: list of N_img images of variable sizes, each of shape (C, H, W) | |
Returns: | |
image_features: tensor of token features for all tokens of all images of | |
shape (N_toks, D) | |
""" | |
assert isinstance( | |
images, list), f"Expected list of images, got {type(images)}" | |
assert all(len(img.shape) == 3 for img in | |
images), f"Expected images with shape (C, H, W), got {[img.shape for img in images]}" | |
# pass images through initial convolution independently | |
patch_embeds_list = [self.patch_conv( | |
img.unsqueeze(0)).squeeze(0) for img in images] | |
# flatten to a single sequence | |
patch_embeds = torch.cat([p.flatten(1).permute(1, 0) | |
for p in patch_embeds_list], dim=0) | |
patch_embeds = self.ln_pre(patch_embeds) | |
# positional embeddings | |
positions = position_meshgrid(patch_embeds_list).to(self.device) | |
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] | |
# pass through Transformer with a block diagonal mask delimiting images | |
mask = BlockDiagonalMask.from_seqlens( | |
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], | |
) | |
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) | |
# remove batch dimension of the single sequence | |
return out # type: ignore[no-any-return] | |
class VisionLanguageAdapter(nn.Module): | |
def __init__(self, in_dim: int, out_dim: int): | |
super().__init__() | |
self.w_in = nn.Linear( | |
in_dim, | |
out_dim, | |
bias=True, | |
) | |
self.gelu = nn.GELU() | |
self.w_out = nn.Linear(out_dim, out_dim, bias=True) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# type: ignore[no-any-return] | |
return self.w_out(self.gelu(self.w_in(x))) | |
class VisionTransformerBlocks(nn.Module): | |
def __init__(self, args: VisionEncoderArgs): | |
super().__init__() | |
self.layers = torch.nn.ModuleList() | |
for _ in range(args.num_hidden_layers): | |
self.layers.append( | |
TransformerBlock( | |
dim=args.hidden_size, | |
hidden_dim=args.intermediate_size, | |
n_heads=args.num_attention_heads, | |
n_kv_heads=args.num_attention_heads, | |
head_dim=args.hidden_size // args.num_attention_heads, | |
norm_eps=1e-5, | |
) | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
mask: 'BlockDiagonalMask', | |
freqs_cis: Optional[torch.Tensor], | |
) -> torch.Tensor: | |
for layer in self.layers: | |
x = layer(x, mask=mask, freqs_cis=freqs_cis) | |
return x | |
DATASET_MEAN = [0.48145466, 0.4578275, 0.40821073] # RGB | |
DATASET_STD = [0.26862954, 0.26130258, 0.27577711] # RGB | |
def normalize(image: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: | |
""" | |
Normalize a tensor image with mean and standard deviation. | |
Args: | |
image (torch.Tensor): Image to be normalized, shape (C, H, W), values in [0, 1]. | |
mean (torch.Tensor): Mean for each channel. | |
std (torch.Tensor): Standard deviation for each channel. | |
Returns: | |
torch.Tensor: Normalized image with shape (C, H, W). | |
""" | |
assert image.shape[0] == len(mean) == len( | |
std), f"{image.shape=}, {mean.shape=}, {std.shape=}" | |
# Reshape mean and std to (C, 1, 1) for broadcasting | |
mean = mean.view(-1, 1, 1) | |
std = std.view(-1, 1, 1) | |
return (image - mean) / std | |
def transform_image(image: torch.Tensor, new_size: tuple[int, int]) -> torch.Tensor: | |
""" | |
Resize and normalize the input image. | |
Args: | |
image (torch.Tensor): Input image tensor of shape (C, H, W), values in [0, 1]. | |
new_size (tuple[int, int]): Target size (height, width) for resizing. | |
Returns: | |
torch.Tensor: Resized and normalized image tensor of shape (C, new_H, new_W). | |
""" | |
# Resize the image | |
resized_image = torch.nn.functional.interpolate( | |
image.unsqueeze(0), | |
size=new_size, | |
mode='bicubic', | |
align_corners=False | |
).squeeze(0) | |
# Normalize the image | |
normalized_image = normalize( | |
resized_image, | |
torch.tensor(DATASET_MEAN, device=image.device, dtype=image.dtype), | |
torch.tensor(DATASET_STD, device=image.device, dtype=image.dtype) | |
) | |
return normalized_image | |
class PixtralVisionImagePreprocessor: | |
def __init__(self, image_patch_size=16, max_image_size=1024) -> None: | |
self.image_patch_size = image_patch_size | |
self.max_image_size = max_image_size | |
self.image_token = 10 | |
def _image_to_num_tokens(self, img: torch.Tensor, max_image_size = None) -> Tuple[int, int]: | |
w: Union[int, float] | |
h: Union[int, float] | |
if max_image_size is None: | |
max_image_size = self.max_image_size | |
w, h = img.shape[-1], img.shape[-2] | |
# originally, pixtral used the largest of the 2 dimensions, but we | |
# will use the base size of the image based on number of pixels. | |
# ratio = max(h / self.max_image_size, w / self.max_image_size) # original | |
base_size = int(math.sqrt(w * h)) | |
ratio = base_size / max_image_size | |
if ratio > 1: | |
w = round(w / ratio) | |
h = round(h / ratio) | |
width_tokens = (w - 1) // self.image_patch_size + 1 | |
height_tokens = (h - 1) // self.image_patch_size + 1 | |
return width_tokens, height_tokens | |
def __call__(self, image: torch.Tensor, max_image_size=None) -> torch.Tensor: | |
""" | |
Converts ImageChunks to numpy image arrays and image token ids | |
Args: | |
image torch tensor with values 0-1 and shape of (C, H, W) | |
Returns: | |
processed_image: tensor of token features for all tokens of all images of | |
""" | |
# should not have batch | |
if len(image.shape) == 4: | |
raise ValueError( | |
f"Expected image with shape (C, H, W), got {image.shape}") | |
if image.min() < 0.0 or image.max() > 1.0: | |
raise ValueError( | |
f"image tensor values must be between 0 and 1. Got min: {image.min()}, max: {image.max()}") | |
if max_image_size is None: | |
max_image_size = self.max_image_size | |
w, h = self._image_to_num_tokens(image, max_image_size=max_image_size) | |
assert w > 0 | |
assert h > 0 | |
new_image_size = ( | |
w * self.image_patch_size, | |
h * self.image_patch_size, | |
) | |
processed_image = transform_image(image, new_image_size) | |
return processed_image | |
class PixtralVisionImagePreprocessorCompatibleReturn: | |
def __init__(self, pixel_values) -> None: | |
self.pixel_values = pixel_values | |
# Compatable version with ai toolkit flow | |
class PixtralVisionImagePreprocessorCompatible(PixtralVisionImagePreprocessor): | |
def __init__(self, image_patch_size=16, max_image_size=1024) -> None: | |
super().__init__( | |
image_patch_size=image_patch_size, | |
max_image_size=max_image_size | |
) | |
self.size = { | |
'height': max_image_size, | |
'width': max_image_size | |
} | |
self.max_image_size = max_image_size | |
self.image_mean = DATASET_MEAN | |
self.image_std = DATASET_STD | |
def __call__( | |
self, | |
images, | |
return_tensors="pt", | |
do_resize=True, | |
do_rescale=False, | |
max_image_size=None, | |
) -> torch.Tensor: | |
if max_image_size is None: | |
max_image_size = self.max_image_size | |
out_stack = [] | |
if len(images.shape) == 3: | |
images = images.unsqueeze(0) | |
for i in range(images.shape[0]): | |
image = images[i] | |
processed_image = super().__call__(image, max_image_size=max_image_size) | |
out_stack.append(processed_image) | |
output = torch.stack(out_stack, dim=0) | |
return PixtralVisionImagePreprocessorCompatibleReturn(output) | |
class PixtralVisionEncoderCompatibleReturn: | |
def __init__(self, hidden_states) -> None: | |
self.hidden_states = hidden_states | |
class PixtralVisionEncoderCompatibleConfig: | |
def __init__(self): | |
self.image_size = 1024 | |
self.hidden_size = 1024 | |
self.patch_size = 16 | |
class PixtralVisionEncoderCompatible(PixtralVisionEncoder): | |
def __init__( | |
self, | |
hidden_size: int = 1024, | |
num_channels: int = 3, | |
image_size: int = 1024, | |
patch_size: int = 16, | |
intermediate_size: int = 4096, | |
num_hidden_layers: int = 24, | |
num_attention_heads: int = 16, | |
rope_theta: float = 1e4, # for rope-2D | |
image_token_id: int = 10, | |
**kwargs, | |
): | |
super().__init__( | |
hidden_size=hidden_size, | |
num_channels=num_channels, | |
image_size=image_size, | |
patch_size=patch_size, | |
intermediate_size=intermediate_size, | |
num_hidden_layers=num_hidden_layers, | |
num_attention_heads=num_attention_heads, | |
rope_theta=rope_theta, | |
image_token_id=image_token_id, | |
) | |
self.config = PixtralVisionEncoderCompatibleConfig() | |
def forward( | |
self, | |
images, | |
output_hidden_states=True, | |
) -> torch.Tensor: | |
out_stack = [] | |
if len(images.shape) == 3: | |
images = images.unsqueeze(0) | |
for i in range(images.shape[0]): | |
image = images[i] | |
# must be in an array | |
image_output = super().forward([image]) | |
out_stack.append(image_output) | |
output = torch.stack(out_stack, dim=0) | |
return PixtralVisionEncoderCompatibleReturn([output]) | |