File size: 3,149 Bytes
79c5088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import numpy as np

import os
import re
import jsonc as json
from PIL import Image


def img_list_to_pil(img_list, cond_image = None, seperation = 10):
    if cond_image is not None:
        img_list.append(cond_image)

    widths, heights = zip(*(i.size for i in img_list))
    total_width = sum(widths) + seperation * len(img_list)
    max_height = max(heights)

    new_im = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    for im in img_list:
        new_im.paste(im, (x_offset, 0))
        x_offset += im.size[0] + seperation

    return new_im


def grid_image_visualize(images, row_size):
    widths, heights = zip(*(i.size for i in images))
    total_width = max(widths) * row_size + 10 * (row_size - 1)
    max_height = max(heights) * ((len(images) + row_size - 1) // row_size)
    new_im = Image.new('RGB', (total_width, max_height))

    x_offset = 0
    y_offset = 0
    for i, im in enumerate(images):
        new_im.paste(im, (x_offset, y_offset))
        x_offset += im.size[0] + 10
        if (i + 1) % row_size == 0:
            x_offset = 0
            y_offset += im.size[1]

    return new_im

def process_images(images, res=512):
    res_images = []
    for image in images:
        crop_size = min(image.size)
        
        left = (image.size[0] - crop_size) // 2
        top = (image.size[1] - crop_size) // 2
        right = (image.size[0] + crop_size) // 2
        bottom = (image.size[1] + crop_size) // 2

        image = image.crop((left, top, right, bottom))
        image = image.resize((res, res), Image.BILINEAR)
        res_images.append(image)
    return res_images

def sanitize_prompt(prompt: str, max_len: int = 50) -> str:
    sanitized = re.sub(r'[^a-zA-Z0-9_\-]+', '_', prompt)
    return sanitized[:max_len].strip("_")

def get_next_index(folder_path: str) -> int:
    if not os.path.exists(folder_path):
        return 0

    pattern = re.compile(r'.*_(\d+)\.(?:png|json)$')
    max_index = -1

    for filename in os.listdir(folder_path):
        match = pattern.match(filename)
        if match:
            idx = int(match.group(1))
            if idx > max_index:
                max_index = idx

    return max_index + 1

def save_results(
    args,
    source_prompt: str,
    target_prompt: str,
    images: Image.Image,
):
    src_name = sanitize_prompt(source_prompt)
    tgt_name = sanitize_prompt(target_prompt)
    folder_name = f"{src_name}#{tgt_name}"

    output_dir = os.path.join(args.output_dir, folder_name)
    os.makedirs(output_dir, exist_ok=True)

    next_idx = get_next_index(output_dir)

    concated_image = img_list_to_pil([images[0], images[-1]], cond_image=None, seperation=10)
    concated_image.save(os.path.join(output_dir, f"concat_{next_idx}.png"))
    images[0].save(os.path.join(output_dir, f"input_{next_idx}.png"))
    images[-1].save(os.path.join(output_dir, f"output_{next_idx}.png"))

    args_filename = f"args_{next_idx}.json"
    args_path = os.path.join(output_dir, args_filename)

    with open(args_path, "w") as f:
        json.dump(vars(args), f, indent=4)

    print(f"Saved image to {output_dir} and args to {args_path}")