ip-composer / IP_Composer /generate_compositions.py
linoyts's picture
linoyts HF Staff
Upload 64 files
c025a3d verified
raw
history blame
5.6 kB
import os
import json
import torch
import gc
import numpy as np
from PIL import Image
from diffusers import StableDiffusionXLPipeline
import open_clip
from huggingface_hub import hf_hub_download
from IP_Adapter.ip_adapter import IPAdapterXL
from perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition
from create_grids import create_grids
import argparse
def save_images(output_dir, image_list):
os.makedirs(output_dir, exist_ok=True)
for i, img in enumerate(image_list):
img.save(os.path.join(output_dir, f"sample_{i + 1}.png"))
def get_image_embeds(pil_image, model, preprocess, device):
image = preprocess(pil_image)[np.newaxis, :, :, :]
with torch.no_grad():
embeds = model.encode_image(image.to(device))
return embeds.cpu().detach().numpy()
def process_combo(
image_embeds_base,
image_names_base,
concept_embeds,
concept_names,
projection_matrices,
ip_model,
output_base_dir,
num_samples=4,
seed=420,
prompt=None,
scale=1.0
):
for base_embed, base_name in zip(image_embeds_base, image_names_base):
# Generate all combinations of concept embeddings
for combo_indices in np.ndindex(*(len(embeds) for embeds in concept_embeds)):
concept_combo_names = [concept_names[c][idx] for c, idx in enumerate(combo_indices)]
combo_dir = os.path.join(
output_base_dir,
f"{base_name}_to_" + "_".join(concept_combo_names)
)
if os.path.exists(combo_dir):
print(f"Directory {combo_dir} already exists. Skipping...")
continue
projections_data = [
{
"embed": concept_embeds[c][idx],
"projection_matrix": projection_matrices[c]
}
for c, idx in enumerate(combo_indices)
]
modified_images = get_modified_images_embeds_composition(
base_embed, projections_data, ip_model, prompt=prompt, scale=scale, num_samples=num_samples, seed=seed
)
save_images(combo_dir, modified_images)
del modified_images
torch.cuda.empty_cache()
gc.collect()
def main(config_path, should_create_grids):
with open(config_path, 'r') as f:
config = json.load(f)
if "prompt" not in config:
config["prompt"] = None
if "scale" not in config:
config["scale"] = 1.0 if config["prompt"] is None else 0.6
if "seed" not in config:
config["seed"] = 420
if "num_samples" not in config:
config["num_samples"] = 4
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
)
image_encoder_repo = 'h94/IP-Adapter'
image_encoder_subfolder = 'models/image_encoder'
ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin')
device = "cuda"
ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device)
device = 'cuda:0'
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
model.to(device)
# Get base image embeddings
image_files_base = [os.path.join(config["input_dir_base"], f) for f in os.listdir(config["input_dir_base"]) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
image_embeds_base = []
image_names_base = []
for path in image_files_base:
img_name = os.path.basename(path)
image_names_base.append(img_name)
image_embeds_base.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device))
# Handle n concepts
concept_dirs = config["input_dirs_concepts"]
concept_embeds = []
concept_names = []
projection_matrices = []
for concept_dir, embeds_path, rank in zip(concept_dirs, config["all_embeds_paths"], config["ranks"]):
image_files = [os.path.join(concept_dir, f) for f in os.listdir(concept_dir) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
embeds = []
names = []
for path in image_files:
img_name = os.path.basename(path)
names.append(img_name)
embeds.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device))
concept_embeds.append(embeds)
concept_names.append(names)
with open(embeds_path, "rb") as f:
all_embeds_in = np.load(f)
projection_matrix = compute_dataset_embeds_svd(all_embeds_in, rank)
projection_matrices.append(projection_matrix)
# Process combinations
process_combo(
image_embeds_base,
image_names_base,
concept_embeds,
concept_names,
projection_matrices,
ip_model,
config["output_base_dir"],
config["num_samples"],
config["seed"],
config["prompt"],
config["scale"]
)
# generate grids
if should_create_grids:
create_grids(config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process images using embeddings and configurations.")
parser.add_argument("--config", type=str, required=True, help="Path to the configuration JSON file.")
parser.add_argument("--create_grids", action="store_true", help="Enable grid creation")
args = parser.parse_args()
main(args.config, args.create_grids)