Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,139 @@ from huggingface_hub import snapshot_download
|
|
9 |
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
10 |
import math
|
11 |
from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
hf_token = os.getenv("HF_TOKEN")
|
15 |
|
@@ -59,26 +191,31 @@ image_mask_list.sort()
|
|
59 |
@spaces.GPU
|
60 |
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
|
61 |
|
|
|
62 |
if base_mask_option == "Draw Mask":
|
63 |
-
tar_image = base_image["
|
64 |
-
tar_mask = base_image["
|
65 |
else:
|
66 |
-
tar_image = base_image["
|
67 |
-
tar_mask = base_mask
|
68 |
|
69 |
if ref_mask_option == "Draw Mask":
|
70 |
-
ref_image = reference_image["
|
71 |
-
ref_mask = reference_image["
|
|
|
|
|
|
|
72 |
else:
|
73 |
-
ref_image = reference_image["
|
74 |
-
ref_mask =
|
75 |
-
|
76 |
|
77 |
tar_image = tar_image.convert("RGB")
|
78 |
tar_mask = tar_mask.convert("L")
|
79 |
ref_image = ref_image.convert("RGB")
|
80 |
ref_mask = ref_mask.convert("L")
|
81 |
|
|
|
|
|
82 |
tar_image = np.asarray(tar_image)
|
83 |
tar_mask = np.asarray(tar_mask)
|
84 |
tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
|
@@ -87,15 +224,20 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
87 |
ref_mask = np.asarray(ref_mask)
|
88 |
ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
ref_box_yyxx = get_bbox_from_mask(ref_mask)
|
92 |
ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
|
93 |
masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
|
94 |
y1,y2,x1,x2 = ref_box_yyxx
|
95 |
-
masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
|
96 |
ref_mask = ref_mask[y1:y2,x1:x2]
|
97 |
ratio = 1.3
|
98 |
-
masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
|
99 |
|
100 |
|
101 |
masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
|
@@ -172,8 +314,10 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
|
|
172 |
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
173 |
edited_image = Image.fromarray(edited_image)
|
174 |
|
175 |
-
|
176 |
-
|
|
|
|
|
177 |
|
178 |
def update_ui(option):
|
179 |
if option == "Draw Mask":
|
@@ -185,32 +329,37 @@ def update_ui(option):
|
|
185 |
with gr.Blocks() as demo:
|
186 |
|
187 |
|
188 |
-
gr.Markdown("#
|
189 |
-
gr.Markdown("
|
190 |
-
|
191 |
-
gr.Markdown("### Only select one of these two methods. Don't forget to click the corresponding button!!")
|
192 |
|
193 |
with gr.Row():
|
194 |
-
with gr.Column():
|
195 |
with gr.Row():
|
196 |
-
base_image = gr.
|
197 |
-
|
|
|
198 |
|
199 |
-
base_mask = gr.
|
200 |
|
201 |
with gr.Row():
|
202 |
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
|
203 |
|
204 |
with gr.Row():
|
205 |
-
ref_image = gr.
|
206 |
-
|
|
|
207 |
|
208 |
-
ref_mask = gr.
|
209 |
|
210 |
with gr.Row():
|
211 |
-
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Reference Mask Input Option", value="Upload with Mask")
|
212 |
|
213 |
-
|
|
|
|
|
|
|
|
|
214 |
with gr.Accordion("Advanced Option", open=True):
|
215 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
|
216 |
gr.Markdown("### Guidelines")
|
@@ -218,7 +367,6 @@ with gr.Blocks() as demo:
|
|
218 |
|
219 |
run_local_button = gr.Button(value="Run")
|
220 |
|
221 |
-
|
222 |
# #### example #####
|
223 |
num_examples = len(image_list)
|
224 |
for i in range(num_examples):
|
@@ -234,12 +382,11 @@ with gr.Blocks() as demo:
|
|
234 |
gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
|
235 |
gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
|
236 |
if i < num_examples - 1:
|
237 |
-
|
238 |
-
gr.HTML("<hr>")
|
239 |
# #### example #####
|
240 |
-
|
241 |
-
run_local_button.click(fn=run_local,
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
demo.launch()
|
|
|
9 |
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
|
10 |
import math
|
11 |
from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
|
12 |
+
|
13 |
+
import os,sys
|
14 |
+
os.system("python -m pip install -e segment_anything")
|
15 |
+
os.system("python -m pip install -e GroundingDINO")
|
16 |
+
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
17 |
+
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
|
18 |
+
os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth")
|
19 |
+
os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
|
20 |
+
|
21 |
+
import torchvision
|
22 |
+
from GroundingDINO.groundingdino.util.inference import load_model
|
23 |
+
from segment_anything import build_sam, SamPredictor
|
24 |
import spaces
|
25 |
+
import GroundingDINO.groundingdino.datasets.transforms as T
|
26 |
+
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
# GroundingDINO config and checkpoint
|
31 |
+
GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
|
32 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swinb_cogcoor.pth"
|
33 |
+
|
34 |
+
# Segment-Anything checkpoint
|
35 |
+
SAM_ENCODER_VERSION = "vit_h"
|
36 |
+
SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
|
37 |
+
|
38 |
+
# Building GroundingDINO inference model
|
39 |
+
groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cpu")
|
40 |
+
# Building SAM Model and SAM Predictor
|
41 |
+
sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
|
42 |
+
sam_predictor = SamPredictor(sam)
|
43 |
+
|
44 |
+
def transform_image(image_pil):
|
45 |
+
|
46 |
+
transform = T.Compose(
|
47 |
+
[
|
48 |
+
T.RandomResize([800], max_size=1333),
|
49 |
+
T.ToTensor(),
|
50 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
image, _ = transform(image_pil, None) # 3, h, w
|
54 |
+
return image
|
55 |
+
|
56 |
+
|
57 |
+
def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
|
58 |
+
caption = caption.lower()
|
59 |
+
caption = caption.strip()
|
60 |
+
if not caption.endswith("."):
|
61 |
+
caption = caption + "."
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
outputs = model(image[None], captions=[caption])
|
65 |
+
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
66 |
+
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
67 |
+
logits.shape[0]
|
68 |
+
|
69 |
+
# filter output
|
70 |
+
logits_filt = logits.clone()
|
71 |
+
boxes_filt = boxes.clone()
|
72 |
+
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
73 |
+
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
74 |
+
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
75 |
+
logits_filt.shape[0]
|
76 |
+
|
77 |
+
# get phrase
|
78 |
+
tokenlizer = model.tokenizer
|
79 |
+
tokenized = tokenlizer(caption)
|
80 |
+
# build pred
|
81 |
+
pred_phrases = []
|
82 |
+
scores = []
|
83 |
+
for logit, box in zip(logits_filt, boxes_filt):
|
84 |
+
pred_phrase = get_phrases_from_posmap(
|
85 |
+
logit > text_threshold, tokenized, tokenlizer)
|
86 |
+
if with_logits:
|
87 |
+
pred_phrases.append(
|
88 |
+
pred_phrase + f"({str(logit.max().item())[:4]})")
|
89 |
+
else:
|
90 |
+
pred_phrases.append(pred_phrase)
|
91 |
+
scores.append(logit.max().item())
|
92 |
+
|
93 |
+
return boxes_filt, torch.Tensor(scores), pred_phrases
|
94 |
+
|
95 |
+
|
96 |
+
def get_mask(image, label):
|
97 |
+
global groundingdino_model, sam_predictor
|
98 |
+
|
99 |
+
|
100 |
+
image_pil = image.convert("RGB")
|
101 |
+
transformed_image = transform_image(image_pil)
|
102 |
+
|
103 |
+
|
104 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(
|
105 |
+
groundingdino_model, transformed_image, label
|
106 |
+
)
|
107 |
+
|
108 |
+
size = image_pil.size
|
109 |
+
|
110 |
+
# process boxes
|
111 |
+
H, W = size[1], size[0]
|
112 |
+
for i in range(boxes_filt.size(0)):
|
113 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
114 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
115 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
116 |
+
|
117 |
+
boxes_filt = boxes_filt.cpu()
|
118 |
+
|
119 |
+
# nms
|
120 |
+
|
121 |
+
nms_idx = torchvision.ops.nms(
|
122 |
+
boxes_filt, scores, 0.8).numpy().tolist()
|
123 |
+
boxes_filt = boxes_filt[nms_idx]
|
124 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
125 |
+
|
126 |
+
|
127 |
+
image = np.array(image_pil)
|
128 |
+
sam_predictor.set_image(image)
|
129 |
+
|
130 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(
|
131 |
+
boxes_filt, image.shape[:2])
|
132 |
+
|
133 |
+
masks, _, _ = sam_predictor.predict_torch(
|
134 |
+
point_coords=None,
|
135 |
+
point_labels=None,
|
136 |
+
boxes=transformed_boxes,
|
137 |
+
multimask_output=False,
|
138 |
+
)
|
139 |
+
result_mask = masks[0][0].cpu().numpy()
|
140 |
+
|
141 |
+
result_mask = Image.fromarray(result_mask)
|
142 |
+
|
143 |
+
return result_mask
|
144 |
+
|
145 |
|
146 |
hf_token = os.getenv("HF_TOKEN")
|
147 |
|
|
|
191 |
@spaces.GPU
|
192 |
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
|
193 |
|
194 |
+
|
195 |
if base_mask_option == "Draw Mask":
|
196 |
+
tar_image = base_image["background"]
|
197 |
+
tar_mask = base_image["layers"][0]
|
198 |
else:
|
199 |
+
tar_image = base_image["background"]
|
200 |
+
tar_mask = base_mask["background"]
|
201 |
|
202 |
if ref_mask_option == "Draw Mask":
|
203 |
+
ref_image = reference_image["background"]
|
204 |
+
ref_mask = reference_image["layers"][0]
|
205 |
+
elif ref_mask_option == "Upload with Mask":
|
206 |
+
ref_image = reference_image["background"]
|
207 |
+
ref_mask = ref_mask["background"]
|
208 |
else:
|
209 |
+
ref_image = reference_image["background"]
|
210 |
+
ref_mask = get_mask(ref_image, text_prompt)
|
|
|
211 |
|
212 |
tar_image = tar_image.convert("RGB")
|
213 |
tar_mask = tar_mask.convert("L")
|
214 |
ref_image = ref_image.convert("RGB")
|
215 |
ref_mask = ref_mask.convert("L")
|
216 |
|
217 |
+
return_ref_mask = ref_mask.copy()
|
218 |
+
|
219 |
tar_image = np.asarray(tar_image)
|
220 |
tar_mask = np.asarray(tar_mask)
|
221 |
tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
|
|
|
224 |
ref_mask = np.asarray(ref_mask)
|
225 |
ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
|
226 |
|
227 |
+
if tar_mask.sum() == 0:
|
228 |
+
raise gr.Error('No mask for the background image.Please check mask button!')
|
229 |
+
|
230 |
+
if ref_mask.sum() == 0:
|
231 |
+
raise gr.Error('No mask for the reference image.Please check mask button!')
|
232 |
|
233 |
ref_box_yyxx = get_bbox_from_mask(ref_mask)
|
234 |
ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
|
235 |
masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
|
236 |
y1,y2,x1,x2 = ref_box_yyxx
|
237 |
+
masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
|
238 |
ref_mask = ref_mask[y1:y2,x1:x2]
|
239 |
ratio = 1.3
|
240 |
+
masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
|
241 |
|
242 |
|
243 |
masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
|
|
|
314 |
edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
|
315 |
edited_image = Image.fromarray(edited_image)
|
316 |
|
317 |
+
if ref_mask_option != "Label to Mask":
|
318 |
+
return [edited_image]
|
319 |
+
else:
|
320 |
+
return [return_ref_mask, edited_image]
|
321 |
|
322 |
def update_ui(option):
|
323 |
if option == "Draw Mask":
|
|
|
329 |
with gr.Blocks() as demo:
|
330 |
|
331 |
|
332 |
+
gr.Markdown("# Insert-Anything")
|
333 |
+
gr.Markdown("### Draw mask or upload mask.Only select one of these two methods. Don't forget to click the corresponding button!!")
|
334 |
+
|
|
|
335 |
|
336 |
with gr.Row():
|
337 |
+
with gr.Column(scale=1):
|
338 |
with gr.Row():
|
339 |
+
base_image = gr.ImageEditor(label="Background Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
|
340 |
+
layers = False,
|
341 |
+
interactive=True)
|
342 |
|
343 |
+
base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
|
344 |
|
345 |
with gr.Row():
|
346 |
base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
|
347 |
|
348 |
with gr.Row():
|
349 |
+
ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
|
350 |
+
layers = False,
|
351 |
+
interactive=True)
|
352 |
|
353 |
+
ref_mask = gr.ImageEditor(label="Reference Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
|
354 |
|
355 |
with gr.Row():
|
356 |
+
ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Reference Mask Input Option", value="Upload with Mask")
|
357 |
|
358 |
+
with gr.Row():
|
359 |
+
text_prompt = gr.Textbox(label="Label")
|
360 |
+
|
361 |
+
with gr.Column(scale=1):
|
362 |
+
baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=701, columns=1)
|
363 |
with gr.Accordion("Advanced Option", open=True):
|
364 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
|
365 |
gr.Markdown("### Guidelines")
|
|
|
367 |
|
368 |
run_local_button = gr.Button(value="Run")
|
369 |
|
|
|
370 |
# #### example #####
|
371 |
num_examples = len(image_list)
|
372 |
for i in range(num_examples):
|
|
|
382 |
gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
|
383 |
gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
|
384 |
if i < num_examples - 1:
|
385 |
+
gr.HTML("<hr>")
|
|
|
386 |
# #### example #####
|
387 |
+
|
388 |
+
run_local_button.click(fn=run_local,
|
389 |
+
inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
|
390 |
+
outputs=[baseline_gallery]
|
391 |
+
)
|
392 |
demo.launch()
|