Samarth991 commited on
Commit
7fef6fd
·
1 Parent(s): 0702aa5

modifided clip.py

Browse files
extract_tools.py CHANGED
@@ -187,7 +187,7 @@ def get_all_tools():
187
  clipseg_tool = Tool(
188
  name = 'ClipSegmentation-tool',
189
  func = clipsegmentation_mask,
190
- description="""Use this tool when user ask to generate the segmentation Mask of the objects provided by the user.
191
  The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
192
  For example :
193
  Query :Provide a segmentation mask of all road car and dog in the image
@@ -212,10 +212,10 @@ def get_all_tools():
212
  )
213
 
214
  object_extractor = Tool(
215
- name = "Object Extraction Tool",
216
  func = object_extraction,
217
- description = " The Tool is used to extract objects within the image . Use this tool if user specifically ask to identify \
218
- what are the objects I can view in the image or identify the objects within the image . "
219
  )
220
 
221
  image_parameters_tool = Tool(
 
187
  clipseg_tool = Tool(
188
  name = 'ClipSegmentation-tool',
189
  func = clipsegmentation_mask,
190
+ description="""Use this tool when user ask to extract the objects from the image .
191
  The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
192
  For example :
193
  Query :Provide a segmentation mask of all road car and dog in the image
 
212
  )
213
 
214
  object_extractor = Tool(
215
+ name = "Object description Tool",
216
  func = object_extraction,
217
+ description = " The Tool is used to describe the objects within the image . Use this tool if user specifically ask to identify \
218
+ what are the objects I can view in the image or identify the objects within the image. "
219
  )
220
 
221
  image_parameters_tool = Tool(
tool_utils/clip_segmentation.py CHANGED
@@ -14,11 +14,22 @@ class CLIPSEG:
14
  self.threshould = threshould
15
  self.clip_model.to('cpu')
16
 
 
 
 
 
 
 
 
 
 
 
 
17
  @staticmethod
18
  def create_rgb_mask(mask,color=None):
19
- color = tuple(np.random.choice(range(0,256), size=3))
20
  gray_3_channel = cv2.merge((mask, mask, mask))
21
- gray_3_channel[mask==255] = color
22
  return gray_3_channel.astype(np.uint8)
23
 
24
  def get_segmentation_mask(self,image_path:str,object_prompts:List):
@@ -41,16 +52,25 @@ class CLIPSEG:
41
  predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
42
  predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
43
  predicted_masks.append(predicted_mask)
44
-
 
 
 
45
  resize_image = cv2.resize(image,(352,352))
46
- mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)]
47
- cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
 
 
48
 
49
- bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
50
- final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
51
  try:
52
- cv2.imwrite('final_mask.png',final_mask)
53
  return 'Segmentation image created : final_mask.png'
54
  except Exception as e:
55
  logging.error("Error while saving the final mask :",e)
56
- return "unable to create a mask image "
 
 
 
 
 
14
  self.threshould = threshould
15
  self.clip_model.to('cpu')
16
 
17
+ @ staticmethod
18
+ def create_single_mask(predicted_masks , color = None ):
19
+
20
+ if len(predicted_masks)>0:
21
+ mask_image = np.zeros_like(predicted_masks[0])
22
+ else:
23
+ mask_image = np.zeros(shape=(352,352),dtype=np.unit8)
24
+ for masks in predicted_masks:
25
+ mask_image = np.bitwise_or(mask_image,masks)
26
+ return mask_image
27
+
28
  @staticmethod
29
  def create_rgb_mask(mask,color=None):
30
+ color = tuple(np.random.choice(range(128,255), size=3))
31
  gray_3_channel = cv2.merge((mask, mask, mask))
32
+ gray_3_channel[mask==255] = 255 # for orignial color
33
  return gray_3_channel.astype(np.uint8)
34
 
35
  def get_segmentation_mask(self,image_path:str,object_prompts:List):
 
52
  predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
53
  predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
54
  predicted_masks.append(predicted_mask)
55
+
56
+ final_mask = self.create_single_mask(predicted_masks)
57
+ rgb_predicted_mask = self.create_rgb_mask(final_mask)
58
+
59
  resize_image = cv2.resize(image,(352,352))
60
+ rgb_mask_img = cv2.bitwise_and(resize_image,rgb_predicted_mask )
61
+
62
+ # mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)]
63
+ # cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
64
 
65
+ # bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
66
+ # final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
67
  try:
68
+ cv2.imwrite('final_mask.png',rgb_mask_img)
69
  return 'Segmentation image created : final_mask.png'
70
  except Exception as e:
71
  logging.error("Error while saving the final mask :",e)
72
+ return "unable to create a mask image "
73
+
74
+ if __name__=="__main__":
75
+ clip = CLIPSEG()
76
+ obj = clip.get_segmentation_mask(image_path="../image_store/demo.jpg",object_prompts=['sand','dog'])