WensongSong commited on
Commit
50d0879
·
verified ·
1 Parent(s): cb27c42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -34
app.py CHANGED
@@ -9,7 +9,139 @@ from huggingface_hub import snapshot_download
9
  from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
10
  import math
11
  from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
 
 
 
 
 
 
 
 
 
 
 
 
12
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  hf_token = os.getenv("HF_TOKEN")
15
 
@@ -59,26 +191,31 @@ image_mask_list.sort()
59
  @spaces.GPU
60
  def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
61
 
 
62
  if base_mask_option == "Draw Mask":
63
- tar_image = base_image["image"]
64
- tar_mask = base_image["mask"]
65
  else:
66
- tar_image = base_image["image"]
67
- tar_mask = base_mask
68
 
69
  if ref_mask_option == "Draw Mask":
70
- ref_image = reference_image["image"]
71
- ref_mask = reference_image["mask"]
 
 
 
72
  else:
73
- ref_image = reference_image["image"]
74
- ref_mask = ref_mask
75
-
76
 
77
  tar_image = tar_image.convert("RGB")
78
  tar_mask = tar_mask.convert("L")
79
  ref_image = ref_image.convert("RGB")
80
  ref_mask = ref_mask.convert("L")
81
 
 
 
82
  tar_image = np.asarray(tar_image)
83
  tar_mask = np.asarray(tar_mask)
84
  tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
@@ -87,15 +224,20 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
87
  ref_mask = np.asarray(ref_mask)
88
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
89
 
 
 
 
 
 
90
 
91
  ref_box_yyxx = get_bbox_from_mask(ref_mask)
92
  ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
93
  masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
94
  y1,y2,x1,x2 = ref_box_yyxx
95
- masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
96
  ref_mask = ref_mask[y1:y2,x1:x2]
97
  ratio = 1.3
98
- masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
99
 
100
 
101
  masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
@@ -172,8 +314,10 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_
172
  edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
173
  edited_image = Image.fromarray(edited_image)
174
 
175
-
176
- return [edited_image]
 
 
177
 
178
  def update_ui(option):
179
  if option == "Draw Mask":
@@ -185,32 +329,37 @@ def update_ui(option):
185
  with gr.Blocks() as demo:
186
 
187
 
188
- gr.Markdown("#  Play with InsertAnything to Insert your Target Objects! ")
189
- gr.Markdown("# Upload / Draw Images for the Background (up) and Reference Object (down)")
190
- gr.Markdown("### Draw mask on the background or just upload the mask.")
191
- gr.Markdown("### Only select one of these two methods. Don't forget to click the corresponding button!!")
192
 
193
  with gr.Row():
194
- with gr.Column():
195
  with gr.Row():
196
- base_image = gr.Image(label="Background Image", source="upload", tool="sketch", type="pil",
197
- brush_color='#FFFFFF', mask_opacity=0.5)
 
198
 
199
- base_mask = gr.Image(label="Background Mask", source="upload", type="pil")
200
 
201
  with gr.Row():
202
  base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
203
 
204
  with gr.Row():
205
- ref_image = gr.Image(label="Reference Image", source="upload", tool="sketch", type="pil",
206
- brush_color='#FFFFFF', mask_opacity=0.5)
 
207
 
208
- ref_mask = gr.Image(label="Reference Mask", source="upload", type="pil")
209
 
210
  with gr.Row():
211
- ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Reference Mask Input Option", value="Upload with Mask")
212
 
213
- baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=512, columns=1)
 
 
 
 
214
  with gr.Accordion("Advanced Option", open=True):
215
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
216
  gr.Markdown("### Guidelines")
@@ -218,7 +367,6 @@ with gr.Blocks() as demo:
218
 
219
  run_local_button = gr.Button(value="Run")
220
 
221
-
222
  # #### example #####
223
  num_examples = len(image_list)
224
  for i in range(num_examples):
@@ -234,12 +382,11 @@ with gr.Blocks() as demo:
234
  gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
235
  gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
236
  if i < num_examples - 1:
237
- with gr.Row():
238
- gr.HTML("<hr>")
239
  # #### example #####
240
-
241
- run_local_button.click(fn=run_local,
242
- inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option],
243
- outputs=[baseline_gallery]
244
- )
245
  demo.launch()
 
9
  from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
10
  import math
11
  from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
12
+
13
+ import os,sys
14
+ os.system("python -m pip install -e segment_anything")
15
+ os.system("python -m pip install -e GroundingDINO")
16
+ sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
17
+ sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
18
+ os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth")
19
+ os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
20
+
21
+ import torchvision
22
+ from GroundingDINO.groundingdino.util.inference import load_model
23
+ from segment_anything import build_sam, SamPredictor
24
  import spaces
25
+ import GroundingDINO.groundingdino.datasets.transforms as T
26
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
27
+
28
+
29
+
30
+ # GroundingDINO config and checkpoint
31
+ GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
32
+ GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swinb_cogcoor.pth"
33
+
34
+ # Segment-Anything checkpoint
35
+ SAM_ENCODER_VERSION = "vit_h"
36
+ SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
37
+
38
+ # Building GroundingDINO inference model
39
+ groundingdino_model = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cpu")
40
+ # Building SAM Model and SAM Predictor
41
+ sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
42
+ sam_predictor = SamPredictor(sam)
43
+
44
+ def transform_image(image_pil):
45
+
46
+ transform = T.Compose(
47
+ [
48
+ T.RandomResize([800], max_size=1333),
49
+ T.ToTensor(),
50
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
51
+ ]
52
+ )
53
+ image, _ = transform(image_pil, None) # 3, h, w
54
+ return image
55
+
56
+
57
+ def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
58
+ caption = caption.lower()
59
+ caption = caption.strip()
60
+ if not caption.endswith("."):
61
+ caption = caption + "."
62
+
63
+ with torch.no_grad():
64
+ outputs = model(image[None], captions=[caption])
65
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
66
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
67
+ logits.shape[0]
68
+
69
+ # filter output
70
+ logits_filt = logits.clone()
71
+ boxes_filt = boxes.clone()
72
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
73
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
74
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
75
+ logits_filt.shape[0]
76
+
77
+ # get phrase
78
+ tokenlizer = model.tokenizer
79
+ tokenized = tokenlizer(caption)
80
+ # build pred
81
+ pred_phrases = []
82
+ scores = []
83
+ for logit, box in zip(logits_filt, boxes_filt):
84
+ pred_phrase = get_phrases_from_posmap(
85
+ logit > text_threshold, tokenized, tokenlizer)
86
+ if with_logits:
87
+ pred_phrases.append(
88
+ pred_phrase + f"({str(logit.max().item())[:4]})")
89
+ else:
90
+ pred_phrases.append(pred_phrase)
91
+ scores.append(logit.max().item())
92
+
93
+ return boxes_filt, torch.Tensor(scores), pred_phrases
94
+
95
+
96
+ def get_mask(image, label):
97
+ global groundingdino_model, sam_predictor
98
+
99
+
100
+ image_pil = image.convert("RGB")
101
+ transformed_image = transform_image(image_pil)
102
+
103
+
104
+ boxes_filt, scores, pred_phrases = get_grounding_output(
105
+ groundingdino_model, transformed_image, label
106
+ )
107
+
108
+ size = image_pil.size
109
+
110
+ # process boxes
111
+ H, W = size[1], size[0]
112
+ for i in range(boxes_filt.size(0)):
113
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
114
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
115
+ boxes_filt[i][2:] += boxes_filt[i][:2]
116
+
117
+ boxes_filt = boxes_filt.cpu()
118
+
119
+ # nms
120
+
121
+ nms_idx = torchvision.ops.nms(
122
+ boxes_filt, scores, 0.8).numpy().tolist()
123
+ boxes_filt = boxes_filt[nms_idx]
124
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
125
+
126
+
127
+ image = np.array(image_pil)
128
+ sam_predictor.set_image(image)
129
+
130
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(
131
+ boxes_filt, image.shape[:2])
132
+
133
+ masks, _, _ = sam_predictor.predict_torch(
134
+ point_coords=None,
135
+ point_labels=None,
136
+ boxes=transformed_boxes,
137
+ multimask_output=False,
138
+ )
139
+ result_mask = masks[0][0].cpu().numpy()
140
+
141
+ result_mask = Image.fromarray(result_mask)
142
+
143
+ return result_mask
144
+
145
 
146
  hf_token = os.getenv("HF_TOKEN")
147
 
 
191
  @spaces.GPU
192
  def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
193
 
194
+
195
  if base_mask_option == "Draw Mask":
196
+ tar_image = base_image["background"]
197
+ tar_mask = base_image["layers"][0]
198
  else:
199
+ tar_image = base_image["background"]
200
+ tar_mask = base_mask["background"]
201
 
202
  if ref_mask_option == "Draw Mask":
203
+ ref_image = reference_image["background"]
204
+ ref_mask = reference_image["layers"][0]
205
+ elif ref_mask_option == "Upload with Mask":
206
+ ref_image = reference_image["background"]
207
+ ref_mask = ref_mask["background"]
208
  else:
209
+ ref_image = reference_image["background"]
210
+ ref_mask = get_mask(ref_image, text_prompt)
 
211
 
212
  tar_image = tar_image.convert("RGB")
213
  tar_mask = tar_mask.convert("L")
214
  ref_image = ref_image.convert("RGB")
215
  ref_mask = ref_mask.convert("L")
216
 
217
+ return_ref_mask = ref_mask.copy()
218
+
219
  tar_image = np.asarray(tar_image)
220
  tar_mask = np.asarray(tar_mask)
221
  tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
 
224
  ref_mask = np.asarray(ref_mask)
225
  ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
226
 
227
+ if tar_mask.sum() == 0:
228
+ raise gr.Error('No mask for the background image.Please check mask button!')
229
+
230
+ if ref_mask.sum() == 0:
231
+ raise gr.Error('No mask for the reference image.Please check mask button!')
232
 
233
  ref_box_yyxx = get_bbox_from_mask(ref_mask)
234
  ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
235
  masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
236
  y1,y2,x1,x2 = ref_box_yyxx
237
+ masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
238
  ref_mask = ref_mask[y1:y2,x1:x2]
239
  ratio = 1.3
240
+ masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
241
 
242
 
243
  masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
 
314
  edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
315
  edited_image = Image.fromarray(edited_image)
316
 
317
+ if ref_mask_option != "Label to Mask":
318
+ return [edited_image]
319
+ else:
320
+ return [return_ref_mask, edited_image]
321
 
322
  def update_ui(option):
323
  if option == "Draw Mask":
 
329
  with gr.Blocks() as demo:
330
 
331
 
332
+ gr.Markdown("# Insert-Anything")
333
+ gr.Markdown("### Draw mask or upload mask.Only select one of these two methods. Don't forget to click the corresponding button!!")
334
+
 
335
 
336
  with gr.Row():
337
+ with gr.Column(scale=1):
338
  with gr.Row():
339
+ base_image = gr.ImageEditor(label="Background Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
340
+ layers = False,
341
+ interactive=True)
342
 
343
+ base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
344
 
345
  with gr.Row():
346
  base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
347
 
348
  with gr.Row():
349
+ ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
350
+ layers = False,
351
+ interactive=True)
352
 
353
+ ref_mask = gr.ImageEditor(label="Reference Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
354
 
355
  with gr.Row():
356
+ ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Reference Mask Input Option", value="Upload with Mask")
357
 
358
+ with gr.Row():
359
+ text_prompt = gr.Textbox(label="Label")
360
+
361
+ with gr.Column(scale=1):
362
+ baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=701, columns=1)
363
  with gr.Accordion("Advanced Option", open=True):
364
  seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
365
  gr.Markdown("### Guidelines")
 
367
 
368
  run_local_button = gr.Button(value="Run")
369
 
 
370
  # #### example #####
371
  num_examples = len(image_list)
372
  for i in range(num_examples):
 
382
  gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
383
  gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
384
  if i < num_examples - 1:
385
+ gr.HTML("<hr>")
 
386
  # #### example #####
387
+
388
+ run_local_button.click(fn=run_local,
389
+ inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
390
+ outputs=[baseline_gallery]
391
+ )
392
  demo.launch()