ip-composer / IP_Composer /create_grids.py
linoyts's picture
linoyts HF Staff
Upload 64 files
c025a3d verified
raw
history blame
8.57 kB
import argparse
import os
import json
import itertools
from PIL import Image, ImageDraw, ImageFont
def wrap_text(text, max_width, draw, font):
"""
Wrap the text to fit within the given width by breaking it into lines.
"""
lines = []
words = text.split(' ')
current_line = []
for word in words:
current_line.append(word)
line_width = draw.textbbox((0, 0), ' '.join(current_line), font=font)[2]
if line_width > max_width:
current_line.pop()
lines.append(' '.join(current_line))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
return lines
def image_grid_with_titles(imgs, rows, cols, top_titles, left_titles, margin=20):
assert len(imgs) == rows * cols
assert len(top_titles) == cols
assert len(left_titles) == rows
imgs = [img.resize((256, 256)) for img in imgs]
w, h = imgs[0].size
title_height = 50
title_width = 120
grid_width = cols * (w + margin) + title_width + margin
grid_height = rows * (h + margin) + title_height + margin
grid = Image.new('RGB', size=(grid_width, grid_height), color='white')
draw = ImageDraw.Draw(grid)
try:
font = ImageFont.truetype("arial.ttf", 20)
except IOError:
font = ImageFont.load_default()
for i, title in enumerate(top_titles):
wrapped_title = wrap_text(title, w, draw, font)
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title])
y_offset = (title_height - total_text_height) // 2
for line in wrapped_title:
text_width = draw.textbbox((0, 0), line, font=font)[2]
x_offset = ((i * (w + margin)) + title_width + margin + (w - text_width) // 2)
draw.text((x_offset, y_offset), line, fill="black", font=font)
y_offset += draw.textbbox((0, 0), line, font=font)[3]
for i, title in enumerate(left_titles):
wrapped_title = wrap_text(title, title_width - 10, draw, font)
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title])
y_offset = (i * (h + margin)) + title_height + (h - total_text_height) // 2 + margin
for line in wrapped_title:
text_width = draw.textbbox((0, 0), line, font=font)[2]
x_offset = (title_width - text_width) // 2
draw.text((x_offset, y_offset), line, fill="black", font=font)
y_offset += draw.textbbox((0, 0), line, font=font)[3]
for i, img in enumerate(imgs):
x_pos = (i % cols) * (w + margin) + title_width + margin
y_pos = (i // cols) * (h + margin) + title_height + margin
grid.paste(img, box=(x_pos, y_pos))
return grid
def create_grids(config):
num_samples = config["num_samples"]
concept_dirs = config["input_dirs_concepts"]
output_base_dir = config["output_base_dir"]
output_grid_dir = os.path.join(output_base_dir, "grids")
os.makedirs(output_grid_dir, exist_ok=True)
base_images = os.listdir(config["input_dir_base"])
if len(concept_dirs) == 1:
# Special case: Single concept
last_concept_dir = concept_dirs[0]
last_concept_images = os.listdir(last_concept_dir)
top_titles = ["Base Image", "Concept 1"] + ["Samples"] + [""] * (num_samples - 1)
left_titles = ["" for i in range(len(last_concept_images))]
def load_image(path):
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white")
for base_image in base_images:
base_image_path = os.path.join(config["input_dir_base"], base_image)
images = []
for last_image in last_concept_images:
last_image_path = os.path.join(last_concept_dir, last_image)
row_images = [load_image(base_image_path), load_image(last_image_path)]
# Add generated samples for the current row
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_{last_image}")
if os.path.exists(sample_dir):
sample_images = sorted(os.listdir(sample_dir))
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images])
images.extend(row_images)
# Fill empty spaces to match the grid dimensions
total_required = len(left_titles) * len(top_titles)
if len(images) < total_required:
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images)))
# Create the grid
grid = image_grid_with_titles(
imgs=images,
rows=len(left_titles),
cols=len(top_titles),
top_titles=top_titles,
left_titles=left_titles
)
# Save the grid
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_concept1.png")
grid.save(grid_save_path)
print(f"Grid saved at {grid_save_path}")
else:
# General case: Multiple concepts
fixed_concepts = concept_dirs[:-1]
last_concept_dir = concept_dirs[-1]
last_concept_images = os.listdir(last_concept_dir)
top_titles = ["Base Image"] + [f"Concept {i+1}" for i in range(len(fixed_concepts))] + ["Last Concept"] + ["Samples"] + [""] * (num_samples - 1)
left_titles = ["" for i in range(len(last_concept_images))]
def load_image(path):
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white")
fixed_concept_images = [os.listdir(concept_dir) for concept_dir in fixed_concepts]
for base_image in base_images:
base_image_path = os.path.join(config["input_dir_base"], base_image)
fixed_combinations = itertools.product(*fixed_concept_images)
for fixed_combination in fixed_combinations:
images = []
# Build fixed combination row
fixed_images = [load_image(base_image_path)]
for concept_dir, concept_image in zip(fixed_concepts, fixed_combination):
concept_image_path = os.path.join(concept_dir, concept_image)
fixed_images.append(load_image(concept_image_path))
# Iterate over last concept for rows
for last_image in last_concept_images:
last_image_path = os.path.join(last_concept_dir, last_image)
row_images = fixed_images + [load_image(last_image_path)]
# Add generated samples for the current row
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_" + "_".join([f"{concept_image}" for concept_image in fixed_combination]) + f"_{last_image}")
if os.path.exists(sample_dir):
sample_images = sorted(os.listdir(sample_dir))
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images])
images.extend(row_images)
# Fill empty spaces to match the grid dimensions
total_required = len(left_titles) * len(top_titles)
if len(images) < total_required:
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images)))
# Create the grid
grid = image_grid_with_titles(
imgs=images,
rows=len(left_titles),
cols=len(top_titles),
top_titles=top_titles,
left_titles=left_titles
)
# Save the grid
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_combo_{'_'.join(map(str, fixed_combination))}.png")
grid.save(grid_save_path)
print(f"Grid saved at {grid_save_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create image grids based on a configuration file.")
parser.add_argument("config_path", type=str, help="Path to the configuration JSON file.")
args = parser.parse_args()
# Load the configuration
with open(args.config_path, 'r') as f:
config = json.load(f)
if "num_samples" not in config:
config["num_samples"] = 4
create_grids(config)