Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
import json | |
import itertools | |
from PIL import Image, ImageDraw, ImageFont | |
def wrap_text(text, max_width, draw, font): | |
""" | |
Wrap the text to fit within the given width by breaking it into lines. | |
""" | |
lines = [] | |
words = text.split(' ') | |
current_line = [] | |
for word in words: | |
current_line.append(word) | |
line_width = draw.textbbox((0, 0), ' '.join(current_line), font=font)[2] | |
if line_width > max_width: | |
current_line.pop() | |
lines.append(' '.join(current_line)) | |
current_line = [word] | |
if current_line: | |
lines.append(' '.join(current_line)) | |
return lines | |
def image_grid_with_titles(imgs, rows, cols, top_titles, left_titles, margin=20): | |
assert len(imgs) == rows * cols | |
assert len(top_titles) == cols | |
assert len(left_titles) == rows | |
imgs = [img.resize((256, 256)) for img in imgs] | |
w, h = imgs[0].size | |
title_height = 50 | |
title_width = 120 | |
grid_width = cols * (w + margin) + title_width + margin | |
grid_height = rows * (h + margin) + title_height + margin | |
grid = Image.new('RGB', size=(grid_width, grid_height), color='white') | |
draw = ImageDraw.Draw(grid) | |
try: | |
font = ImageFont.truetype("arial.ttf", 20) | |
except IOError: | |
font = ImageFont.load_default() | |
for i, title in enumerate(top_titles): | |
wrapped_title = wrap_text(title, w, draw, font) | |
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title]) | |
y_offset = (title_height - total_text_height) // 2 | |
for line in wrapped_title: | |
text_width = draw.textbbox((0, 0), line, font=font)[2] | |
x_offset = ((i * (w + margin)) + title_width + margin + (w - text_width) // 2) | |
draw.text((x_offset, y_offset), line, fill="black", font=font) | |
y_offset += draw.textbbox((0, 0), line, font=font)[3] | |
for i, title in enumerate(left_titles): | |
wrapped_title = wrap_text(title, title_width - 10, draw, font) | |
total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title]) | |
y_offset = (i * (h + margin)) + title_height + (h - total_text_height) // 2 + margin | |
for line in wrapped_title: | |
text_width = draw.textbbox((0, 0), line, font=font)[2] | |
x_offset = (title_width - text_width) // 2 | |
draw.text((x_offset, y_offset), line, fill="black", font=font) | |
y_offset += draw.textbbox((0, 0), line, font=font)[3] | |
for i, img in enumerate(imgs): | |
x_pos = (i % cols) * (w + margin) + title_width + margin | |
y_pos = (i // cols) * (h + margin) + title_height + margin | |
grid.paste(img, box=(x_pos, y_pos)) | |
return grid | |
def create_grids(config): | |
num_samples = config["num_samples"] | |
concept_dirs = config["input_dirs_concepts"] | |
output_base_dir = config["output_base_dir"] | |
output_grid_dir = os.path.join(output_base_dir, "grids") | |
os.makedirs(output_grid_dir, exist_ok=True) | |
base_images = os.listdir(config["input_dir_base"]) | |
if len(concept_dirs) == 1: | |
# Special case: Single concept | |
last_concept_dir = concept_dirs[0] | |
last_concept_images = os.listdir(last_concept_dir) | |
top_titles = ["Base Image", "Concept 1"] + ["Samples"] + [""] * (num_samples - 1) | |
left_titles = ["" for i in range(len(last_concept_images))] | |
def load_image(path): | |
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white") | |
for base_image in base_images: | |
base_image_path = os.path.join(config["input_dir_base"], base_image) | |
images = [] | |
for last_image in last_concept_images: | |
last_image_path = os.path.join(last_concept_dir, last_image) | |
row_images = [load_image(base_image_path), load_image(last_image_path)] | |
# Add generated samples for the current row | |
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_{last_image}") | |
if os.path.exists(sample_dir): | |
sample_images = sorted(os.listdir(sample_dir)) | |
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images]) | |
images.extend(row_images) | |
# Fill empty spaces to match the grid dimensions | |
total_required = len(left_titles) * len(top_titles) | |
if len(images) < total_required: | |
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images))) | |
# Create the grid | |
grid = image_grid_with_titles( | |
imgs=images, | |
rows=len(left_titles), | |
cols=len(top_titles), | |
top_titles=top_titles, | |
left_titles=left_titles | |
) | |
# Save the grid | |
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_concept1.png") | |
grid.save(grid_save_path) | |
print(f"Grid saved at {grid_save_path}") | |
else: | |
# General case: Multiple concepts | |
fixed_concepts = concept_dirs[:-1] | |
last_concept_dir = concept_dirs[-1] | |
last_concept_images = os.listdir(last_concept_dir) | |
top_titles = ["Base Image"] + [f"Concept {i+1}" for i in range(len(fixed_concepts))] + ["Last Concept"] + ["Samples"] + [""] * (num_samples - 1) | |
left_titles = ["" for i in range(len(last_concept_images))] | |
def load_image(path): | |
return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white") | |
fixed_concept_images = [os.listdir(concept_dir) for concept_dir in fixed_concepts] | |
for base_image in base_images: | |
base_image_path = os.path.join(config["input_dir_base"], base_image) | |
fixed_combinations = itertools.product(*fixed_concept_images) | |
for fixed_combination in fixed_combinations: | |
images = [] | |
# Build fixed combination row | |
fixed_images = [load_image(base_image_path)] | |
for concept_dir, concept_image in zip(fixed_concepts, fixed_combination): | |
concept_image_path = os.path.join(concept_dir, concept_image) | |
fixed_images.append(load_image(concept_image_path)) | |
# Iterate over last concept for rows | |
for last_image in last_concept_images: | |
last_image_path = os.path.join(last_concept_dir, last_image) | |
row_images = fixed_images + [load_image(last_image_path)] | |
# Add generated samples for the current row | |
sample_dir = os.path.join(output_base_dir, f"{base_image}_to_" + "_".join([f"{concept_image}" for concept_image in fixed_combination]) + f"_{last_image}") | |
if os.path.exists(sample_dir): | |
sample_images = sorted(os.listdir(sample_dir)) | |
row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images]) | |
images.extend(row_images) | |
# Fill empty spaces to match the grid dimensions | |
total_required = len(left_titles) * len(top_titles) | |
if len(images) < total_required: | |
images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images))) | |
# Create the grid | |
grid = image_grid_with_titles( | |
imgs=images, | |
rows=len(left_titles), | |
cols=len(top_titles), | |
top_titles=top_titles, | |
left_titles=left_titles | |
) | |
# Save the grid | |
grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_combo_{'_'.join(map(str, fixed_combination))}.png") | |
grid.save(grid_save_path) | |
print(f"Grid saved at {grid_save_path}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Create image grids based on a configuration file.") | |
parser.add_argument("config_path", type=str, help="Path to the configuration JSON file.") | |
args = parser.parse_args() | |
# Load the configuration | |
with open(args.config_path, 'r') as f: | |
config = json.load(f) | |
if "num_samples" not in config: | |
config["num_samples"] = 4 | |
create_grids(config) | |