import os import json import torch import gc import numpy as np from PIL import Image from diffusers import StableDiffusionXLPipeline import open_clip from huggingface_hub import hf_hub_download from IP_Adapter.ip_adapter import IPAdapterXL from perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition from create_grids import create_grids import argparse def save_images(output_dir, image_list): os.makedirs(output_dir, exist_ok=True) for i, img in enumerate(image_list): img.save(os.path.join(output_dir, f"sample_{i + 1}.png")) def get_image_embeds(pil_image, model, preprocess, device): image = preprocess(pil_image)[np.newaxis, :, :, :] with torch.no_grad(): embeds = model.encode_image(image.to(device)) return embeds.cpu().detach().numpy() def process_combo( image_embeds_base, image_names_base, concept_embeds, concept_names, projection_matrices, ip_model, output_base_dir, num_samples=4, seed=420, prompt=None, scale=1.0 ): for base_embed, base_name in zip(image_embeds_base, image_names_base): # Generate all combinations of concept embeddings for combo_indices in np.ndindex(*(len(embeds) for embeds in concept_embeds)): concept_combo_names = [concept_names[c][idx] for c, idx in enumerate(combo_indices)] combo_dir = os.path.join( output_base_dir, f"{base_name}_to_" + "_".join(concept_combo_names) ) if os.path.exists(combo_dir): print(f"Directory {combo_dir} already exists. Skipping...") continue projections_data = [ { "embed": concept_embeds[c][idx], "projection_matrix": projection_matrices[c] } for c, idx in enumerate(combo_indices) ] modified_images = get_modified_images_embeds_composition( base_embed, projections_data, ip_model, prompt=prompt, scale=scale, num_samples=num_samples, seed=seed ) save_images(combo_dir, modified_images) del modified_images torch.cuda.empty_cache() gc.collect() def main(config_path, should_create_grids): with open(config_path, 'r') as f: config = json.load(f) if "prompt" not in config: config["prompt"] = None if "scale" not in config: config["scale"] = 1.0 if config["prompt"] is None else 0.6 if "seed" not in config: config["seed"] = 420 if "num_samples" not in config: config["num_samples"] = 4 base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" pipe = StableDiffusionXLPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, add_watermarker=False, ) image_encoder_repo = 'h94/IP-Adapter' image_encoder_subfolder = 'models/image_encoder' ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin') device = "cuda" ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device) device = 'cuda:0' model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') model.to(device) # Get base image embeddings image_files_base = [os.path.join(config["input_dir_base"], f) for f in os.listdir(config["input_dir_base"]) if f.lower().endswith(('png', 'jpg', 'jpeg'))] image_embeds_base = [] image_names_base = [] for path in image_files_base: img_name = os.path.basename(path) image_names_base.append(img_name) image_embeds_base.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device)) # Handle n concepts concept_dirs = config["input_dirs_concepts"] concept_embeds = [] concept_names = [] projection_matrices = [] for concept_dir, embeds_path, rank in zip(concept_dirs, config["all_embeds_paths"], config["ranks"]): image_files = [os.path.join(concept_dir, f) for f in os.listdir(concept_dir) if f.lower().endswith(('png', 'jpg', 'jpeg'))] embeds = [] names = [] for path in image_files: img_name = os.path.basename(path) names.append(img_name) embeds.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device)) concept_embeds.append(embeds) concept_names.append(names) with open(embeds_path, "rb") as f: all_embeds_in = np.load(f) projection_matrix = compute_dataset_embeds_svd(all_embeds_in, rank) projection_matrices.append(projection_matrix) # Process combinations process_combo( image_embeds_base, image_names_base, concept_embeds, concept_names, projection_matrices, ip_model, config["output_base_dir"], config["num_samples"], config["seed"], config["prompt"], config["scale"] ) # generate grids if should_create_grids: create_grids(config) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process images using embeddings and configurations.") parser.add_argument("--config", type=str, required=True, help="Path to the configuration JSON file.") parser.add_argument("--create_grids", action="store_true", help="Enable grid creation") args = parser.parse_args() main(args.config, args.create_grids)