Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +244 -0
- requirements.txt +15 -0
app.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
from PIL import Image, ImageFilter, ImageDraw
|
8 |
+
from huggingface_hub import snapshot_download
|
9 |
+
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
10 |
+
import math
|
11 |
+
from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev", local_dir="./FLUX.1-Fill-dev")
|
16 |
+
snapshot_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", local_dir="./FLUX.1-Redux-dev")
|
17 |
+
snapshot_download(repo_id="123123aa123/insertanything_model", local_dir="./insertanything_model")
|
18 |
+
|
19 |
+
|
20 |
+
dtype = torch.bfloat16
|
21 |
+
size = (768, 768)
|
22 |
+
|
23 |
+
pipe = FluxFillPipeline.from_pretrained(
|
24 |
+
"./FLUX.1-Fill-dev",
|
25 |
+
torch_dtype=dtype
|
26 |
+
).to("cuda")
|
27 |
+
|
28 |
+
pipe.load_lora_weights(
|
29 |
+
"./insertanything_model/20250321-082022_steps5000_pytorch_lora_weights.safetensors"
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
redux = FluxPriorReduxPipeline.from_pretrained("./FLUX.1-Redux-dev").to(dtype=dtype).to("cuda")
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
### example #####
|
38 |
+
ref_dir='./examples/ref_image'
|
39 |
+
ref_mask_dir='./examples/ref_mask'
|
40 |
+
image_dir='./examples/source_image'
|
41 |
+
image_mask_dir='./examples/source_mask'
|
42 |
+
|
43 |
+
ref_list=[os.path.join(ref_dir,file) for file in os.listdir(ref_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
|
44 |
+
ref_list.sort()
|
45 |
+
|
46 |
+
ref_mask_list=[os.path.join(ref_mask_dir,file) for file in os.listdir(ref_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
|
47 |
+
ref_mask_list.sort()
|
48 |
+
|
49 |
+
image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
|
50 |
+
image_list.sort()
|
51 |
+
|
52 |
+
image_mask_list=[os.path.join(image_mask_dir,file) for file in os.listdir(image_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
|
53 |
+
image_mask_list.sort()
|
54 |
+
### example #####
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
|
60 |
+
|
61 |
+
if base_mask_option == "Draw Mask":
|
62 |
+
tar_image = base_image["image"]
|
63 |
+
tar_mask = base_image["mask"]
|
64 |
+
else:
|
65 |
+
tar_image = base_image["image"]
|
66 |
+
tar_mask = base_mask
|
67 |
+
|
68 |
+
if ref_mask_option == "Draw Mask":
|
69 |
+
ref_image = reference_image["image"]
|
70 |
+
ref_mask = reference_image["mask"]
|
71 |
+
else:
|
72 |
+
ref_image = reference_image["image"]
|
73 |
+
ref_mask = ref_mask
|
74 |
+
|
75 |
+
|
76 |
+
tar_image = tar_image.convert("RGB")
|
77 |
+
tar_mask = tar_mask.convert("L")
|
78 |
+
ref_image = ref_image.convert("RGB")
|
79 |
+
ref_mask = ref_mask.convert("L")
|
80 |
+
|
81 |
+
tar_image = np.asarray(tar_image)
|
82 |
+
tar_mask = np.asarray(tar_mask)
|
83 |
+
tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
|
84 |
+
|
85 |
+
ref_image = np.asarray(ref_image)
|
86 |
+
ref_mask = np.asarray(ref_mask)
|
87 |
+
ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
|
88 |
+
|
89 |
+
|
90 |
+
ref_box_yyxx = get_bbox_from_mask(ref_mask)
|
91 |
+
ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
|
92 |
+
masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
|
93 |
+
y1,y2,x1,x2 = ref_box_yyxx
|
94 |
+
masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
|
95 |
+
ref_mask = ref_mask[y1:y2,x1:x2]
|
96 |
+
ratio = 1.3
|
97 |
+
masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
|
98 |
+
|
99 |
+
|
100 |
+
masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
|
101 |
+
|
102 |
+
kernel = np.ones((7, 7), np.uint8)
|
103 |
+
iterations = 2
|
104 |
+
tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)
|
105 |
+
|
106 |
+
# zome in
|
107 |
+
tar_box_yyxx = get_bbox_from_mask(tar_mask)
|
108 |
+
tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=1.2)
|
109 |
+
|
110 |
+
tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=2) #1.2 1.6
|
111 |
+
tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
|
112 |
+
y1,y2,x1,x2 = tar_box_yyxx_crop
|
113 |
+
|
114 |
+
|
115 |
+
old_tar_image = tar_image.copy()
|
116 |
+
tar_image = tar_image[y1:y2,x1:x2,:]
|
117 |
+
tar_mask = tar_mask[y1:y2,x1:x2]
|
118 |
+
|
119 |
+
H1, W1 = tar_image.shape[0], tar_image.shape[1]
|
120 |
+
# zome in
|
121 |
+
|
122 |
+
|
123 |
+
tar_mask = pad_to_square(tar_mask, pad_value=0)
|
124 |
+
tar_mask = cv2.resize(tar_mask, size)
|
125 |
+
|
126 |
+
masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
|
127 |
+
pipe_prior_output = redux(Image.fromarray(masked_ref_image))
|
128 |
+
|
129 |
+
|
130 |
+
tar_image = pad_to_square(tar_image, pad_value=255)
|
131 |
+
|
132 |
+
H2, W2 = tar_image.shape[0], tar_image.shape[1]
|
133 |
+
|
134 |
+
tar_image = cv2.resize(tar_image, size)
|
135 |
+
diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
|
136 |
+
|
137 |
+
|
138 |
+
tar_mask = np.stack([tar_mask,tar_mask,tar_mask],-1)
|
139 |
+
mask_black = np.ones_like(tar_image) * 0
|
140 |
+
mask_diptych = np.concatenate([mask_black, tar_mask], axis=1)
|
141 |
+
|
142 |
+
|
143 |
+
diptych_ref_tar = Image.fromarray(diptych_ref_tar)
|
144 |
+
mask_diptych[mask_diptych == 1] = 255
|
145 |
+
mask_diptych = Image.fromarray(mask_diptych)
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
150 |
+
edited_image = pipe(
|
151 |
+
image=diptych_ref_tar,
|
152 |
+
mask_image=mask_diptych,
|
153 |
+
height=mask_diptych.size[1],
|
154 |
+
width=mask_diptych.size[0],
|
155 |
+
max_sequence_length=512,
|
156 |
+
generator=generator,
|
157 |
+
**pipe_prior_output,
|
158 |
+
).images[0]
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
width, height = edited_image.size
|
163 |
+
left = width // 2
|
164 |
+
right = width
|
165 |
+
top = 0
|
166 |
+
bottom = height
|
167 |
+
edited_image = edited_image.crop((left, top, right, bottom))
|
168 |
+
|
169 |
+
|
170 |
+
edited_image = np.array(edited_image)
|
171 |
+
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
172 |
+
edited_image = Image.fromarray(edited_image)
|
173 |
+
|
174 |
+
|
175 |
+
return [edited_image]
|
176 |
+
|
177 |
+
def update_ui(option):
|
178 |
+
if option == "Draw Mask":
|
179 |
+
return gr.update(visible=False), gr.update(visible=True)
|
180 |
+
else:
|
181 |
+
return gr.update(visible=True), gr.update(visible=False)
|
182 |
+
|
183 |
+
|
184 |
+
with gr.Blocks() as demo:
|
185 |
+
|
186 |
+
|
187 |
+
gr.Markdown("# Play with InsertAnything to Insert your Target Objects! ")
|
188 |
+
gr.Markdown("# Upload / Draw Images for the Background (up) and Reference Object (down)")
|
189 |
+
gr.Markdown("### Draw mask on the background or just upload the mask.")
|
190 |
+
gr.Markdown("### Only select one of these two methods. Don't forget to click the corresponding button!!")
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
with gr.Column():
|
194 |
+
with gr.Row():
|
195 |
+
base_image = gr.Image(label="Background Image", source="upload", tool="sketch", type="pil",
|
196 |
+
brush_color='#FFFFFF', mask_opacity=0.5)
|
197 |
+
|
198 |
+
base_mask = gr.Image(label="Background Mask", source="upload", type="pil")
|
199 |
+
|
200 |
+
with gr.Row():
|
201 |
+
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
|
202 |
+
|
203 |
+
with gr.Row():
|
204 |
+
ref_image = gr.Image(label="Reference Image", source="upload", tool="sketch", type="pil",
|
205 |
+
brush_color='#FFFFFF', mask_opacity=0.5)
|
206 |
+
|
207 |
+
ref_mask = gr.Image(label="Reference Mask", source="upload", type="pil")
|
208 |
+
|
209 |
+
with gr.Row():
|
210 |
+
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Reference Mask Input Option", value="Upload with Mask")
|
211 |
+
|
212 |
+
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=512, columns=1)
|
213 |
+
with gr.Accordion("Advanced Option", open=True):
|
214 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
|
215 |
+
gr.Markdown("### Guidelines")
|
216 |
+
gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
|
217 |
+
|
218 |
+
run_local_button = gr.Button(value="Run")
|
219 |
+
|
220 |
+
|
221 |
+
# #### example #####
|
222 |
+
num_examples = len(image_list)
|
223 |
+
for i in range(num_examples):
|
224 |
+
with gr.Row():
|
225 |
+
if i == 0:
|
226 |
+
gr.Examples([image_list[i]], inputs=[base_image], label="Examples - Background Image", examples_per_page=1)
|
227 |
+
gr.Examples([image_mask_list[i]], inputs=[base_mask], label="Examples - Background Mask", examples_per_page=1)
|
228 |
+
gr.Examples([ref_list[i]], inputs=[ref_image], label="Examples - Reference Object", examples_per_page=1)
|
229 |
+
gr.Examples([ref_mask_list[i]], inputs=[ref_mask], label="Examples - Reference Mask", examples_per_page=1)
|
230 |
+
else:
|
231 |
+
gr.Examples([image_list[i]], inputs=[base_image], examples_per_page=1, label="")
|
232 |
+
gr.Examples([image_mask_list[i]], inputs=[base_mask], examples_per_page=1, label="")
|
233 |
+
gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
|
234 |
+
gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
|
235 |
+
if i < num_examples - 1:
|
236 |
+
with gr.Row():
|
237 |
+
gr.HTML("<hr>")
|
238 |
+
# #### example #####
|
239 |
+
|
240 |
+
run_local_button.click(fn=run_local,
|
241 |
+
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option],
|
242 |
+
outputs=[baseline_gallery]
|
243 |
+
)
|
244 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchvision==0.20.1
|
3 |
+
diffusers==0.32.2
|
4 |
+
transformers==4.50.3
|
5 |
+
peft==0.15.1
|
6 |
+
opencv-python
|
7 |
+
protobuf
|
8 |
+
sentencepiece
|
9 |
+
gradio==3.39.0
|
10 |
+
bezier
|
11 |
+
lightning==2.5.1
|
12 |
+
datasets
|
13 |
+
prodigyopt
|
14 |
+
einops
|
15 |
+
scipy
|