kfirbria's picture
add demo
779c9ab
import random
from pathlib import Path
from typing import Optional
import numpy as np
import pyrallis
import torch
from diffusers import (
StableDiffusionXLPipeline,
)
from huggingface_hub import hf_hub_download
from PIL import Image
from ip_adapter import IPAdapterPlusXL
from model.dit import DiT_Llama
from model.pipeline_pit import PiTPipeline
from training.train_config import TrainConfig
def paste_on_background(image, background, min_scale=0.4, max_scale=0.8, scale=None):
# Calculate aspect ratio and determine resizing based on the smaller dimension of the background
aspect_ratio = image.width / image.height
scale = random.uniform(min_scale, max_scale) if scale is None else scale
new_width = int(min(background.width, background.height * aspect_ratio) * scale)
new_height = int(new_width / aspect_ratio)
# Resize image and calculate position
image = image.resize((new_width, new_height), resample=Image.LANCZOS)
pos_x = random.randint(0, background.width - new_width)
pos_y = random.randint(0, background.height - new_height)
# Paste the image using its alpha channel as mask if present
background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
return background
def set_seed(seed: int):
"""Ensures reproducibility across multiple libraries."""
random.seed(seed) # Python random module
np.random.seed(seed) # NumPy random module
torch.manual_seed(seed) # PyTorch CPU random seed
torch.cuda.manual_seed_all(seed) # PyTorch GPU random seed
torch.backends.cudnn.deterministic = True # Ensures deterministic behavior
torch.backends.cudnn.benchmark = False # Disable benchmarking to avoid randomness
class PiTDemoPipeline:
def __init__(self, prior_repo: str, prior_path: str):
# Download model and config
prior_ckpt_path = hf_hub_download(
repo_id=prior_repo,
filename=str(prior_path),
local_dir="pretrained_models",
)
prior_cfg_path = hf_hub_download(
repo_id=prior_repo, filename=str(Path(prior_path).parent / "cfg.yaml"), local_dir="pretrained_models"
)
self.model_cfg: TrainConfig = pyrallis.load(TrainConfig, open(prior_cfg_path, "r"))
self.weight_dtype = torch.float32
self.device = "cuda:0"
prior = DiT_Llama(
embedding_dim=2048,
hidden_dim=self.model_cfg.hidden_dim,
n_layers=self.model_cfg.num_layers,
n_heads=self.model_cfg.num_attention_heads,
)
prior.load_state_dict(torch.load(prior_ckpt_path))
image_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
add_watermarker=False,
)
ip_ckpt_path = hf_hub_download(
repo_id="h94/IP-Adapter",
filename="ip-adapter-plus_sdxl_vit-h.bin",
subfolder="sdxl_models",
local_dir="pretrained_models",
)
self.ip_model = IPAdapterPlusXL(
image_pipe,
"models/image_encoder",
ip_ckpt_path,
self.device,
num_tokens=16,
)
self.image_processor = self.ip_model.clip_image_processor
empty_image = Image.new("RGB", (256, 256), (255, 255, 255))
zero_image = torch.Tensor(self.image_processor(empty_image)["pixel_values"][0])
self.zero_image_embeds = self.ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True)
prior_pipeline = PiTPipeline(
prior=prior,
)
self.prior_pipeline = prior_pipeline.to(self.device)
set_seed(42)
def run(self, crops_paths: list[str], scale: float = 2.0, seed: Optional[int] = None, n_images: int = 1):
if seed is not None:
set_seed(seed)
processed_crops = []
input_images = []
crops_paths = [None] + crops_paths
# Extend to >3 with Nones
while len(crops_paths) < 3:
crops_paths.append(None)
for path_ind, path in enumerate(crops_paths):
if path is None:
image = Image.new("RGB", (224, 224), (255, 255, 255))
else:
image = Image.open(path).convert("RGB")
if path_ind > 0 or not self.model_cfg.use_ref:
background = Image.new("RGB", (1024, 1024), (255, 255, 255))
image = paste_on_background(image, background, scale=0.92)
else:
image = image.resize((1024, 1024))
input_images.append(image)
# Name should be parent directory name
processed_image = (
torch.Tensor(self.image_processor(image)["pixel_values"][0])
.to(self.device)
.unsqueeze(0)
.to(self.weight_dtype)
)
processed_crops.append(processed_image)
image_embed_inputs = []
for crop_ind in range(len(processed_crops)):
image_embed_inputs.append(self.ip_model.get_image_embeds(processed_crops[crop_ind], skip_uncond=True))
crops_input_sequence = torch.cat(image_embed_inputs, dim=1)
generated_images = []
for _ in range(n_images):
seed = random.randint(0, 1000000)
for curr_scale in [scale]:
negative_cond_sequence = torch.zeros_like(crops_input_sequence)
embeds_len = self.zero_image_embeds.shape[1]
for i in range(0, negative_cond_sequence.shape[1], embeds_len):
negative_cond_sequence[:, i : i + embeds_len] = self.zero_image_embeds.detach()
img_emb = self.prior_pipeline(
cond_sequence=crops_input_sequence,
negative_cond_sequence=negative_cond_sequence,
num_inference_steps=25,
num_images_per_prompt=1,
guidance_scale=curr_scale,
generator=torch.Generator(device="cuda").manual_seed(seed),
).image_embeds
for seed_2 in range(1):
images = self.ip_model.generate(
image_prompt_embeds=img_emb,
num_samples=1,
num_inference_steps=50,
)
generated_images += images
return generated_images