sagar007 commited on
Commit
73a8e2b
·
verified ·
1 Parent(s): 1d4c655

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -2
app.py CHANGED
@@ -9,6 +9,14 @@ processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
  def segment_everything(image):
 
 
 
 
 
 
 
 
12
  inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
13
  with torch.no_grad():
14
  outputs = model(**inputs)
@@ -17,9 +25,17 @@ def segment_everything(image):
17
  return Image.fromarray(segmentation)
18
 
19
  def segment_box(image, x1, y1, x2, y2):
 
 
 
 
 
 
 
 
20
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
21
  cropped_image = image[y1:y2, x1:x2]
22
- inputs = processor(text=["object"], images=[cropped_image], padding="max_length", return_tensors="pt")
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
  preds = outputs.logits.squeeze().sigmoid()
@@ -31,9 +47,13 @@ def update_image(image, segmentation):
31
  if segmentation is None:
32
  return image
33
 
 
 
 
 
34
  # Ensure image is in the correct format (PIL Image)
35
  if isinstance(image, np.ndarray):
36
- image_pil = Image.fromarray((image * 255).astype(np.uint8))
37
  else:
38
  image_pil = image
39
 
 
9
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
10
 
11
  def segment_everything(image):
12
+ # Check if image is a list and extract the actual image data
13
+ if isinstance(image, list):
14
+ image = image[0]
15
+
16
+ # Convert numpy array to PIL Image
17
+ if isinstance(image, np.ndarray):
18
+ image = Image.fromarray(image)
19
+
20
  inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
21
  with torch.no_grad():
22
  outputs = model(**inputs)
 
25
  return Image.fromarray(segmentation)
26
 
27
  def segment_box(image, x1, y1, x2, y2):
28
+ # Check if image is a list and extract the actual image data
29
+ if isinstance(image, list):
30
+ image = image[0]
31
+
32
+ # Convert PIL Image to numpy array if necessary
33
+ if isinstance(image, Image.Image):
34
+ image = np.array(image)
35
+
36
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
37
  cropped_image = image[y1:y2, x1:x2]
38
+ inputs = processor(text=["object"], images=[Image.fromarray(cropped_image)], padding="max_length", return_tensors="pt")
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
  preds = outputs.logits.squeeze().sigmoid()
 
47
  if segmentation is None:
48
  return image
49
 
50
+ # Check if image is a list and extract the actual image data
51
+ if isinstance(image, list):
52
+ image = image[0]
53
+
54
  # Ensure image is in the correct format (PIL Image)
55
  if isinstance(image, np.ndarray):
56
+ image_pil = Image.fromarray(image)
57
  else:
58
  image_pil = image
59