Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,569 Bytes
c025a3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import argparse
import os
import json
import itertools
from PIL import Image, ImageDraw, ImageFont
def wrap_text(text, max_width, draw, font):
"""
Wrap the text to fit within the given width by breaking it into lines.
"""
lines = []
words = text.split(' ')
current_line = []
for word in words:
current_line.append(word)
line_width = draw.textbbox((0, 0), ' '.join(current_line), font=font)[2]
if line_width > max_width:
current_line.pop()
lines.append(' '.join(current_line))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
return lines
def image_grid_with_titles(imgs, rows, cols, top_titles, left_titles, margin=20):
assert len(imgs) == rows * cols
assert len(top_titles) == cols
assert len(left_titles) == rows
imgs = [img.resize((256, 256)) for img in imgs]
w, h = imgs[0].size
title_height = 50
title_width = 120
grid_width = cols * (w + margin) + title_width + margin
grid_height = rows * (h + margin) + title_height + margin
grid = Image.new('RGB', size=(grid_width, grid_height), color='white')
draw = ImageDraw.Draw(grid)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
for i, title in enumerate(top_titles):
wrapped_title = wrap_text(title, w, draw, font)
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title])
y_offset = (title_height - total_text_height) // 2
for line in wrapped_title:
text_width = draw.textbbox((0, 0), line, font=font)[2]
x_offset = ((i * (w + margin)) + title_width + margin + (w - text_width) // 2)
draw.text((x_offset, y_offset), line, fill="black", font=font)
y_offset += draw.textbbox((0, 0), line, font=font)[3]
for i, title in enumerate(left_titles):
wrapped_title = wrap_text(title, title_width - 10, draw, font)
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title])
y_offset = (i * (h + margin)) + title_height + (h - total_text_height) // 2 + margin
for line in wrapped_title:
text_width = draw.textbbox((0, 0), line, font=font)[2]
x_offset = (title_width - text_width) // 2
draw.text((x_offset, y_offset), line, fill="black", font=font)
y_offset += draw.textbbox((0, 0), line, font=font)[3]
for i, img in enumerate(imgs):
x_pos = (i % cols) * (w + margin) + title_width + margin
y_pos = (i // cols) * (h + margin) + title_height + margin
grid.paste(img, box=(x_pos, y_pos))
return grid
def create_grids(config):
num_samples = config["num_samples"]
concept_dirs = config["input_dirs_concepts"]
output_base_dir = config["output_base_dir"]
output_grid_dir = os.path.join(output_base_dir, "grids")
os.makedirs(output_grid_dir, exist_ok=True)
base_images = os.listdir(config["input_dir_base"])
if len(concept_dirs) == 1:
# Special case: Single concept
last_concept_dir = concept_dirs[0]
last_concept_images = os.listdir(last_concept_dir)
top_titles = ["Base Image", "Concept 1"] + ["Samples"] + [""] * (num_samples - 1)
left_titles = ["" for i in range(len(last_concept_images))]
def load_image(path):
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white")
for base_image in base_images:
base_image_path = os.path.join(config["input_dir_base"], base_image)
images = []
for last_image in last_concept_images:
last_image_path = os.path.join(last_concept_dir, last_image)
row_images = [load_image(base_image_path), load_image(last_image_path)]
# Add generated samples for the current row
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_{last_image}")
if os.path.exists(sample_dir):
sample_images = sorted(os.listdir(sample_dir))
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images])
images.extend(row_images)
# Fill empty spaces to match the grid dimensions
total_required = len(left_titles) * len(top_titles)
if len(images) < total_required:
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images)))
# Create the grid
grid = image_grid_with_titles(
imgs=images,
rows=len(left_titles),
cols=len(top_titles),
top_titles=top_titles,
left_titles=left_titles
)
# Save the grid
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_concept1.png")
grid.save(grid_save_path)
print(f"Grid saved at {grid_save_path}")
else:
# General case: Multiple concepts
fixed_concepts = concept_dirs[:-1]
last_concept_dir = concept_dirs[-1]
last_concept_images = os.listdir(last_concept_dir)
top_titles = ["Base Image"] + [f"Concept {i+1}" for i in range(len(fixed_concepts))] + ["Last Concept"] + ["Samples"] + [""] * (num_samples - 1)
left_titles = ["" for i in range(len(last_concept_images))]
def load_image(path):
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white")
fixed_concept_images = [os.listdir(concept_dir) for concept_dir in fixed_concepts]
for base_image in base_images:
base_image_path = os.path.join(config["input_dir_base"], base_image)
fixed_combinations = itertools.product(*fixed_concept_images)
for fixed_combination in fixed_combinations:
images = []
# Build fixed combination row
fixed_images = [load_image(base_image_path)]
for concept_dir, concept_image in zip(fixed_concepts, fixed_combination):
concept_image_path = os.path.join(concept_dir, concept_image)
fixed_images.append(load_image(concept_image_path))
# Iterate over last concept for rows
for last_image in last_concept_images:
last_image_path = os.path.join(last_concept_dir, last_image)
row_images = fixed_images + [load_image(last_image_path)]
# Add generated samples for the current row
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_" + "_".join([f"{concept_image}" for concept_image in fixed_combination]) + f"_{last_image}")
if os.path.exists(sample_dir):
sample_images = sorted(os.listdir(sample_dir))
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images])
images.extend(row_images)
# Fill empty spaces to match the grid dimensions
total_required = len(left_titles) * len(top_titles)
if len(images) < total_required:
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images)))
# Create the grid
grid = image_grid_with_titles(
imgs=images,
rows=len(left_titles),
cols=len(top_titles),
top_titles=top_titles,
left_titles=left_titles
)
# Save the grid
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_combo_{'_'.join(map(str, fixed_combination))}.png")
grid.save(grid_save_path)
print(f"Grid saved at {grid_save_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create image grids based on a configuration file.")
parser.add_argument("config_path", type=str, help="Path to the configuration JSON file.")
args = parser.parse_args()
# Load the configuration
with open(args.config_path, 'r') as f:
config = json.load(f)
if "num_samples" not in config:
config["num_samples"] = 4
create_grids(config)
|