Files changed (2) hide show
  1. app.py +206 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import gc
5
+ import numpy as np
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from diffusers import StableDiffusionXLPipeline
9
+ import open_clip
10
+ from huggingface_hub import hf_hub_download
11
+ from IP_Composer.IP_Adapter.ip_adapter import IPAdapterXL
12
+ from IP_Composer.perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Initialize SDXL pipeline
17
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
18
+ pipe = StableDiffusionXLPipeline.from_pretrained(
19
+ base_model_path,
20
+ torch_dtype=torch.float16,
21
+ add_watermarker=False,
22
+ )
23
+
24
+ # Initialize IP-Adapter
25
+ image_encoder_repo = 'h94/IP-Adapter'
26
+ image_encoder_subfolder = 'models/image_encoder'
27
+ ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin')
28
+ ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device)
29
+
30
+ # Initialize CLIP model
31
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
32
+ clip_model.to(device)
33
+ print("Models initialized successfully!")
34
+
35
+ def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
36
+ """Get CLIP image embeddings for a given PIL image"""
37
+ image = preproc(pil_image)[np.newaxis, :, :, :]
38
+ with torch.no_grad():
39
+ embeds = model.encode_image(image.to(dev))
40
+ return embeds.cpu().detach().numpy()
41
+
42
+ def process_images(
43
+ base_image,
44
+ concept_image1, concept_desc1,
45
+ concept_image2=None, concept_desc2=None,
46
+ concept_image3=None, concept_desc3=None,
47
+ rank1=10, rank2=10, rank3=10,
48
+ prompt=None,
49
+ scale=1.0,
50
+ seed=420
51
+ ):
52
+ """Process the base image and concept images to generate modified images"""
53
+ # Process base image
54
+ base_image_pil = Image.fromarray(base_image).convert("RGB")
55
+ base_embed = get_image_embeds(base_image_pil)
56
+
57
+ # Process concept images
58
+ concept_images = []
59
+ concept_descriptions = []
60
+
61
+ # Add first concept (required)
62
+ if concept_image1 is not None:
63
+ concept_images.append(concept_image1)
64
+ concept_descriptions.append(concept_desc1 if concept_desc1 else "Concept 1")
65
+ else:
66
+ return None, "Please upload at least one concept image"
67
+
68
+ # Add second concept (optional)
69
+ if concept_image2 is not None:
70
+ concept_images.append(concept_image2)
71
+ concept_descriptions.append(concept_desc2 if concept_desc2 else "Concept 2")
72
+
73
+ # Add third concept (optional)
74
+ if concept_image3 is not None:
75
+ concept_images.append(concept_image3)
76
+ concept_descriptions.append(concept_desc3 if concept_desc3 else "Concept 3")
77
+
78
+ # Get all ranks
79
+ ranks = [rank1]
80
+ if concept_image2 is not None:
81
+ ranks.append(rank2)
82
+ if concept_image3 is not None:
83
+ ranks.append(rank3)
84
+
85
+ concept_embeds = []
86
+ for img in concept_images:
87
+ if img is not None:
88
+ img_pil = Image.fromarray(img).convert("RGB")
89
+ concept_embeds.append(get_image_embeds(img_pil))
90
+
91
+ # Compute projection matrices
92
+ projection_matrices = []
93
+ for i, embed in enumerate(concept_embeds):
94
+ # For a single image, we need to reshape to have the same format as a collection
95
+ single_embed = embed.reshape(1, *embed.shape)
96
+ projection_matrix = compute_dataset_embeds_svd(single_embed, ranks[i])
97
+ projection_matrices.append(projection_matrix)
98
+
99
+ # Create projection data structure for the composition
100
+ projections_data = [
101
+ {
102
+ "embed": embed,
103
+ "projection_matrix": proj_matrix
104
+ }
105
+ for embed, proj_matrix in zip(concept_embeds, projection_matrices)
106
+ ]
107
+
108
+ # Generate modified images -
109
+ modified_images = get_modified_images_embeds_composition(
110
+ base_embed,
111
+ projections_data,
112
+ ip_model,
113
+ prompt=prompt,
114
+ scale=scale,
115
+ num_samples=1,
116
+ seed=seed
117
+ )
118
+
119
+ return modified_images
120
+
121
+ def process_and_display(
122
+ base_image,
123
+ concept_image1, concept_desc1,
124
+ concept_image2=None, concept_desc2=None,
125
+ concept_image3=None, concept_desc3=None,
126
+ rank1=10, rank2=10, rank3=10,
127
+ prompt=None, scale=1.0, seed=420
128
+ ):
129
+ """Wrapper for process_images that handles UI updates"""
130
+ if base_image is None:
131
+ return None, "Please upload a base image"
132
+
133
+ if concept_image1 is None:
134
+ return None, "Please upload at least one concept image"
135
+
136
+ modified_images = process_images(
137
+ base_image,
138
+ concept_image1, concept_desc1,
139
+ concept_image2, concept_desc2,
140
+ concept_image3, concept_desc3,
141
+ rank1, rank2, rank3,
142
+ prompt, scale, seed
143
+ )
144
+
145
+ # # Clean up memory
146
+ # torch.cuda.empty_cache()
147
+ # gc.collect()
148
+
149
+ return modified_images
150
+
151
+ with gr.Blocks(title="Image Concept Composition") as demo:
152
+ gr.Markdown("# IP Composer")
153
+ gr.Markdown("")
154
+
155
+ with gr.Row():
156
+ with gr.Column():
157
+ base_image = gr.Image(label="Base Image (Required)", type="numpy")
158
+
159
+ with gr.Row():
160
+ with gr.Column(scale=2):
161
+ concept_image1 = gr.Image(label="Concept Image 1 (Required)", type="numpy")
162
+ with gr.Column(scale=1):
163
+ concept_desc1 = gr.Textbox(label="Concept 1 Description", placeholder="Describe this concept")
164
+ rank1 = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Rank 1")
165
+
166
+ with gr.Row():
167
+ with gr.Column(scale=2):
168
+ concept_image2 = gr.Image(label="Concept Image 2 (Optional)", type="numpy")
169
+ with gr.Column(scale=1):
170
+ concept_desc2 = gr.Textbox(label="Concept 2 Description", placeholder="Describe this concept")
171
+ rank2 = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Rank 2")
172
+
173
+ with gr.Row():
174
+ with gr.Column(scale=2):
175
+ concept_image3 = gr.Image(label="Concept Image 3 (Optional)", type="numpy")
176
+ with gr.Column(scale=1):
177
+ concept_desc3 = gr.Textbox(label="Concept 3 Description", placeholder="Describe this concept")
178
+ rank3 = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Rank 3")
179
+
180
+ prompt = gr.Textbox(label="Guidance Prompt (Optional)", placeholder="Optional text prompt to guide generation")
181
+
182
+ with gr.Row():
183
+ scale = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Scale")
184
+ seed = gr.Number(value=420, label="Seed", precision=0)
185
+
186
+ submit_btn = gr.Button("Generate")
187
+
188
+ with gr.Column():
189
+ output_image = gr.Image(label="composed output", show_label=True)
190
+
191
+ submit_btn.click(
192
+ fn=process_and_display,
193
+ inputs=[
194
+ base_image,
195
+ concept_image1, concept_desc1,
196
+ concept_image2, concept_desc2,
197
+ concept_image3, concept_desc3,
198
+ rank1, rank2, rank3,
199
+ prompt, scale, seed
200
+ ],
201
+ outputs=[output_image]
202
+ )
203
+
204
+
205
+
206
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ torchvision
3
+ transformers
4
+ accelerate
5
+ safetensors
6
+ einops
7
+ spaces
8
+ peft
9
+ huggingface-hub