Spaces:
Paused
Paused
import math | |
from functools import partial | |
from torch import nn | |
import torch | |
class AuraFlowPatchEmbed(nn.Module): | |
def __init__( | |
self, | |
height=224, | |
width=224, | |
patch_size=16, | |
in_channels=3, | |
embed_dim=768, | |
pos_embed_max_size=None, | |
): | |
super().__init__() | |
self.num_patches = (height // patch_size) * (width // patch_size) | |
self.pos_embed_max_size = pos_embed_max_size | |
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim) | |
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1) | |
self.patch_size = patch_size | |
self.height, self.width = height // patch_size, width // patch_size | |
self.base_size = height // patch_size | |
def forward(self, latent): | |
batch_size, num_channels, height, width = latent.size() | |
latent = latent.view( | |
batch_size, | |
num_channels, | |
height // self.patch_size, | |
self.patch_size, | |
width // self.patch_size, | |
self.patch_size, | |
) | |
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) | |
latent = self.proj(latent) | |
try: | |
return latent + self.pos_embed | |
except RuntimeError: | |
raise RuntimeError( | |
f"Positional embeddings are too small for the number of patches. " | |
f"Please increase `pos_embed_max_size` to at least {self.num_patches}." | |
) | |
# comfy | |
# def apply_pos_embeds(self, x, h, w): | |
# h = (h + 1) // self.patch_size | |
# w = (w + 1) // self.patch_size | |
# max_dim = max(h, w) | |
# | |
# cur_dim = self.h_max | |
# pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype) | |
# | |
# if max_dim > cur_dim: | |
# pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, | |
# -1) | |
# cur_dim = max_dim | |
# | |
# from_h = (cur_dim - h) // 2 | |
# from_w = (cur_dim - w) // 2 | |
# pos_encoding = pos_encoding[:, from_h:from_h + h, from_w:from_w + w] | |
# return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1]) | |
# def patchify(self, x): | |
# B, C, H, W = x.size() | |
# pad_h = (self.patch_size - H % self.patch_size) % self.patch_size | |
# pad_w = (self.patch_size - W % self.patch_size) % self.patch_size | |
# | |
# x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect') | |
# x = x.view( | |
# B, | |
# C, | |
# (H + 1) // self.patch_size, | |
# self.patch_size, | |
# (W + 1) // self.patch_size, | |
# self.patch_size, | |
# ) | |
# x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) | |
# return x | |
def patch_auraflow_pos_embed(pos_embed): | |
# we need to hijack the forward and replace with a custom one. Self is the model | |
def new_forward(self, latent): | |
batch_size, num_channels, height, width = latent.size() | |
# add padding to the latent to make it match pos_embed | |
latent_size = height * width * num_channels / 16 # todo check where 16 comes from? | |
pos_embed_size = self.pos_embed.shape[1] | |
if latent_size < pos_embed_size: | |
total_padding = int(pos_embed_size - math.floor(latent_size)) | |
total_padding = total_padding // 16 | |
pad_height = total_padding // 2 | |
pad_width = total_padding - pad_height | |
# mirror padding on the right side | |
padding = (0, pad_width, 0, pad_height) | |
latent = torch.nn.functional.pad(latent, padding, mode='reflect') | |
elif latent_size > pos_embed_size: | |
amount_to_remove = latent_size - pos_embed_size | |
latent = latent[:, :, :-amount_to_remove] | |
batch_size, num_channels, height, width = latent.size() | |
latent = latent.view( | |
batch_size, | |
num_channels, | |
height // self.patch_size, | |
self.patch_size, | |
width // self.patch_size, | |
self.patch_size, | |
) | |
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2) | |
latent = self.proj(latent) | |
try: | |
return latent + self.pos_embed | |
except RuntimeError: | |
raise RuntimeError( | |
f"Positional embeddings are too small for the number of patches. " | |
f"Please increase `pos_embed_max_size` to at least {self.num_patches}." | |
) | |
pos_embed.forward = partial(new_forward, pos_embed) | |