linoyts HF Staff commited on
Commit
8506b26
·
verified ·
1 Parent(s): 5da04a3

Create app.py

Browse files

initial app commit

Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import gc
5
+ import numpy as np
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from diffusers import StableDiffusionXLPipeline
9
+ import open_clip
10
+ from huggingface_hub import hf_hub_download
11
+ from IP_Adapter.ip_adapter import IPAdapterXL
12
+ from perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition
13
+ import tempfile
14
+ import uuid
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ # Initialize SDXL pipeline
19
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
20
+ pipe = StableDiffusionXLPipeline.from_pretrained(
21
+ base_model_path,
22
+ torch_dtype=torch.float16,
23
+ add_watermarker=False,
24
+ )
25
+
26
+ # Initialize IP-Adapter
27
+ image_encoder_repo = 'h94/IP-Adapter'
28
+ image_encoder_subfolder = 'models/image_encoder'
29
+ ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin')
30
+ ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device)
31
+
32
+ # Initialize CLIP model
33
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
34
+ clip_model.to(device)
35
+ print("Models initialized successfully!")
36
+
37
+ def get_image_embeds(pil_image, model=clip_model, preproc=preprocess, dev=device):
38
+ """Get CLIP image embeddings for a given PIL image"""
39
+ image = preproc(pil_image)[np.newaxis, :, :, :]
40
+ with torch.no_grad():
41
+ embeds = model.encode_image(image.to(dev))
42
+ return embeds.cpu().detach().numpy()
43
+
44
+ def save_temp_image(image):
45
+ """Save a PIL image to a temporary file and return the path"""
46
+ temp_dir = tempfile.gettempdir()
47
+ filename = f"{uuid.uuid4()}.png"
48
+ filepath = os.path.join(temp_dir, filename)
49
+ image.save(filepath)
50
+ return filepath
51
+
52
+ def process_images(
53
+ base_image,
54
+ concept_image1, concept_desc1,
55
+ concept_image2=None, concept_desc2=None,
56
+ concept_image3=None, concept_desc3=None,
57
+ rank1=10, rank2=10, rank3=10,
58
+ prompt=None,
59
+ scale=1.0,
60
+ seed=420
61
+ ):
62
+ """Process the base image and concept images to generate modified images"""
63
+ # Process base image
64
+ base_image_pil = Image.fromarray(base_image).convert("RGB")
65
+ base_embed = get_image_embeds(base_image_pil)
66
+
67
+ # Process concept images
68
+ concept_images = []
69
+ concept_descriptions = []
70
+
71
+ # Add first concept (required)
72
+ if concept_image1 is not None:
73
+ concept_images.append(concept_image1)
74
+ concept_descriptions.append(concept_desc1 if concept_desc1 else "Concept 1")
75
+ else:
76
+ return None, "Please upload at least one concept image"
77
+
78
+ # Add second concept (optional)
79
+ if concept_image2 is not None:
80
+ concept_images.append(concept_image2)
81
+ concept_descriptions.append(concept_desc2 if concept_desc2 else "Concept 2")
82
+
83
+ # Add third concept (optional)
84
+ if concept_image3 is not None:
85
+ concept_images.append(concept_image3)
86
+ concept_descriptions.append(concept_desc3 if concept_desc3 else "Concept 3")
87
+
88
+ # Get all ranks
89
+ ranks = [rank1]
90
+ if concept_image2 is not None:
91
+ ranks.append(rank2)
92
+ if concept_image3 is not None:
93
+ ranks.append(rank3)
94
+
95
+ concept_embeds = []
96
+ for img in concept_images:
97
+ if img is not None:
98
+ img_pil = Image.fromarray(img).convert("RGB")
99
+ concept_embeds.append(get_image_embeds(img_pil))
100
+
101
+ # Compute projection matrices
102
+ projection_matrices = []
103
+ for i, embed in enumerate(concept_embeds):
104
+ # For a single image, we need to reshape to have the same format as a collection
105
+ single_embed = embed.reshape(1, *embed.shape)
106
+ projection_matrix = compute_dataset_embeds_svd(single_embed, ranks[i])
107
+ projection_matrices.append(projection_matrix)
108
+
109
+ # Create projection data structure for the composition
110
+ projections_data = [
111
+ {
112
+ "embed": embed,
113
+ "projection_matrix": proj_matrix
114
+ }
115
+ for embed, proj_matrix in zip(concept_embeds, projection_matrices)
116
+ ]
117
+
118
+ # Generate modified images -
119
+ modified_images = get_modified_images_embeds_composition(
120
+ base_embed,
121
+ projections_data,
122
+ ip_model,
123
+ prompt=prompt,
124
+ scale=scale,
125
+ num_samples=1,
126
+ seed=seed
127
+ )
128
+
129
+ return modified_images
130
+
131
+ def process_and_display(
132
+ base_image,
133
+ concept_image1, concept_desc1,
134
+ concept_image2=None, concept_desc2=None,
135
+ concept_image3=None, concept_desc3=None,
136
+ rank1=10, rank2=10, rank3=10,
137
+ prompt=None, scale=1.0, seed=420
138
+ ):
139
+ """Wrapper for process_images that handles UI updates"""
140
+ if base_image is None:
141
+ return None, "Please upload a base image"
142
+
143
+ if concept_image1 is None:
144
+ return None, "Please upload at least one concept image"
145
+
146
+ modified_images = process_images(
147
+ base_image,
148
+ concept_image1, concept_desc1,
149
+ concept_image2, concept_desc2,
150
+ concept_image3, concept_desc3,
151
+ rank1, rank2, rank3,
152
+ prompt, scale, seed
153
+ )
154
+
155
+ # # Clean up memory
156
+ # torch.cuda.empty_cache()
157
+ # gc.collect()
158
+
159
+ return modified_images
160
+
161
+ with gr.Blocks(title="Image Concept Composition") as demo:
162
+ gr.Markdown("# Image Concept Composition")
163
+ gr.Markdown("Upload a base image and 1-3 concept images to create new images that combine these concepts.")
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ base_image = gr.Image(label="Base Image (Required)", type="numpy")
168
+
169
+ with gr.Row():
170
+ with gr.Column(scale=2):
171
+ concept_image1 = gr.Image(label="Concept Image 1 (Required)", type="numpy")
172
+ with gr.Column(scale=1):
173
+ concept_desc1 = gr.Textbox(label="Concept 1 Description", placeholder="Describe this concept")
174
+ rank1 = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Rank 1")
175
+
176
+ with gr.Row():
177
+ with gr.Column(scale=2):
178
+ concept_image2 = gr.Image(label="Concept Image 2 (Optional)", type="numpy")
179
+ with gr.Column(scale=1):
180
+ concept_desc2 = gr.Textbox(label="Concept 2 Description", placeholder="Describe this concept")
181
+ rank2 = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Rank 2")
182
+
183
+ with gr.Row():
184
+ with gr.Column(scale=2):
185
+ concept_image3 = gr.Image(label="Concept Image 3 (Optional)", type="numpy")
186
+ with gr.Column(scale=1):
187
+ concept_desc3 = gr.Textbox(label="Concept 3 Description", placeholder="Describe this concept")
188
+ rank3 = gr.Slider(minimum=1, maximum=50, value=10, step=1, label="Rank 3")
189
+
190
+ prompt = gr.Textbox(label="Guidance Prompt (Optional)", placeholder="Optional text prompt to guide generation")
191
+
192
+ with gr.Row():
193
+ scale = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Scale")
194
+ seed = gr.Number(value=420, label="Seed", precision=0)
195
+
196
+ submit_btn = gr.Button("Generate Image")
197
+
198
+ with gr.Column():
199
+ gallery = gr.Gallery(label="Generated Image", show_label=True)
200
+ status = gr.Markdown("Upload images and click Generate")
201
+
202
+ submit_btn.click(
203
+ fn=process_and_display,
204
+ inputs=[
205
+ base_image,
206
+ concept_image1, concept_desc1,
207
+ concept_image2, concept_desc2,
208
+ concept_image3, concept_desc3,
209
+ rank1, rank2, rank3,
210
+ prompt, scale, seed
211
+ ],
212
+ outputs=[gallery, status]
213
+ )
214
+
215
+
216
+
217
+ demo.launch()