import argparse import os import json import itertools from PIL import Image, ImageDraw, ImageFont def wrap_text(text, max_width, draw, font): """ Wrap the text to fit within the given width by breaking it into lines. """ lines = [] words = text.split(' ') current_line = [] for word in words: current_line.append(word) line_width = draw.textbbox((0, 0), ' '.join(current_line), font=font)[2] if line_width > max_width: current_line.pop() lines.append(' '.join(current_line)) current_line = [word] if current_line: lines.append(' '.join(current_line)) return lines def image_grid_with_titles(imgs, rows, cols, top_titles, left_titles, margin=20): assert len(imgs) == rows * cols assert len(top_titles) == cols assert len(left_titles) == rows imgs = [img.resize((256, 256)) for img in imgs] w, h = imgs[0].size title_height = 50 title_width = 120 grid_width = cols * (w + margin) + title_width + margin grid_height = rows * (h + margin) + title_height + margin grid = Image.new('RGB', size=(grid_width, grid_height), color='white') draw = ImageDraw.Draw(grid) try: font = ImageFont.truetype("arial.ttf", 20) except IOError: font = ImageFont.load_default() for i, title in enumerate(top_titles): wrapped_title = wrap_text(title, w, draw, font) total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title]) y_offset = (title_height - total_text_height) // 2 for line in wrapped_title: text_width = draw.textbbox((0, 0), line, font=font)[2] x_offset = ((i * (w + margin)) + title_width + margin + (w - text_width) // 2) draw.text((x_offset, y_offset), line, fill="black", font=font) y_offset += draw.textbbox((0, 0), line, font=font)[3] for i, title in enumerate(left_titles): wrapped_title = wrap_text(title, title_width - 10, draw, font) total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title]) y_offset = (i * (h + margin)) + title_height + (h - total_text_height) // 2 + margin for line in wrapped_title: text_width = draw.textbbox((0, 0), line, font=font)[2] x_offset = (title_width - text_width) // 2 draw.text((x_offset, y_offset), line, fill="black", font=font) y_offset += draw.textbbox((0, 0), line, font=font)[3] for i, img in enumerate(imgs): x_pos = (i % cols) * (w + margin) + title_width + margin y_pos = (i // cols) * (h + margin) + title_height + margin grid.paste(img, box=(x_pos, y_pos)) return grid def create_grids(config): num_samples = config["num_samples"] concept_dirs = config["input_dirs_concepts"] output_base_dir = config["output_base_dir"] output_grid_dir = os.path.join(output_base_dir, "grids") os.makedirs(output_grid_dir, exist_ok=True) base_images = os.listdir(config["input_dir_base"]) if len(concept_dirs) == 1: # Special case: Single concept last_concept_dir = concept_dirs[0] last_concept_images = os.listdir(last_concept_dir) top_titles = ["Base Image", "Concept 1"] + ["Samples"] + [""] * (num_samples - 1) left_titles = ["" for i in range(len(last_concept_images))] def load_image(path): return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white") for base_image in base_images: base_image_path = os.path.join(config["input_dir_base"], base_image) images = [] for last_image in last_concept_images: last_image_path = os.path.join(last_concept_dir, last_image) row_images = [load_image(base_image_path), load_image(last_image_path)] # Add generated samples for the current row sample_dir = os.path.join(output_base_dir, f"{base_image}_to_{last_image}") if os.path.exists(sample_dir): sample_images = sorted(os.listdir(sample_dir)) row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images]) images.extend(row_images) # Fill empty spaces to match the grid dimensions total_required = len(left_titles) * len(top_titles) if len(images) < total_required: images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images))) # Create the grid grid = image_grid_with_titles( imgs=images, rows=len(left_titles), cols=len(top_titles), top_titles=top_titles, left_titles=left_titles ) # Save the grid grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_concept1.png") grid.save(grid_save_path) print(f"Grid saved at {grid_save_path}") else: # General case: Multiple concepts fixed_concepts = concept_dirs[:-1] last_concept_dir = concept_dirs[-1] last_concept_images = os.listdir(last_concept_dir) top_titles = ["Base Image"] + [f"Concept {i+1}" for i in range(len(fixed_concepts))] + ["Last Concept"] + ["Samples"] + [""] * (num_samples - 1) left_titles = ["" for i in range(len(last_concept_images))] def load_image(path): return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white") fixed_concept_images = [os.listdir(concept_dir) for concept_dir in fixed_concepts] for base_image in base_images: base_image_path = os.path.join(config["input_dir_base"], base_image) fixed_combinations = itertools.product(*fixed_concept_images) for fixed_combination in fixed_combinations: images = [] # Build fixed combination row fixed_images = [load_image(base_image_path)] for concept_dir, concept_image in zip(fixed_concepts, fixed_combination): concept_image_path = os.path.join(concept_dir, concept_image) fixed_images.append(load_image(concept_image_path)) # Iterate over last concept for rows for last_image in last_concept_images: last_image_path = os.path.join(last_concept_dir, last_image) row_images = fixed_images + [load_image(last_image_path)] # Add generated samples for the current row sample_dir = os.path.join(output_base_dir, f"{base_image}_to_" + "_".join([f"{concept_image}" for concept_image in fixed_combination]) + f"_{last_image}") if os.path.exists(sample_dir): sample_images = sorted(os.listdir(sample_dir)) row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images]) images.extend(row_images) # Fill empty spaces to match the grid dimensions total_required = len(left_titles) * len(top_titles) if len(images) < total_required: images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images))) # Create the grid grid = image_grid_with_titles( imgs=images, rows=len(left_titles), cols=len(top_titles), top_titles=top_titles, left_titles=left_titles ) # Save the grid grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_combo_{'_'.join(map(str, fixed_combination))}.png") grid.save(grid_save_path) print(f"Grid saved at {grid_save_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Create image grids based on a configuration file.") parser.add_argument("config_path", type=str, help="Path to the configuration JSON file.") args = parser.parse_args() # Load the configuration with open(args.config_path, 'r') as f: config = json.load(f) if "num_samples" not in config: config["num_samples"] = 4 create_grids(config)