LPX55 commited on
Commit
cc9ba96
·
1 Parent(s): 1647712
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -12,8 +12,21 @@ from PIL import Image, ImageDraw
12
  import numpy as np
13
  from sam2.sam2_image_predictor import SAM2ImagePredictor
14
 
15
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- SAM_MODEL = "facebook/sam2.1-hiera-large"
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  MODELS = {
19
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
@@ -51,6 +64,8 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
51
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
52
  pipe.to("cuda")
53
  print(pipe)
 
 
54
  PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
55
 
56
  def load_default_pipeline():
@@ -64,8 +79,11 @@ def load_default_pipeline():
64
  return gr.update(value="Default pipeline loaded!")
65
 
66
  @spaces.GPU()
67
- def predict_masks(image, points):
68
  """Predict a single mask from the image based on selected points."""
 
 
 
69
  if not points:
70
  return image # Return the original image if no points are selected
71
 
@@ -74,29 +92,28 @@ def predict_masks(image, points):
74
 
75
  # Ensure points is a list of lists with at least two elements
76
  if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
77
- points_list = [[point[0], point[1]] for point in points]
78
  else:
79
  return image # Return the original image if points structure is unexpected
80
 
81
- input_labels = [1] * len(points_list)
82
 
83
  with torch.inference_mode():
84
- PREDICTOR.set_image(np.array(image))
85
  masks, _, _ = PREDICTOR.predict(
86
- point_coords=points_list, point_labels=input_labels, multimask_output=False
87
  )
88
 
89
  # Prepare the overlay image
90
- image_np = np.array(image)
91
- red_mask = np.zeros_like(image_np)
92
  if masks and len(masks) > 0:
93
  red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
94
  red_mask = PILImage.fromarray(red_mask)
95
- original_image = PILImage.fromarray(image_np)
96
  blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
97
  return np.array(blended_image)
98
  else:
99
- return image_np
100
 
101
  def update_mask(prompts):
102
  """Update the mask based on the prompts."""
 
12
  import numpy as np
13
  from sam2.sam2_image_predictor import SAM2ImagePredictor
14
 
15
+ # class SAM2PredictorSingleton:
16
+ # _instance = None
17
+
18
+ # def __new__(cls):
19
+ # if cls._instance is None:
20
+ # cls._instance = super(SAM2PredictorSingleton, cls).__new__(cls)
21
+ # cls._instance._initialize_predictor()
22
+ # return cls._instance
23
+
24
+ # def _initialize_predictor(self):
25
+ # MODEL = "facebook/sam2-hiera-large"
26
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ # self.predictor = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
28
+
29
+
30
 
31
  MODELS = {
32
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
 
64
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
65
  pipe.to("cuda")
66
  print(pipe)
67
+ DEVICE = torch.device("cuda")
68
+ SAM_MODEL = "facebook/sam2.1-hiera-large"
69
  PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
70
 
71
  def load_default_pipeline():
 
79
  return gr.update(value="Default pipeline loaded!")
80
 
81
  @spaces.GPU()
82
+ def predict_masks(prompts):
83
  """Predict a single mask from the image based on selected points."""
84
+ image = np.array(prompts["image"]) # Convert the image to a numpy array
85
+ points = prompts["points"] # Get the points from prompts
86
+
87
  if not points:
88
  return image # Return the original image if no points are selected
89
 
 
92
 
93
  # Ensure points is a list of lists with at least two elements
94
  if isinstance(points, list) and all(isinstance(point, list) and len(point) >= 2 for point in points):
95
+ input_points = [[point[0], point[1]] for point in points]
96
  else:
97
  return image # Return the original image if points structure is unexpected
98
 
99
+ input_labels = [1] * len(input_points)
100
 
101
  with torch.inference_mode():
102
+ PREDICTOR.set_image(image)
103
  masks, _, _ = PREDICTOR.predict(
104
+ point_coords=input_points, point_labels=input_labels, multimask_output=False
105
  )
106
 
107
  # Prepare the overlay image
108
+ red_mask = np.zeros_like(image)
 
109
  if masks and len(masks) > 0:
110
  red_mask[:, :, 0] = masks[0].astype(np.uint8) * 255 # Apply the red channel
111
  red_mask = PILImage.fromarray(red_mask)
112
+ original_image = PILImage.fromarray(image)
113
  blended_image = PILImage.blend(original_image, red_mask, alpha=0.5)
114
  return np.array(blended_image)
115
  else:
116
+ return image
117
 
118
  def update_mask(prompts):
119
  """Update the mask based on the prompts."""