File size: 6,492 Bytes
779c9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import random
from pathlib import Path
from typing import Optional

import numpy as np
import pyrallis
import torch
from diffusers import (
    StableDiffusionXLPipeline,
)
from huggingface_hub import hf_hub_download
from PIL import Image

from ip_adapter import IPAdapterPlusXL
from model.dit import DiT_Llama
from model.pipeline_pit import PiTPipeline
from training.train_config import TrainConfig


def paste_on_background(image, background, min_scale=0.4, max_scale=0.8, scale=None):
    # Calculate aspect ratio and determine resizing based on the smaller dimension of the background
    aspect_ratio = image.width / image.height
    scale = random.uniform(min_scale, max_scale) if scale is None else scale
    new_width = int(min(background.width, background.height * aspect_ratio) * scale)
    new_height = int(new_width / aspect_ratio)

    # Resize image and calculate position
    image = image.resize((new_width, new_height), resample=Image.LANCZOS)
    pos_x = random.randint(0, background.width - new_width)
    pos_y = random.randint(0, background.height - new_height)

    # Paste the image using its alpha channel as mask if present
    background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
    return background


def set_seed(seed: int):
    """Ensures reproducibility across multiple libraries."""
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy random module
    torch.manual_seed(seed)  # PyTorch CPU random seed
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU random seed
    torch.backends.cudnn.deterministic = True  # Ensures deterministic behavior
    torch.backends.cudnn.benchmark = False  # Disable benchmarking to avoid randomness


class PiTDemoPipeline:
    def __init__(self, prior_repo: str, prior_path: str):
        # Download model and config
        prior_ckpt_path = hf_hub_download(
            repo_id=prior_repo,
            filename=str(prior_path),
            local_dir="pretrained_models",
        )
        prior_cfg_path = hf_hub_download(
            repo_id=prior_repo, filename=str(Path(prior_path).parent / "cfg.yaml"), local_dir="pretrained_models"
        )
        self.model_cfg: TrainConfig = pyrallis.load(TrainConfig, open(prior_cfg_path, "r"))

        self.weight_dtype = torch.float32
        self.device = "cuda:0"
        prior = DiT_Llama(
            embedding_dim=2048,
            hidden_dim=self.model_cfg.hidden_dim,
            n_layers=self.model_cfg.num_layers,
            n_heads=self.model_cfg.num_attention_heads,
        )
        prior.load_state_dict(torch.load(prior_ckpt_path))
        image_pipe = StableDiffusionXLPipeline.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            torch_dtype=torch.float16,
            add_watermarker=False,
        )
        ip_ckpt_path = hf_hub_download(
            repo_id="h94/IP-Adapter",
            filename="ip-adapter-plus_sdxl_vit-h.bin",
            subfolder="sdxl_models",
            local_dir="pretrained_models",
        )

        self.ip_model = IPAdapterPlusXL(
            image_pipe,
            "models/image_encoder",
            ip_ckpt_path,
            self.device,
            num_tokens=16,
        )
        self.image_processor = self.ip_model.clip_image_processor

        empty_image = Image.new("RGB", (256, 256), (255, 255, 255))
        zero_image = torch.Tensor(self.image_processor(empty_image)["pixel_values"][0])
        self.zero_image_embeds = self.ip_model.get_image_embeds(zero_image.unsqueeze(0), skip_uncond=True)

        prior_pipeline = PiTPipeline(
            prior=prior,
        )
        self.prior_pipeline = prior_pipeline.to(self.device)
        set_seed(42)

    def run(self, crops_paths: list[str], scale: float = 2.0, seed: Optional[int] = None, n_images: int = 1):
        if seed is not None:
            set_seed(seed)
        processed_crops = []
        input_images = []

        crops_paths = [None] + crops_paths
        # Extend to >3 with Nones
        while len(crops_paths) < 3:
            crops_paths.append(None)

        for path_ind, path in enumerate(crops_paths):
            if path is None:
                image = Image.new("RGB", (224, 224), (255, 255, 255))
            else:
                image = Image.open(path).convert("RGB")
                if path_ind > 0 or not self.model_cfg.use_ref:
                    background = Image.new("RGB", (1024, 1024), (255, 255, 255))
                    image = paste_on_background(image, background, scale=0.92)
                else:
                    image = image.resize((1024, 1024))
                input_images.append(image)
                # Name should be parent directory name
            processed_image = (
                torch.Tensor(self.image_processor(image)["pixel_values"][0])
                .to(self.device)
                .unsqueeze(0)
                .to(self.weight_dtype)
            )
            processed_crops.append(processed_image)

        image_embed_inputs = []
        for crop_ind in range(len(processed_crops)):
            image_embed_inputs.append(self.ip_model.get_image_embeds(processed_crops[crop_ind], skip_uncond=True))
        crops_input_sequence = torch.cat(image_embed_inputs, dim=1)
        generated_images = []
        for _ in range(n_images):
            seed = random.randint(0, 1000000)
            for curr_scale in [scale]:
                negative_cond_sequence = torch.zeros_like(crops_input_sequence)
                embeds_len = self.zero_image_embeds.shape[1]
                for i in range(0, negative_cond_sequence.shape[1], embeds_len):
                    negative_cond_sequence[:, i : i + embeds_len] = self.zero_image_embeds.detach()

                img_emb = self.prior_pipeline(
                    cond_sequence=crops_input_sequence,
                    negative_cond_sequence=negative_cond_sequence,
                    num_inference_steps=25,
                    num_images_per_prompt=1,
                    guidance_scale=curr_scale,
                    generator=torch.Generator(device="cuda").manual_seed(seed),
                ).image_embeds

                for seed_2 in range(1):
                    images = self.ip_model.generate(
                        image_prompt_embeds=img_emb,
                        num_samples=1,
                        num_inference_steps=50,
                    )
                    generated_images += images

        return generated_images