Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,492 Bytes
779c9ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import random
from pathlib import Path
from typing import Optional
import numpy as np
import pyrallis
import torch
from diffusers import (
StableDiffusionXLPipeline,
)
from huggingface_hub import hf_hub_download
from PIL import Image
from ip_adapter import IPAdapterPlusXL
from model.dit import DiT_Llama
from model.pipeline_pit import PiTPipeline
from training.train_config import TrainConfig
def paste_on_background(image, background, min_scale=0.4, max_scale=0.8, scale=None):
# Calculate aspect ratio and determine resizing based on the smaller dimension of the background
aspect_ratio = image.width / image.height
scale = random.uniform(min_scale, max_scale) if scale is None else scale
new_width = int(min(background.width, background.height * aspect_ratio) * scale)
new_height = int(new_width / aspect_ratio)
# Resize image and calculate position
image = image.resize((new_width, new_height), resample=Image.LANCZOS)
pos_x = random.randint(0, background.width - new_width)
pos_y = random.randint(0, background.height - new_height)
# Paste the image using its alpha channel as mask if present
background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
return background
def set_seed(seed: int):
"""Ensures reproducibility across multiple libraries."""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random module
torch.manual_seed(seed) # PyTorch CPU random seed
torch.cuda.manual_seed_all(seed) # PyTorch GPU random seed
torch.backends.cudnn.deterministic = True # Ensures deterministic behavior
torch.backends.cudnn.benchmark = False # Disable benchmarking to avoid randomness
class PiTDemoPipeline:
def __init__(self, prior_repo: str, prior_path: str):
# Download model and config
prior_ckpt_path = hf_hub_download(
repo_id=prior_repo,
filename=str(prior_path),
local_dir="pretrained_models",
)
prior_cfg_path = hf_hub_download(
repo_id=prior_repo, filename=str(Path(prior_path).parent / "cfg.yaml"), local_dir="pretrained_models"
)
self.model_cfg: TrainConfig = pyrallis.load(TrainConfig, open(prior_cfg_path, "r"))
self.weight_dtype = torch.float32
self.device = "cuda:0"
prior = DiT_Llama(
embedding_dim=2048,
hidden_dim=self.model_cfg.hidden_dim,
n_layers=self.model_cfg.num_layers,
n_heads=self.model_cfg.num_attention_heads,
)
prior.load_state_dict(torch.load(prior_ckpt_path))
image_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
add_watermarker=False,
)
ip_ckpt_path = hf_hub_download(
repo_id="h94/IP-Adapter",
filename="ip-adapter-plus_sdxl_vit-h.bin",
subfolder="sdxl_models",
local_dir="pretrained_models",
)
self.ip_model = IPAdapterPlusXL(
image_pipe,
"models/image_encoder",
ip_ckpt_path,
self.device,
num_tokens=16,
)
self.image_processor = self.ip_model.clip_image_processor
empty_image = Image.new("RGB", (256, 256), (255, 255, 255))
zero_image = torch.Tensor(self.image_processor(empty_image)["pixel_values"][0])
self.zero_image_embeds = self.ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True)
prior_pipeline = PiTPipeline(
prior=prior,
)
self.prior_pipeline = prior_pipeline.to(self.device)
set_seed(42)
def run(self, crops_paths: list[str], scale: float = 2.0, seed: Optional[int] = None, n_images: int = 1):
if seed is not None:
set_seed(seed)
processed_crops = []
input_images = []
crops_paths = [None] + crops_paths
# Extend to >3 with Nones
while len(crops_paths) < 3:
crops_paths.append(None)
for path_ind, path in enumerate(crops_paths):
if path is None:
image = Image.new("RGB", (224, 224), (255, 255, 255))
else:
image = Image.open(path).convert("RGB")
if path_ind > 0 or not self.model_cfg.use_ref:
background = Image.new("RGB", (1024, 1024), (255, 255, 255))
image = paste_on_background(image, background, scale=0.92)
else:
image = image.resize((1024, 1024))
input_images.append(image)
# Name should be parent directory name
processed_image = (
torch.Tensor(self.image_processor(image)["pixel_values"][0])
.to(self.device)
.unsqueeze(0)
.to(self.weight_dtype)
)
processed_crops.append(processed_image)
image_embed_inputs = []
for crop_ind in range(len(processed_crops)):
image_embed_inputs.append(self.ip_model.get_image_embeds(processed_crops[crop_ind], skip_uncond=True))
crops_input_sequence = torch.cat(image_embed_inputs, dim=1)
generated_images = []
for _ in range(n_images):
seed = random.randint(0, 1000000)
for curr_scale in [scale]:
negative_cond_sequence = torch.zeros_like(crops_input_sequence)
embeds_len = self.zero_image_embeds.shape[1]
for i in range(0, negative_cond_sequence.shape[1], embeds_len):
negative_cond_sequence[:, i : i + embeds_len] = self.zero_image_embeds.detach()
img_emb = self.prior_pipeline(
cond_sequence=crops_input_sequence,
negative_cond_sequence=negative_cond_sequence,
num_inference_steps=25,
num_images_per_prompt=1,
guidance_scale=curr_scale,
generator=torch.Generator(device="cuda").manual_seed(seed),
).image_embeds
for seed_2 in range(1):
images = self.ip_model.generate(
image_prompt_embeds=img_emb,
num_samples=1,
num_inference_steps=50,
)
generated_images += images
return generated_images
|