LPX55 commited on
Commit
8e20124
·
1 Parent(s): 8602a2c

merged: zerogpu doesnt like strangers

Browse files
Files changed (1) hide show
  1. app.py +55 -4
app.py CHANGED
@@ -7,12 +7,13 @@ from gradio_imageslider import ImageSlider
7
  from huggingface_hub import hf_hub_download
8
  from controlnet_union import ControlNetModel_Union
9
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
10
  from PIL import Image, ImageDraw
11
  import numpy as np
12
- from sam2_mask import create_sam2_mask_interface
13
-
14
- #from sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
 
 
16
 
17
  MODELS = {
18
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
@@ -61,6 +62,38 @@ def load_default_pipeline():
61
  ).to("cuda")
62
  return gr.update(value="Default pipeline loaded!")
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  @spaces.GPU(duration=12)
65
  def fill_image(prompt, image, model_selection, paste_back):
66
  print(f"Received image: {image}")
@@ -489,7 +522,25 @@ with gr.Blocks(css=css, fill_height=True) as demo:
489
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
490
  preview_image = gr.Image(label="Preview")
491
  with gr.TabItem("SAM2 Mask"):
492
- sam2_mask_tab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  with gr.TabItem("Misc"):
494
  with gr.Column():
495
  clear_cache_button = gr.Button("Clear CUDA Cache")
 
7
  from huggingface_hub import hf_hub_download
8
  from controlnet_union import ControlNetModel_Union
9
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
10
+ from gradio_image_prompter import ImagePrompter
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
 
 
14
 
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ SAM_MODEL = "facebook/sam2.1-hiera-large"
17
 
18
  MODELS = {
19
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
 
62
  ).to("cuda")
63
  return gr.update(value="Default pipeline loaded!")
64
 
65
+ @spaces.GPU()
66
+ def predict_masks(image, points):
67
+ """Predict a single mask from the image based on selected points."""
68
+ image_np = np.array(image)
69
+ points_list = [[point["x"], point["y"]] for point in points]
70
+ input_labels = [1] * len(points_list)
71
+
72
+ with torch.inference_mode():
73
+ PREDICTOR.set_image(image_np)
74
+ masks, _, _ = PREDICTOR.predict(
75
+ point_coords=points_list, point_labels=input_labels, multimask_output=False
76
+ )
77
+
78
+ # Prepare the overlay image
79
+ red_mask = np.zeros_like(image_np)
80
+ if masks and len(masks) > 0:
81
+ red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
82
+ red_mask = PILImage.fromarray(red_mask)
83
+ original_image = PILImage.fromarray(image_np)
84
+ blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
85
+ return np.array(blended_image)
86
+ else:
87
+ return image_np
88
+
89
+
90
+ def update_mask(prompts):
91
+ """Update the mask based on the prompts."""
92
+ image = prompts["image"]
93
+ points = prompts["points"]
94
+ return predict_masks(image, points)
95
+
96
+
97
  @spaces.GPU(duration=12)
98
  def fill_image(prompt, image, model_selection, paste_back):
99
  print(f"Received image: {image}")
 
522
  history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
523
  preview_image = gr.Image(label="Preview")
524
  with gr.TabItem("SAM2 Mask"):
525
+ gr.Markdown("# Object Segmentation with SAM2")
526
+ gr.Markdown(
527
+ """
528
+ This application utilizes **Segment Anything V2 (SAM2)** to allow you to upload an image and interactively generate a segmentation mask based on multiple points you select on the image.
529
+ """
530
+ )
531
+ with gr.Row():
532
+ with gr.Column():
533
+ upload_image_input = ImagePrompter(show_label=False)
534
+ with gr.Column():
535
+ image_output = gr.Image(label="Segmented Image", type="pil", height=400)
536
+
537
+ # Define the action triggered by the upload_image_input change
538
+ upload_image_input.change(
539
+ fn=update_mask,
540
+ inputs=[upload_image_input],
541
+ outputs=[image_output],
542
+ show_progress=True,
543
+ )
544
  with gr.TabItem("Misc"):
545
  with gr.Column():
546
  clear_cache_button = gr.Button("Clear CUDA Cache")