File size: 20,915 Bytes
319886d
 
 
 
 
 
 
 
 
af44a4b
 
319886d
af44a4b
319886d
40fb840
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af44a4b
 
 
319886d
af44a4b
 
 
 
319886d
af44a4b
 
 
 
319886d
af44a4b
319886d
af44a4b
 
 
 
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccc05b7
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e96e96
319886d
8e96e96
319886d
 
 
af44a4b
 
 
 
319886d
 
 
 
 
 
 
 
 
 
 
af44a4b
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42002e4
 
319886d
 
 
 
 
 
 
af44a4b
319886d
 
af44a4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319886d
 
af44a4b
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b97b9f3
 
b4faa43
af44a4b
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42002e4
319886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d1683b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319886d
 
 
 
 
 
 
5e74212
 
319886d
 
 
 
 
 
 
 
af44a4b
319886d
af44a4b
 
 
 
 
 
 
 
 
444c26f
 
 
319886d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468

import random
from einops import rearrange
from diffusers.models import AutoencoderKL
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from models.sampling import prepare_modified
from models.util import load_clip, load_t5, load_flow_model
from transport import Sampler, create_transport
from data.imgproc import to_rgb_if_rgba


def center_crop(image, target_size):
    width, height = image.size
    new_width, new_height = target_size

    left = (width - new_width) // 2
    top = (height - new_height) // 2
    right = left + new_width
    bottom = top + new_height

    return image.crop((left, top, right, bottom))


def resize_with_aspect_ratio(img, resolution, divisible=16, aspect_ratio=None):
    """Resize image while maintaining aspect ratio, ensuring area is close to resolution**2 and dimensions are divisible by 16
    
    Args:
        img: PIL Image or torch.Tensor (C,H,W)/(B,C,H,W)
        resolution: target resolution
        divisible: ensure output dimensions are divisible by this number
    
    Returns:
        Resized image of the same type as input
    """
    # Check input type and get dimensions
    is_tensor = isinstance(img, torch.Tensor)
    if is_tensor:
        if img.dim() == 3:
            c, h, w = img.shape
            batch_dim = False
        else:
            b, c, h, w = img.shape
            batch_dim = True
    else:
        w, h = img.size
        
    # Calculate new dimensions
    if aspect_ratio is None:
        aspect_ratio = w / h
    target_area = resolution * resolution
    new_h = int((target_area / aspect_ratio) ** 0.5)
    new_w = int(new_h * aspect_ratio)
    
    # Ensure divisible by divisible
    new_w = max(new_w // divisible, 1) * divisible
    new_h = max(new_h // divisible, 1) * divisible
    
    # Adjust size based on input type
    if is_tensor:
        # Use torch interpolation method
        mode = 'bilinear'
        align_corners = False
        if batch_dim:
            return F.interpolate(img, size=(new_h, new_w), 
                               mode=mode, align_corners=align_corners)
        else:
            return F.interpolate(img.unsqueeze(0), size=(new_h, new_w),
                               mode=mode, align_corners=align_corners).squeeze(0)
    else:
        # Use PIL LANCZOS resampling
        return img.resize((new_w, new_h), Image.LANCZOS)


class VisualClozeModel:
    def __init__(
        self, model_path, model_name="flux-dev-fill-lora", max_length=512, lora_rank=256, 
        atol=1e-6, rtol=1e-3, solver='euler', time_shifting_factor=1, 
        resolution=384, precision='bf16'):
        self.atol = atol
        self.rtol = rtol
        self.solver = solver
        self.time_shifting_factor = time_shifting_factor
        self.resolution = resolution
        self.precision = precision
        self.max_length = max_length
        self.lora_rank = lora_rank
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
        
        # Initialize model
        print("Initializing model...")
        self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
        
        # Initialize VAE
        print("Initializing VAE...")
        self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
        self.ae.requires_grad_(False)
        
        # Initialize text encoders
        print("Initializing text encoders...")
        self.t5 = load_t5(self.device, max_length=self.max_length)
        self.clip = load_clip(self.device)
        
        self.model.eval().to(self.device, dtype=self.dtype)
        
        # Load model weights
        ckpt = torch.load(model_path)
        self.model.load_state_dict(ckpt, strict=False)
        del ckpt
        
        # Initialize sampler
        transport = create_transport(
            "Linear",
            "velocity",
            do_shift=True,
        ) 
        self.sampler = Sampler(transport)
        self.sample_fn = self.sampler.sample_ode(
            sampling_method=self.solver,
            num_steps=30,
            atol=self.atol,
            rtol=self.rtol,
            reverse=False,
            do_shift=True,
            time_shifting_factor=self.time_shifting_factor,
        )
        
        # Image transformation
        self.image_transform = transforms.Compose([
            transforms.Lambda(lambda img: to_rgb_if_rgba(img)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
        
        self.grid_h = None
        self.grid_w = None
        
    def set_grid_size(self, h, w):
        """Set grid size"""
        self.grid_h = h
        self.grid_w = w
    
    @torch.no_grad
    def upsampling(self, image, target_size, cfg, upsampling_steps, upsampling_noise, generator, content_prompt):
        content_instruction = [
            "The content of the last image in the final row is: ",
            "The last image of the last row depicts: ",
            "In the final row, the last image shows: ",
            "The last image in the bottom row illustrates: ",
            "The content of the bottom-right image is: ",
            "The final image in the last row portrays: ",
            "The last image of the final row displays: ",
            "In the last row, the final image captures: ",
            "The bottom-right corner image presents: ",
            "The content of the last image in the concluding row is: ",
            "In the last row, ",
            "The editing instruction in the last row is: ", 
        ]
        for c in content_instruction:
            if content_prompt.startswith(c):
                content_prompt = content_prompt.replace(c, '')
        
        if target_size is None:
            aspect_ratio = 1
            target_area = 1024 * 1024
            new_h = int((target_area / aspect_ratio) ** 0.5)
            new_w = int(new_h * aspect_ratio)
            target_size = (new_w, new_h)

        if target_size[0] * target_size[1] > 1024 * 1024:
            aspect_ratio = target_size[0] / target_size[1]
            target_area = 1024 * 1024
            new_h = int((target_area / aspect_ratio) ** 0.5)
            new_w = int(new_h * aspect_ratio)
            target_size = (new_w, new_h)
        
        image = image.resize(((target_size[0] // 16) * 16, (target_size[1] // 16) * 16))
        if upsampling_noise >= 1.0:
            return image

        self.sample_fn = self.sampler.sample_ode(
            sampling_method=self.solver,
            num_steps=upsampling_steps,
            atol=self.atol,
            rtol=self.rtol,
            reverse=False,
            do_shift=False,
            time_shifting_factor=1.0, 
            strength=upsampling_noise
        )

        processed_image = self.image_transform(image)
        processed_image = processed_image.to(self.device, non_blocking=True)
        blank = torch.zeros_like(processed_image, device=self.device, dtype=self.dtype)
        mask = torch.full((1, 1, processed_image.shape[1], processed_image.shape[2]), fill_value=1, device=self.device, dtype=self.dtype)
        with torch.no_grad():
            latent = self.ae.encode(processed_image[None].to(self.ae.dtype)).latent_dist.sample()
            blank = self.ae.encode(blank[None].to(self.ae.dtype)).latent_dist.sample()
            latent = (latent - self.ae.config.shift_factor) * self.ae.config.scaling_factor
            blank = (blank - self.ae.config.shift_factor) * self.ae.config.scaling_factor
            latent_h, latent_w = latent.shape[2:]

            mask = rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) 
            mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
            
            latent = latent.to(self.dtype)
            blank = blank.to(self.dtype)
            latent = rearrange(latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
            blank = rearrange(blank, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
            
            img_cond = torch.cat((blank, mask), dim=-1)
    
            # Generate noise
            noise = torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=generator).to(self.dtype)
            x = [[noise]]
            
            inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[content_prompt], proportion_empty_prompts=0.0)
            inp["img"] = inp["img"] * (1 - upsampling_noise) + latent * upsampling_noise
            model_kwargs = dict(
                txt=inp["txt"], 
                txt_ids=inp["txt_ids"], 
                txt_mask=inp["txt_mask"],
                y=inp["vec"], 
                img_ids=inp["img_ids"], 
                img_mask=inp["img_mask"], 
                cond=img_cond,
                guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype),
            )
            sample = self.sample_fn(
                inp["img"], self.model.forward, model_kwargs
            )[-1]
            
            sample = sample[:1]
            sample = rearrange(sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h // 2, w=latent_w // 2)
            sample = self.ae.decode(sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0]
            sample = (sample + 1.0) / 2.0
            sample.clamp_(0.0, 1.0)
            sample = sample[0]
            
            output_image = to_pil_image(sample.float())
            
            return output_image
    
    def process_images(
        self, images: list[list[Image.Image]], 
        prompts: list[str], 
        seed: int = 0, 
        cfg: int = 30, 
        steps: int = 30, 
        upsampling_steps: int = 10, 
        upsampling_noise: float = 0.4, 
        is_upsampling: bool =True):
        """
        Processes a list of images based on the provided text prompts and settings, with optional upsampling to enhance image resolution or detail.

        Parameters:
            images (list[list[Image.Image]]): A collection of images arranged in a grid layout, where each row represents an in-context example or the current query. 
            The current query should be placed in the last row. The target image may be None in the input, while all other images should be of the PIL Image type (Image.Image).
            
            prompts (list[str]): A list containing three prompts: the layout prompt, task prompt, and content prompt, respectively.
            
            seed (int): A fixed integer seed to ensure reproducibility of random elements during processing.
            
            cfg (int): The strength of Classifier-Free Diffusion Guidance, which controls the degree of influence over the generated results.
            
            steps (int): The number of sampling steps to be performed during processing.
            
            upsampling_steps (int): The number of denoising steps to apply when performing upsampling.
            
            upsampling_noise (float): The noise level used as a starting point when upsampling with SDEdit. A higher value reduces noise, and setting it to 1 disables SDEdit, causing the PIL resize function to be used instead.
            
            is_upsampling (bool, optional): A flag indicating whether upsampling should be applied using SDEdit.

        Returns:
            Processed images resulting from the algorithm, with optional upsampling applied based on the `is_upsampling` flag.
        """
        
        if seed == 0:
            seed = random.randint(0, 2 ** 32 - 1)
        
        self.sample_fn = self.sampler.sample_ode(
            sampling_method=self.solver,
            num_steps=steps,
            atol=self.atol,
            rtol=self.rtol,
            reverse=False,
            do_shift=True,
            time_shifting_factor=self.time_shifting_factor,
        )

        # Use class grid size
        grid_h, grid_w = self.grid_h, self.grid_w
        
        # Ensure all images are RGB mode or None
        for i in range(0, grid_h):
            images[i] = [img.convert("RGB") if img is not None else None for img in images[i]]
        
        # Adjust all image sizes
        resolution = self.resolution
        processed_images = []
        mask_position = []
        target_size = None
        upsampling_size = None
        
        for i in range(grid_h):
            # Find the size of the first non-empty image in this row
            reference_size = None
            for j in range(0, grid_w):
                if images[i][j] is not None:
                    if i == grid_h - 1 and upsampling_size is None:
                        upsampling_size = images[i][j].size

                    resized = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None)
                    reference_size = resized.size
                    if i == grid_h - 1 and target_size is None:
                        target_size = reference_size
                    break
            
            # Process all images in this row
            for j in range(0, grid_w):
                if images[i][j] is not None:
                    target = resize_with_aspect_ratio(images[i][j], resolution, aspect_ratio=None)
                    if target.width <= target.height:
                        target = target.resize((reference_size[0], int(reference_size[0] / target.width * target.height)))
                        target = center_crop(target, reference_size)
                    elif target.width > target.height:
                        target = target.resize((int(reference_size[1] / target.height * target.width), reference_size[1]))
                        target = center_crop(target, reference_size)
                    
                    processed_images.append(target)
                    if i == grid_h - 1:
                        mask_position.append(0)
                else:
                    # If this row has a reference size, use it; otherwise use default size
                    if reference_size:
                        blank = Image.new('RGB', reference_size, (0, 0, 0))
                    else:
                        blank = Image.new('RGB', (resolution, resolution), (0, 0, 0))
                    processed_images.append(blank)
                    if i == grid_h - 1:
                        mask_position.append(1)
                    else:
                        raise ValueError('Please provide each image in the in-context example.')
            
        # return processed_images
        
        if len(mask_position) > 1 and sum(mask_position) > 1:
            if target_size is None:
                new_w = 384
            else:
                new_w = target_size[0]
            for i in range(len(processed_images)):
                if processed_images[i] is not None:
                    new_h = int(processed_images[i].height * (new_w / processed_images[i].width))
                    new_w = int(new_w / 16) * 16
                    new_h = int(new_h / 16) * 16
                    processed_images[i] = processed_images[i].resize((new_w, new_h))
                
        # Build grid image and mask
        with torch.autocast("cuda", self.dtype):
            grid_image = []
            fill_mask = []
            for i in range(grid_h):
                row_images = [self.image_transform(img) for img in processed_images[i * grid_w: (i + 1) * grid_w]]
                if i == grid_h - 1:
                    row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=m, device=self.device) for m in mask_position]
                else:
                    row_masks = [torch.full((1, 1, row_images[0].shape[1], row_images[0].shape[2]), fill_value=0, device=self.device) for m in mask_position]

                grid_image.append(torch.cat(row_images, dim=2).to(self.device, non_blocking=True))
                fill_mask.append(torch.cat(row_masks, dim=3))
            # Encode condition image
            with torch.no_grad():
                fill_cond = [self.ae.encode(img[None].to(self.ae.dtype)).latent_dist.sample()[0] for img in grid_image]
                fill_cond = [(img - self.ae.config.shift_factor) * self.ae.config.scaling_factor for img in fill_cond]
                
                # Rearrange mask
                fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (c ph pw) h w", ph=8, pw=8) for mask in fill_mask]
                fill_mask = [rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for mask in fill_mask]
            
            fill_cond = [img.to(self.dtype) for img in fill_cond]
            fill_cond = [rearrange(img.unsqueeze(0), "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) for img in fill_cond]
            
            fill_cond =  torch.cat(fill_cond, dim=1)
            fill_mask =  torch.cat(fill_mask, dim=1)
            img_cond = torch.cat((fill_cond, fill_mask), dim=-1)
        
            # Generate sample
            noise = []
            sliced_subimage = []
            rng = torch.Generator(device=self.device).manual_seed(int(seed))
            for sub_img in grid_image:
                h, w = sub_img.shape[-2:]
                sliced_subimage.append((h, w))
                latent_w, latent_h = w // 8, h // 8
                noise.append(torch.randn([1, 16, latent_h, latent_w], device=self.device, generator=rng).to(self.dtype))
            x = [noise]
            
            with torch.no_grad():
                inp = prepare_modified(t5=self.t5, clip=self.clip, img=x, prompt=[' '.join(prompts)], proportion_empty_prompts=0.0)
                
                model_kwargs = dict(
                    txt=inp["txt"], 
                    txt_ids=inp["txt_ids"], 
                    txt_mask=inp["txt_mask"],
                    y=inp["vec"], 
                    img_ids=inp["img_ids"], 
                    img_mask=inp["img_mask"], 
                    cond=img_cond,
                    guidance=torch.full((1,), cfg, device=self.device, dtype=self.dtype),
                )
                samples = self.sample_fn(
                    inp["img"], self.model.forward, model_kwargs
                )[-1]

            # Get query row
            with torch.no_grad():
                samples = samples[:1]
                row_samples = []
                start = 0
                for size in sliced_subimage:
                    end = start + (size[0] * size[1] // 256)
                    latent_h = size[0] // 8
                    latent_w = size[1] // 8
                    row_sample = samples[:, start:end, :]
                    row_sample = rearrange(row_sample, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=latent_h//2, w=latent_w//2)
                    row_sample = self.ae.decode(row_sample / self.ae.config.scaling_factor + self.ae.config.shift_factor)[0]
                    row_sample = (row_sample + 1.0) / 2.0
                    row_sample.clamp_(0.0, 1.0)
                    row_samples.append(row_sample[0])
                    start = end
            
            # Convert all samples to PIL images
            output_images = []
            for row_sample in row_samples:
                output_image = to_pil_image(row_sample.float())
                output_images.append(output_image)
            
            torch.cuda.empty_cache()
            
            ret = []
            ret_w = output_images[-1].width
            ret_h = output_images[-1].height
            
            row_start = (grid_h - 1) * grid_w
            row_end = grid_h * grid_w
            for i in range(row_start, row_end):
                # when the image is masked, then output it
                if mask_position[i - row_start] and is_upsampling:
                    cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h))
                    upsampled = self.upsampling(
                        cropped, 
                        upsampling_size, 
                        cfg, 
                        upsampling_steps=upsampling_steps, 
                        upsampling_noise=upsampling_noise, 
                        generator=rng, 
                        content_prompt=prompts[2])
                    ret.append(upsampled)
                elif mask_position[i - row_start]:
                    cropped = output_images[-1].crop(((i - row_start) * ret_w // self.grid_w, 0, ((i - row_start) + 1) * ret_w // self.grid_w, ret_h))
                    ret.append(cropped)
            
            return ret