123123aa123 commited on
Commit
84e6f56
·
verified ·
1 Parent(s): 80ad2c6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +244 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image, ImageFilter, ImageDraw
8
+ from huggingface_hub import snapshot_download
9
+ from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
10
+ import math
11
+ from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
12
+
13
+
14
+
15
+ snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev", local_dir="./FLUX.1-Fill-dev")
16
+ snapshot_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", local_dir="./FLUX.1-Redux-dev")
17
+ snapshot_download(repo_id="123123aa123/insertanything_model", local_dir="./insertanything_model")
18
+
19
+
20
+ dtype = torch.bfloat16
21
+ size = (768, 768)
22
+
23
+ pipe = FluxFillPipeline.from_pretrained(
24
+ "./FLUX.1-Fill-dev",
25
+ torch_dtype=dtype
26
+ ).to("cuda")
27
+
28
+ pipe.load_lora_weights(
29
+ "./insertanything_model/20250321-082022_steps5000_pytorch_lora_weights.safetensors"
30
+ )
31
+
32
+
33
+ redux = FluxPriorReduxPipeline.from_pretrained("./FLUX.1-Redux-dev").to(dtype=dtype).to("cuda")
34
+
35
+
36
+
37
+ ### example #####
38
+ ref_dir='./examples/ref_image'
39
+ ref_mask_dir='./examples/ref_mask'
40
+ image_dir='./examples/source_image'
41
+ image_mask_dir='./examples/source_mask'
42
+
43
+ ref_list=[os.path.join(ref_dir,file) for file in os.listdir(ref_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
44
+ ref_list.sort()
45
+
46
+ ref_mask_list=[os.path.join(ref_mask_dir,file) for file in os.listdir(ref_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
47
+ ref_mask_list.sort()
48
+
49
+ image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
50
+ image_list.sort()
51
+
52
+ image_mask_list=[os.path.join(image_mask_dir,file) for file in os.listdir(image_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
53
+ image_mask_list.sort()
54
+ ### example #####
55
+
56
+
57
+
58
+
59
+ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
60
+
61
+ if base_mask_option == "Draw Mask":
62
+ tar_image = base_image["image"]
63
+ tar_mask = base_image["mask"]
64
+ else:
65
+ tar_image = base_image["image"]
66
+ tar_mask = base_mask
67
+
68
+ if ref_mask_option == "Draw Mask":
69
+ ref_image = reference_image["image"]
70
+ ref_mask = reference_image["mask"]
71
+ else:
72
+ ref_image = reference_image["image"]
73
+ ref_mask = ref_mask
74
+
75
+
76
+ tar_image = tar_image.convert("RGB")
77
+ tar_mask = tar_mask.convert("L")
78
+ ref_image = ref_image.convert("RGB")
79
+ ref_mask = ref_mask.convert("L")
80
+
81
+ tar_image = np.asarray(tar_image)
82
+ tar_mask = np.asarray(tar_mask)
83
+ tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
84
+
85
+ ref_image = np.asarray(ref_image)
86
+ ref_mask = np.asarray(ref_mask)
87
+ ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
88
+
89
+
90
+ ref_box_yyxx = get_bbox_from_mask(ref_mask)
91
+ ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
92
+ masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
93
+ y1,y2,x1,x2 = ref_box_yyxx
94
+ masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
95
+ ref_mask = ref_mask[y1:y2,x1:x2]
96
+ ratio = 1.3
97
+ masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
98
+
99
+
100
+ masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
101
+
102
+ kernel = np.ones((7, 7), np.uint8)
103
+ iterations = 2
104
+ tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)
105
+
106
+ # zome in
107
+ tar_box_yyxx = get_bbox_from_mask(tar_mask)
108
+ tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=1.2)
109
+
110
+ tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=2) #1.2 1.6
111
+ tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
112
+ y1,y2,x1,x2 = tar_box_yyxx_crop
113
+
114
+
115
+ old_tar_image = tar_image.copy()
116
+ tar_image = tar_image[y1:y2,x1:x2,:]
117
+ tar_mask = tar_mask[y1:y2,x1:x2]
118
+
119
+ H1, W1 = tar_image.shape[0], tar_image.shape[1]
120
+ # zome in
121
+
122
+
123
+ tar_mask = pad_to_square(tar_mask, pad_value=0)
124
+ tar_mask = cv2.resize(tar_mask, size)
125
+
126
+ masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
127
+ pipe_prior_output = redux(Image.fromarray(masked_ref_image))
128
+
129
+
130
+ tar_image = pad_to_square(tar_image, pad_value=255)
131
+
132
+ H2, W2 = tar_image.shape[0], tar_image.shape[1]
133
+
134
+ tar_image = cv2.resize(tar_image, size)
135
+ diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
136
+
137
+
138
+ tar_mask = np.stack([tar_mask,tar_mask,tar_mask],-1)
139
+ mask_black = np.ones_like(tar_image) * 0
140
+ mask_diptych = np.concatenate([mask_black, tar_mask], axis=1)
141
+
142
+
143
+ diptych_ref_tar = Image.fromarray(diptych_ref_tar)
144
+ mask_diptych[mask_diptych == 1] = 255
145
+ mask_diptych = Image.fromarray(mask_diptych)
146
+
147
+
148
+
149
+ generator = torch.Generator("cuda").manual_seed(seed)
150
+ edited_image = pipe(
151
+ image=diptych_ref_tar,
152
+ mask_image=mask_diptych,
153
+ height=mask_diptych.size[1],
154
+ width=mask_diptych.size[0],
155
+ max_sequence_length=512,
156
+ generator=generator,
157
+ **pipe_prior_output,
158
+ ).images[0]
159
+
160
+
161
+
162
+ width, height = edited_image.size
163
+ left = width // 2
164
+ right = width
165
+ top = 0
166
+ bottom = height
167
+ edited_image = edited_image.crop((left, top, right, bottom))
168
+
169
+
170
+ edited_image = np.array(edited_image)
171
+ edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
172
+ edited_image = Image.fromarray(edited_image)
173
+
174
+
175
+ return [edited_image]
176
+
177
+ def update_ui(option):
178
+ if option == "Draw Mask":
179
+ return gr.update(visible=False), gr.update(visible=True)
180
+ else:
181
+ return gr.update(visible=True), gr.update(visible=False)
182
+
183
+
184
+ with gr.Blocks() as demo:
185
+
186
+
187
+ gr.Markdown("#  Play with InsertAnything to Insert your Target Objects! ")
188
+ gr.Markdown("# Upload / Draw Images for the Background (up) and Reference Object (down)")
189
+ gr.Markdown("### Draw mask on the background or just upload the mask.")
190
+ gr.Markdown("### Only select one of these two methods. Don't forget to click the corresponding button!!")
191
+
192
+ with gr.Row():
193
+ with gr.Column():
194
+ with gr.Row():
195
+ base_image = gr.Image(label="Background Image", source="upload", tool="sketch", type="pil",
196
+ brush_color='#FFFFFF', mask_opacity=0.5)
197
+
198
+ base_mask = gr.Image(label="Background Mask", source="upload", type="pil")
199
+
200
+ with gr.Row():
201
+ base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
202
+
203
+ with gr.Row():
204
+ ref_image = gr.Image(label="Reference Image", source="upload", tool="sketch", type="pil",
205
+ brush_color='#FFFFFF', mask_opacity=0.5)
206
+
207
+ ref_mask = gr.Image(label="Reference Mask", source="upload", type="pil")
208
+
209
+ with gr.Row():
210
+ ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Reference Mask Input Option", value="Upload with Mask")
211
+
212
+ baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=512, columns=1)
213
+ with gr.Accordion("Advanced Option", open=True):
214
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
215
+ gr.Markdown("### Guidelines")
216
+ gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
217
+
218
+ run_local_button = gr.Button(value="Run")
219
+
220
+
221
+ # #### example #####
222
+ num_examples = len(image_list)
223
+ for i in range(num_examples):
224
+ with gr.Row():
225
+ if i == 0:
226
+ gr.Examples([image_list[i]], inputs=[base_image], label="Examples - Background Image", examples_per_page=1)
227
+ gr.Examples([image_mask_list[i]], inputs=[base_mask], label="Examples - Background Mask", examples_per_page=1)
228
+ gr.Examples([ref_list[i]], inputs=[ref_image], label="Examples - Reference Object", examples_per_page=1)
229
+ gr.Examples([ref_mask_list[i]], inputs=[ref_mask], label="Examples - Reference Mask", examples_per_page=1)
230
+ else:
231
+ gr.Examples([image_list[i]], inputs=[base_image], examples_per_page=1, label="")
232
+ gr.Examples([image_mask_list[i]], inputs=[base_mask], examples_per_page=1, label="")
233
+ gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
234
+ gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
235
+ if i < num_examples - 1:
236
+ with gr.Row():
237
+ gr.HTML("<hr>")
238
+ # #### example #####
239
+
240
+ run_local_button.click(fn=run_local,
241
+ inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option],
242
+ outputs=[baseline_gallery]
243
+ )
244
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ diffusers==0.32.2
4
+ transformers==4.50.3
5
+ peft==0.15.1
6
+ opencv-python
7
+ protobuf
8
+ sentencepiece
9
+ gradio==3.39.0
10
+ bezier
11
+ lightning==2.5.1
12
+ datasets
13
+ prodigyopt
14
+ einops
15
+ scipy