LPX55 commited on
Commit
c767877
·
1 Parent(s): cc9ba96
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -26,7 +26,7 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
26
  # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  # self.predictor = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
28
 
29
-
30
 
31
  MODELS = {
32
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
@@ -64,9 +64,6 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
64
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
65
  pipe.to("cuda")
66
  print(pipe)
67
- DEVICE = torch.device("cuda")
68
- SAM_MODEL = "facebook/sam2.1-hiera-large"
69
- PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
70
 
71
  def load_default_pipeline():
72
  global pipe
@@ -80,6 +77,15 @@ def load_default_pipeline():
80
 
81
  @spaces.GPU()
82
  def predict_masks(prompts):
 
 
 
 
 
 
 
 
 
83
  """Predict a single mask from the image based on selected points."""
84
  image = np.array(prompts["image"]) # Convert the image to a numpy array
85
  points = prompts["points"] # Get the points from prompts
 
26
  # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  # self.predictor = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE)
28
 
29
+ PREDICTOR = None
30
 
31
  MODELS = {
32
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
 
64
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
65
  pipe.to("cuda")
66
  print(pipe)
 
 
 
67
 
68
  def load_default_pipeline():
69
  global pipe
 
77
 
78
  @spaces.GPU()
79
  def predict_masks(prompts):
80
+
81
+ DEVICE = torch.device("cuda")
82
+ SAM_MODEL = "facebook/sam2.1-hiera-large"
83
+ if PREDICTOR is None:
84
+ PREDICTOR = SAM2ImagePredictor.from_pretrained(SAM_MODEL, device=DEVICE)
85
+ else:
86
+ PREDICTOR = PREDICTOR
87
+
88
+
89
  """Predict a single mask from the image based on selected points."""
90
  image = np.array(prompts["image"]) # Convert the image to a numpy array
91
  points = prompts["points"] # Get the points from prompts