Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import torch | |
import gc | |
import numpy as np | |
from PIL import Image | |
from diffusers import StableDiffusionXLPipeline | |
import open_clip | |
from huggingface_hub import hf_hub_download | |
from IP_Adapter.ip_adapter import IPAdapterXL | |
from perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition | |
from create_grids import create_grids | |
import argparse | |
def save_images(output_dir, image_list): | |
os.makedirs(output_dir, exist_ok=True) | |
for i, img in enumerate(image_list): | |
img.save(os.path.join(output_dir, f"sample_{i + 1}.png")) | |
def get_image_embeds(pil_image, model, preprocess, device): | |
image = preprocess(pil_image)[np.newaxis, :, :, :] | |
with torch.no_grad(): | |
embeds = model.encode_image(image.to(device)) | |
return embeds.cpu().detach().numpy() | |
def process_combo( | |
image_embeds_base, | |
image_names_base, | |
concept_embeds, | |
concept_names, | |
projection_matrices, | |
ip_model, | |
output_base_dir, | |
num_samples=4, | |
seed=420, | |
prompt=None, | |
scale=1.0 | |
): | |
for base_embed, base_name in zip(image_embeds_base, image_names_base): | |
# Generate all combinations of concept embeddings | |
for combo_indices in np.ndindex(*(len(embeds) for embeds in concept_embeds)): | |
concept_combo_names = [concept_names[c][idx] for c, idx in enumerate(combo_indices)] | |
combo_dir = os.path.join( | |
output_base_dir, | |
f"{base_name}_to_" + "_".join(concept_combo_names) | |
) | |
if os.path.exists(combo_dir): | |
print(f"Directory {combo_dir} already exists. Skipping...") | |
continue | |
projections_data = [ | |
{ | |
"embed": concept_embeds[c][idx], | |
"projection_matrix": projection_matrices[c] | |
} | |
for c, idx in enumerate(combo_indices) | |
] | |
modified_images = get_modified_images_embeds_composition( | |
base_embed, projections_data, ip_model, prompt=prompt, scale=scale, num_samples=num_samples, seed=seed | |
) | |
save_images(combo_dir, modified_images) | |
del modified_images | |
torch.cuda.empty_cache() | |
gc.collect() | |
def main(config_path, should_create_grids): | |
with open(config_path, 'r') as f: | |
config = json.load(f) | |
if "prompt" not in config: | |
config["prompt"] = None | |
if "scale" not in config: | |
config["scale"] = 1.0 if config["prompt"] is None else 0.6 | |
if "seed" not in config: | |
config["seed"] = 420 | |
if "num_samples" not in config: | |
config["num_samples"] = 4 | |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
add_watermarker=False, | |
) | |
image_encoder_repo = 'h94/IP-Adapter' | |
image_encoder_subfolder = 'models/image_encoder' | |
ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin') | |
device = "cuda" | |
ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device) | |
device = 'cuda:0' | |
model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') | |
model.to(device) | |
# Get base image embeddings | |
image_files_base = [os.path.join(config["input_dir_base"], f) for f in os.listdir(config["input_dir_base"]) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
image_embeds_base = [] | |
image_names_base = [] | |
for path in image_files_base: | |
img_name = os.path.basename(path) | |
image_names_base.append(img_name) | |
image_embeds_base.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device)) | |
# Handle n concepts | |
concept_dirs = config["input_dirs_concepts"] | |
concept_embeds = [] | |
concept_names = [] | |
projection_matrices = [] | |
for concept_dir, embeds_path, rank in zip(concept_dirs, config["all_embeds_paths"], config["ranks"]): | |
image_files = [os.path.join(concept_dir, f) for f in os.listdir(concept_dir) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
embeds = [] | |
names = [] | |
for path in image_files: | |
img_name = os.path.basename(path) | |
names.append(img_name) | |
embeds.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device)) | |
concept_embeds.append(embeds) | |
concept_names.append(names) | |
with open(embeds_path, "rb") as f: | |
all_embeds_in = np.load(f) | |
projection_matrix = compute_dataset_embeds_svd(all_embeds_in, rank) | |
projection_matrices.append(projection_matrix) | |
# Process combinations | |
process_combo( | |
image_embeds_base, | |
image_names_base, | |
concept_embeds, | |
concept_names, | |
projection_matrices, | |
ip_model, | |
config["output_base_dir"], | |
config["num_samples"], | |
config["seed"], | |
config["prompt"], | |
config["scale"] | |
) | |
# generate grids | |
if should_create_grids: | |
create_grids(config) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Process images using embeddings and configurations.") | |
parser.add_argument("--config", type=str, required=True, help="Path to the configuration JSON file.") | |
parser.add_argument("--create_grids", action="store_true", help="Enable grid creation") | |
args = parser.parse_args() | |
main(args.config, args.create_grids) | |