Spaces:
Running
on
Zero
Running
on
Zero
import random | |
from einops import rearrange | |
from diffusers.models import AutoencoderKL | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from torchvision.transforms.functional import to_pil_image | |
from flux.sampling import prepare_modified | |
from flux.util import load_clip, load_t5, load_flow_model | |
from transport import Sampler, create_transport | |
from imgproc import to_rgb_if_rgba | |
def center_crop(image, target_size): | |
width, height = image.size | |
new_width, new_height = target_size | |
left = (width - new_width) // 2 | |
top = (height - new_height) // 2 | |
right = left + new_width | |
bottom = top + new_height | |
return image.crop((left, top, right, bottom)) | |
def resize_with_aspect_ratio(img, resolution, divisible=16, aspect_ratio=None): | |
"""Resize image while maintaining aspect ratio, ensuring area is close to resolution**2 and dimensions are divisible by 16 | |
Args: | |
img: PIL Image or torch.Tensor (C,H,W)/(B,C,H,W) | |
resolution: target resolution | |
divisible: ensure output dimensions are divisible by this number | |
Returns: | |
Resized image of the same type as input | |
""" | |
# Check input type and get dimensions | |
is_tensor = isinstance(img, torch.Tensor) | |
if is_tensor: | |
if img.dim() == 3: | |
c, h, w = img.shape | |
batch_dim = False | |
else: | |
b, c, h, w = img.shape | |
batch_dim = True | |
else: | |
w, h = img.size | |
# Calculate new dimensions | |
if aspect_ratio is None: | |
aspect_ratio = w / h | |
target_area = resolution * resolution | |
new_h = int((target_area / aspect_ratio) ** 0.5) | |
new_w = int(new_h * aspect_ratio) | |
# Ensure divisible by divisible | |
new_w = max(new_w // divisible, 1) * divisible | |
new_h = max(new_h // divisible, 1) * divisible | |
# Adjust size based on input type | |
if is_tensor: | |
# Use torch interpolation method | |
mode = 'bilinear' | |
align_corners = False | |
if batch_dim: | |
return F.interpolate(img, size=(new_h, new_w), | |
mode=mode, align_corners=align_corners) | |
else: | |
return F.interpolate(img.unsqueeze(0), size=(new_h, new_w), | |
mode=mode, align_corners=align_corners).squeeze(0) | |
else: | |
# Use PIL LANCZOS resampling | |
return img.resize((new_w, new_h), Image.LANCZOS) | |
class VisualClozeModel: | |
def __init__( | |
self, model_path, model_name="flux-dev-fill-lora", max_length=512, lora_rank=256, | |
atol=1e-6, rtol=1e-3, solver='euler', time_shifting_factor=1, | |
resolution=384, precision='bf16'): | |
self.atol = atol | |
self.rtol = rtol | |
self.solver = solver | |
self.time_shifting_factor = time_shifting_factor | |
self.resolution = resolution | |
self.precision = precision | |
self.max_length = max_length | |
self.lora_rank = lora_rank | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision] | |
# Initialize model | |
print("Initializing model...") | |
self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank) | |
# Initialize VAE | |
print("Initializing VAE...") | |
self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device) | |
self.ae.requires_grad_(False) | |
# Initialize text encoders | |
print("Initializing text encoders...") | |
self.t5 = load_t5(self.device, max_length=self.max_length) | |
self.clip = load_clip(self.device) | |
self.model.eval().to(self.device, dtype=self.dtype) | |
# Load model weights | |
ckpt = torch.load(model_path) | |
self.model.load_state_dict(ckpt, strict=False) | |
del ckpt | |
# Initialize sampler | |
transport = create_transport( | |
"Linear", | |
"velocity", | |
do_shift=True, | |
) | |
self.sampler = Sampler(transport) | |
self.sample_fn = self.sampler.sample_ode( | |
sampling_method=self.solver, | |
num_steps=30, | |
atol=self.atol, | |
rtol=self.rtol, | |
reverse=False, | |
do_shift=True, | |
time_shifting_factor=self.time_shifting_factor, | |
) | |
# Image transformation | |
self.image_transform = transforms.Compose([ | |
transforms.Lambda(lambda img: to_rgb_if_rgba(img)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
]) | |
self.grid_h = None | |
self.grid_w = None | |
def set_grid_size(self, h, w): | |
"""Set grid size""" | |
self.grid_h = h | |
self.grid_w = w | |
def upsampling(self, image, target_size, cfg, upsampling_steps, upsampling_noise, generator, content_prompt): | |
content_instruction = [ | |
"The content of the last image in the final row is: ", | |
"The last image of the last row depicts: ", | |
"In the final row, the last image shows: ", | |
"The last image in the bottom row illustrates: ", | |
"The content of the bottom-right image is: ", | |
"The final image in the last row portrays: ", | |
"The last image of the final row displays: ", | |
"In the last row, the final image captures: ", | |
"The bottom-right corner image presents: ", | |
"The content of the last image in the concluding row is: ", | |
"In the last row, ", | |
"The editing instruction in the last row is: ", | |
] | |
for c in content_instruction: | |
if content_prompt.startswith(c): | |
content_prompt = content_prompt.replace(c, '') | |
if target_size is None: | |
aspect_ratio = 1 | |
target_area = 1024 * 1024 | |
new_h = int((target_area / aspect_ratio) ** 0.5) | |
new_w = int(new_h * aspect_ratio) | |
target_size = (new_w, new_h) | |
if target_size[0] * target_size[1] > 1600 * 1600: | |
aspect_ratio = target_size[0] / target_size[1] | |
target_area = 1600 * 1600 | |
new_h = int((target_area / aspect_ratio) ** 0.5) | |
new_w = int(new_h * aspect_ratio) | |
target_size = (new_w, new_h) | |
self.sample_fn = self.sampler.sample_ode( | |
sampling_method=self.solver, | |
num_steps=upsampling_steps, | |
atol=self.atol, | |
rtol=self.rtol, | |
reverse=False, | |
do_shift=False, | |
time_shifting_factor=1.0, | |
strength=upsampling_noise | |
) | |
image = image.resize(((target_size[0] // 16) * 16, (target_size[1] // 16) * 16)) | |
processed_image = self.image_transform(image) | |
processed_image = processed_image.to(self.device, non_blocking=True) | |
blank = torch.zeros_like(processed_image, device=self.device, dtype=self.dtype) | |
mask = torch.full((1, 1, processed_image.shape[1], processed_image.shape[2]), fill_value=1, device=self.device, dtype=self.dtype) | |
with torch.no_grad(): | |
latent = self.ae.encode(processed_image[None].to(self.ae.dtype)).latent_dist.sample() | |
blank = self.ae.encode(blank[None].to(self.ae.dtype)).latent_dist.sample() | |
latent = (latent - self.ae.config.shift_factor) * self.ae.config.scaling_factor | |
blank = (blank - self.ae.config.shift_factor) * self.ae.config.scaling_factor | |
latent_h, latent_w = latent.shape[2:] | |
mask = rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) | |
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
latent = latent.to(self.dtype) | |
blank = blank.to(self.dtype) | |
latent = rearrange(latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
blank = rearrange(blank, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
img_cond = torch.cat((blank, mask), dim=-1) | |
# Generate noise | |
noise = torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=generator).to(self.dtype) | |
x = [[noise]] | |
inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[content_prompt], proportion_empty_prompts=0.0) | |
inp["img"] = inp["img"] * (1 - upsampling_noise) + latent * upsampling_noise | |
model_kwargs = dict( | |
txt=inp["txt"], | |
txt_ids=inp["txt_ids"], | |
txt_mask=inp["txt_mask"], | |
y=inp["vec"], | |
img_ids=inp["img_ids"], | |
img_mask=inp["img_mask"], | |
cond=img_cond, | |
guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype), | |
) | |
sample = self.sample_fn( | |
inp["img"], self.model.forward, model_kwargs | |
)[-1] | |
sample = sample[:1] | |
sample = rearrange(sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h // 2, w=latent_w // 2) | |
sample = self.ae.decode(sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0] | |
sample = (sample + 1.0) / 2.0 | |
sample.clamp_(0.0, 1.0) | |
sample = sample[0] | |
output_image = to_pil_image(sample.float()) | |
return output_image | |
def process_images( | |
self, images: list[list[Image.Image]], text_prompt: list[str], | |
seed: int = 0, | |
cfg: int = 30, | |
steps: int = 30, | |
upsampling_steps: int = 10, | |
upsampling_noise: float = 0.4, | |
is_upsampling: bool =True): | |
""" | |
Processes a list of images based on provided text prompts and settings, | |
with optional upsampling steps to improve image resolution or detail. | |
Parameters: | |
images (list[list[Image.Image]]): A grid-layout image collection, each row represents an in-context example or the current query, | |
where the current query should be placed in the last row. | |
The target image can be None in the input. The other images should be the PIL Image class (Image.Image). | |
text_prompt (list[str]): Three prompts, representing the layout prompt, task prompt, and content prompt respectively. | |
seed (int): A fixed integer seed to ensure reproducibility of the random elements in the processing. | |
cfg (int): The strength of Classifier-Free Diffusion Guidance. | |
steps (int): The number of sampling steps. | |
upsampling_steps (int): The number of denoising steps when upsampling. | |
upsampling_noise (float): When upsampling using SDEdit, | |
the noise is used as a starting point and less noise is added the higher the strength. | |
A value of 1 means added noise is maximum. | |
is_upsampling (bool, optional): A flag to indicate if upsampling should be applied using SDEdit. | |
Returns: | |
Processed images as a result of the algorithm, with optional upsampling applied based on the `is_upsampling` flag. | |
""" | |
if seed == 0: | |
seed = random.randint(0, 2 ** 32 - 1) | |
self.sample_fn = self.sampler.sample_ode( | |
sampling_method=self.solver, | |
num_steps=steps, | |
atol=self.atol, | |
rtol=self.rtol, | |
reverse=False, | |
do_shift=True, | |
time_shifting_factor=self.time_shifting_factor, | |
) | |
# Use class grid size | |
grid_h, grid_w = self.grid_h, self.grid_w | |
# Ensure all images are RGB mode or None | |
for i in range(0, grid_h): | |
images[i] = [img.convert("RGB") if img is not None else None for img in images[i]] | |
# Adjust all image sizes | |
resolution = self.resolution | |
processed_images = [] | |
mask_position = [] | |
target_size = None | |
upsampling_size = None | |
for i in range(grid_h): | |
# Find the size of the first non-empty image in this row | |
reference_size = None | |
for j in range(0, grid_w): | |
if images[i][j] is not None: | |
if i == grid_h - 1 and upsampling_size is None: | |
upsampling_size = images[i][j].size | |
resized = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None) | |
reference_size = resized.size | |
if i == grid_h - 1 and target_size is None: | |
target_size = reference_size | |
break | |
# Process all images in this row | |
for j in range(0, grid_w): | |
if images[i][j] is not None: | |
target = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None) | |
if target.width <= target.height: | |
target = target.resize((reference_size[0], int(reference_size[0] / target.width * target.height))) | |
target = center_crop(target, reference_size) | |
elif target.width > target.height: | |
target = target.resize((int(reference_size[1] / target.height * target.width), reference_size[1])) | |
target = center_crop(target, reference_size) | |
processed_images.append(target) | |
if i == grid_h - 1: | |
mask_position.append(0) | |
else: | |
# If this row has a reference size, use it; otherwise use default size | |
if reference_size: | |
blank = Image.new('RGB', reference_size, (0, 0, 0)) | |
else: | |
blank = Image.new('RGB', (resolution, resolution), (0, 0, 0)) | |
processed_images.append(blank) | |
if i == grid_h - 1: | |
mask_position.append(1) | |
if len(mask_position) > 1 and sum(mask_position) > 1: | |
if target_size is None: | |
new_w = 384 | |
else: | |
new_w = target_size[0] | |
for i in range(len(processed_images)): | |
if processed_images[i] is not None: | |
new_h = int(processed_images[i].height * (new_w / processed_images[i].width)) | |
new_w = int(new_w / 16) * 16 | |
new_h = int(new_h / 16) * 16 | |
processed_images[i] = processed_images[i].resize((new_w, new_h)) | |
# Build grid image and mask | |
with torch.autocast("cuda", self.dtype): | |
grid_image = [] | |
fill_mask = [] | |
for i in range(grid_h): | |
row_images = [self.image_transform(img) for img in processed_images[i * grid_w: (i + 1) * grid_w]] | |
if i == grid_h - 1: | |
row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=m, device=self.device) for m in mask_position] | |
else: | |
row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=0, device=self.device) for m in mask_position] | |
grid_image.append(torch.cat(row_images, dim=2).to(self.device, non_blocking=True)) | |
fill_mask.append(torch.cat(row_masks, dim=3)) | |
# Encode condition image | |
with torch.no_grad(): | |
fill_cond = [self.ae.encode(img[None].to(self.ae.dtype)).latent_dist.sample()[0] for img in grid_image] | |
fill_cond = [(img - self.ae.config.shift_factor) * self.ae.config.scaling_factor for img in fill_cond] | |
# Rearrange mask | |
fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) for mask in fill_mask] | |
fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for mask in fill_mask] | |
fill_cond = [img.to(self.dtype) for img in fill_cond] | |
fill_cond = [rearrange(img.unsqueeze(0), "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for img in fill_cond] | |
fill_cond = torch.cat(fill_cond, dim=1) | |
fill_mask = torch.cat(fill_mask, dim=1) | |
img_cond = torch.cat((fill_cond, fill_mask), dim=-1) | |
# Generate sample | |
noise = [] | |
sliced_subimage = [] | |
rng = torch.Generator(device=self.device).manual_seed(int(seed)) | |
for sub_img in grid_image: | |
h, w = sub_img.shape[-2:] | |
sliced_subimage.append((h, w)) | |
latent_w, latent_h = w // 8, h // 8 | |
noise.append(torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=rng).to(self.dtype)) | |
x = [noise] | |
with torch.no_grad(): | |
inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[' '.join(text_prompt)], proportion_empty_prompts=0.0) | |
model_kwargs = dict( | |
txt=inp["txt"], | |
txt_ids=inp["txt_ids"], | |
txt_mask=inp["txt_mask"], | |
y=inp["vec"], | |
img_ids=inp["img_ids"], | |
img_mask=inp["img_mask"], | |
cond=img_cond, | |
guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype), | |
) | |
samples = self.sample_fn( | |
inp["img"], self.model.forward, model_kwargs | |
)[-1] | |
# Get query row | |
samples = samples[:1] | |
row_samples = [] | |
start = 0 | |
for size in sliced_subimage: | |
end = start + (size[0] * size[1] // 256) | |
latent_h = size[0] // 8 | |
latent_w = size[1] // 8 | |
row_sample = samples[:, start:end, :] | |
row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2) | |
row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0] | |
row_sample = (row_sample + 1.0) / 2.0 | |
row_sample.clamp_(0.0, 1.0) | |
row_samples.append(row_sample[0]) | |
start = end | |
# Convert all samples to PIL images | |
output_images = [] | |
for row_sample in row_samples: | |
output_image = to_pil_image(row_sample.float()) | |
output_images.append(output_image) | |
ret = [] | |
ret_w = output_images[-1].width | |
ret_h = output_images[-1].height | |
row_start = (grid_h - 1) * grid_w | |
row_end = grid_h * grid_w | |
for i in range(row_start, row_end): | |
# when the image is masked, then output it | |
if True: # images[i] is None: | |
cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h)) | |
ret.append(cropped) | |
if mask_position[i - row_start] and is_upsampling: | |
upsampled = self.upsampling( | |
cropped, | |
upsampling_size, | |
cfg, | |
upsampling_steps=upsampling_steps, | |
upsampling_noise=upsampling_noise, | |
generator=rng, | |
content_prompt=text_prompt[2]) | |
ret.append(upsampled) | |
return ret |