Spaces:
Running
on
Zero
Running
on
Zero
import random | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import pyrallis | |
import torch | |
from diffusers import ( | |
StableDiffusionXLPipeline, | |
) | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
from ip_adapter import IPAdapterPlusXL | |
from model.dit import DiT_Llama | |
from model.pipeline_pit import PiTPipeline | |
from training.train_config import TrainConfig | |
def paste_on_background(image, background, min_scale=0.4, max_scale=0.8, scale=None): | |
# Calculate aspect ratio and determine resizing based on the smaller dimension of the background | |
aspect_ratio = image.width / image.height | |
scale = random.uniform(min_scale, max_scale) if scale is None else scale | |
new_width = int(min(background.width, background.height * aspect_ratio) * scale) | |
new_height = int(new_width / aspect_ratio) | |
# Resize image and calculate position | |
image = image.resize((new_width, new_height), resample=Image.LANCZOS) | |
pos_x = random.randint(0, background.width - new_width) | |
pos_y = random.randint(0, background.height - new_height) | |
# Paste the image using its alpha channel as mask if present | |
background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None) | |
return background | |
def set_seed(seed: int): | |
"""Ensures reproducibility across multiple libraries.""" | |
random.seed(seed) # Python random module | |
np.random.seed(seed) # NumPy random module | |
torch.manual_seed(seed) # PyTorch CPU random seed | |
torch.cuda.manual_seed_all(seed) # PyTorch GPU random seed | |
torch.backends.cudnn.deterministic = True # Ensures deterministic behavior | |
torch.backends.cudnn.benchmark = False # Disable benchmarking to avoid randomness | |
class PiTDemoPipeline: | |
def __init__(self, prior_repo: str, prior_path: str): | |
# Download model and config | |
prior_ckpt_path = hf_hub_download( | |
repo_id=prior_repo, | |
filename=str(prior_path), | |
local_dir="pretrained_models", | |
) | |
prior_cfg_path = hf_hub_download( | |
repo_id=prior_repo, filename=str(Path(prior_path).parent / "cfg.yaml"), local_dir="pretrained_models" | |
) | |
self.model_cfg: TrainConfig = pyrallis.load(TrainConfig, open(prior_cfg_path, "r")) | |
self.weight_dtype = torch.float32 | |
self.device = "cuda:0" | |
prior = DiT_Llama( | |
embedding_dim=2048, | |
hidden_dim=self.model_cfg.hidden_dim, | |
n_layers=self.model_cfg.num_layers, | |
n_heads=self.model_cfg.num_attention_heads, | |
) | |
prior.load_state_dict(torch.load(prior_ckpt_path)) | |
image_pipe = StableDiffusionXLPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=torch.float16, | |
add_watermarker=False, | |
) | |
ip_ckpt_path = hf_hub_download( | |
repo_id="h94/IP-Adapter", | |
filename="ip-adapter-plus_sdxl_vit-h.bin", | |
subfolder="sdxl_models", | |
local_dir="pretrained_models", | |
) | |
self.ip_model = IPAdapterPlusXL( | |
image_pipe, | |
"models/image_encoder", | |
ip_ckpt_path, | |
self.device, | |
num_tokens=16, | |
) | |
self.image_processor = self.ip_model.clip_image_processor | |
empty_image = Image.new("RGB", (256, 256), (255, 255, 255)) | |
zero_image = torch.Tensor(self.image_processor(empty_image)["pixel_values"][0]) | |
self.zero_image_embeds = self.ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True) | |
prior_pipeline = PiTPipeline( | |
prior=prior, | |
) | |
self.prior_pipeline = prior_pipeline.to(self.device) | |
set_seed(42) | |
def run(self, crops_paths: list[str], scale: float = 2.0, seed: Optional[int] = None, n_images: int = 1): | |
if seed is not None: | |
set_seed(seed) | |
processed_crops = [] | |
input_images = [] | |
crops_paths = [None] + crops_paths | |
# Extend to >3 with Nones | |
while len(crops_paths) < 3: | |
crops_paths.append(None) | |
for path_ind, path in enumerate(crops_paths): | |
if path is None: | |
image = Image.new("RGB", (224, 224), (255, 255, 255)) | |
else: | |
image = Image.open(path).convert("RGB") | |
if path_ind > 0 or not self.model_cfg.use_ref: | |
background = Image.new("RGB", (1024, 1024), (255, 255, 255)) | |
image = paste_on_background(image, background, scale=0.92) | |
else: | |
image = image.resize((1024, 1024)) | |
input_images.append(image) | |
# Name should be parent directory name | |
processed_image = ( | |
torch.Tensor(self.image_processor(image)["pixel_values"][0]) | |
.to(self.device) | |
.unsqueeze(0) | |
.to(self.weight_dtype) | |
) | |
processed_crops.append(processed_image) | |
image_embed_inputs = [] | |
for crop_ind in range(len(processed_crops)): | |
image_embed_inputs.append(self.ip_model.get_image_embeds(processed_crops[crop_ind], skip_uncond=True)) | |
crops_input_sequence = torch.cat(image_embed_inputs, dim=1) | |
generated_images = [] | |
for _ in range(n_images): | |
seed = random.randint(0, 1000000) | |
for curr_scale in [scale]: | |
negative_cond_sequence = torch.zeros_like(crops_input_sequence) | |
embeds_len = self.zero_image_embeds.shape[1] | |
for i in range(0, negative_cond_sequence.shape[1], embeds_len): | |
negative_cond_sequence[:, i : i + embeds_len] = self.zero_image_embeds.detach() | |
img_emb = self.prior_pipeline( | |
cond_sequence=crops_input_sequence, | |
negative_cond_sequence=negative_cond_sequence, | |
num_inference_steps=25, | |
num_images_per_prompt=1, | |
guidance_scale=curr_scale, | |
generator=torch.Generator(device="cuda").manual_seed(seed), | |
).image_embeds | |
for seed_2 in range(1): | |
images = self.ip_model.generate( | |
image_prompt_embeds=img_emb, | |
num_samples=1, | |
num_inference_steps=50, | |
) | |
generated_images += images | |
return generated_images | |