Spaces:
Running
Running
Commit
·
7fef6fd
1
Parent(s):
0702aa5
modifided clip.py
Browse files- extract_tools.py +4 -4
- tool_utils/clip_segmentation.py +29 -9
extract_tools.py
CHANGED
@@ -187,7 +187,7 @@ def get_all_tools():
|
|
187 |
clipseg_tool = Tool(
|
188 |
name = 'ClipSegmentation-tool',
|
189 |
func = clipsegmentation_mask,
|
190 |
-
description="""Use this tool when user ask to
|
191 |
The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
|
192 |
For example :
|
193 |
Query :Provide a segmentation mask of all road car and dog in the image
|
@@ -212,10 +212,10 @@ def get_all_tools():
|
|
212 |
)
|
213 |
|
214 |
object_extractor = Tool(
|
215 |
-
name = "Object
|
216 |
func = object_extraction,
|
217 |
-
description = " The Tool is used to
|
218 |
-
what are the objects I can view in the image or identify the objects within the image
|
219 |
)
|
220 |
|
221 |
image_parameters_tool = Tool(
|
|
|
187 |
clipseg_tool = Tool(
|
188 |
name = 'ClipSegmentation-tool',
|
189 |
func = clipsegmentation_mask,
|
190 |
+
description="""Use this tool when user ask to extract the objects from the image .
|
191 |
The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
|
192 |
For example :
|
193 |
Query :Provide a segmentation mask of all road car and dog in the image
|
|
|
212 |
)
|
213 |
|
214 |
object_extractor = Tool(
|
215 |
+
name = "Object description Tool",
|
216 |
func = object_extraction,
|
217 |
+
description = " The Tool is used to describe the objects within the image . Use this tool if user specifically ask to identify \
|
218 |
+
what are the objects I can view in the image or identify the objects within the image. "
|
219 |
)
|
220 |
|
221 |
image_parameters_tool = Tool(
|
tool_utils/clip_segmentation.py
CHANGED
@@ -14,11 +14,22 @@ class CLIPSEG:
|
|
14 |
self.threshould = threshould
|
15 |
self.clip_model.to('cpu')
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
@staticmethod
|
18 |
def create_rgb_mask(mask,color=None):
|
19 |
-
color = tuple(np.random.choice(range(
|
20 |
gray_3_channel = cv2.merge((mask, mask, mask))
|
21 |
-
gray_3_channel[mask==255] = color
|
22 |
return gray_3_channel.astype(np.uint8)
|
23 |
|
24 |
def get_segmentation_mask(self,image_path:str,object_prompts:List):
|
@@ -41,16 +52,25 @@ class CLIPSEG:
|
|
41 |
predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
42 |
predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
|
43 |
predicted_masks.append(predicted_mask)
|
44 |
-
|
|
|
|
|
|
|
45 |
resize_image = cv2.resize(image,(352,352))
|
46 |
-
|
47 |
-
|
|
|
|
|
48 |
|
49 |
-
bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
|
50 |
-
final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
|
51 |
try:
|
52 |
-
cv2.imwrite('final_mask.png',
|
53 |
return 'Segmentation image created : final_mask.png'
|
54 |
except Exception as e:
|
55 |
logging.error("Error while saving the final mask :",e)
|
56 |
-
return "unable to create a mask image "
|
|
|
|
|
|
|
|
|
|
14 |
self.threshould = threshould
|
15 |
self.clip_model.to('cpu')
|
16 |
|
17 |
+
@ staticmethod
|
18 |
+
def create_single_mask(predicted_masks , color = None ):
|
19 |
+
|
20 |
+
if len(predicted_masks)>0:
|
21 |
+
mask_image = np.zeros_like(predicted_masks[0])
|
22 |
+
else:
|
23 |
+
mask_image = np.zeros(shape=(352,352),dtype=np.unit8)
|
24 |
+
for masks in predicted_masks:
|
25 |
+
mask_image = np.bitwise_or(mask_image,masks)
|
26 |
+
return mask_image
|
27 |
+
|
28 |
@staticmethod
|
29 |
def create_rgb_mask(mask,color=None):
|
30 |
+
color = tuple(np.random.choice(range(128,255), size=3))
|
31 |
gray_3_channel = cv2.merge((mask, mask, mask))
|
32 |
+
gray_3_channel[mask==255] = 255 # for orignial color
|
33 |
return gray_3_channel.astype(np.uint8)
|
34 |
|
35 |
def get_segmentation_mask(self,image_path:str,object_prompts:List):
|
|
|
52 |
predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
53 |
predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
|
54 |
predicted_masks.append(predicted_mask)
|
55 |
+
|
56 |
+
final_mask = self.create_single_mask(predicted_masks)
|
57 |
+
rgb_predicted_mask = self.create_rgb_mask(final_mask)
|
58 |
+
|
59 |
resize_image = cv2.resize(image,(352,352))
|
60 |
+
rgb_mask_img = cv2.bitwise_and(resize_image,rgb_predicted_mask )
|
61 |
+
|
62 |
+
# mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)]
|
63 |
+
# cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
|
64 |
|
65 |
+
# bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
|
66 |
+
# final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
|
67 |
try:
|
68 |
+
cv2.imwrite('final_mask.png',rgb_mask_img)
|
69 |
return 'Segmentation image created : final_mask.png'
|
70 |
except Exception as e:
|
71 |
logging.error("Error while saving the final mask :",e)
|
72 |
+
return "unable to create a mask image "
|
73 |
+
|
74 |
+
if __name__=="__main__":
|
75 |
+
clip = CLIPSEG()
|
76 |
+
obj = clip.get_segmentation_mask(image_path="../image_store/demo.jpg",object_prompts=['sand','dog'])
|