Spaces:
Running
Running
Commit
·
3945649
1
Parent(s):
52932d2
added object detection imporved
Browse files- extract_tools.py +6 -3
- tool_utils/final_mask.png +0 -0
- tool_utils/yolo_world.py +21 -13
- tool_utils/yolov8x-worldv2.pt +3 -0
- utils.py +9 -10
extract_tools.py
CHANGED
@@ -11,7 +11,6 @@ from langchain_groq import ChatGroq
|
|
11 |
from utils import draw_panoptic_segmentation
|
12 |
|
13 |
from tool_utils.clip_segmentation import CLIPSEG
|
14 |
-
from tool_utils.object_extractor import create_object_extraction_chain
|
15 |
from tool_utils.yolo_world import YoloWorld
|
16 |
from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
|
17 |
|
@@ -34,6 +33,10 @@ try:
|
|
34 |
except:
|
35 |
logging.error("Unable to Maskformer model {}".format(err))
|
36 |
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def get_groq_model(model_name = "gemma2-9b-it"):
|
39 |
os.environ.get("GROQ_API_KEY")
|
@@ -117,7 +120,6 @@ def clipsegmentation_mask(input_data:str)->str:
|
|
117 |
@tool
|
118 |
def generate_bounding_box_tool(input_data:str)->str:
|
119 |
"use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
|
120 |
-
yolo_world_model= YoloWorld()
|
121 |
data = input_data.split(",")
|
122 |
image_path = data[0]
|
123 |
object_prompts = data[1:]
|
@@ -142,7 +144,8 @@ def object_extraction(image_path:str)->str:
|
|
142 |
segment_label_id = segment['label_id']
|
143 |
segment_label = maskformer_model.config.id2label[segment_label_id]
|
144 |
objects.append(segment_label)
|
145 |
-
|
|
|
146 |
|
147 |
@tool
|
148 |
def get_image_quality(image_path:str)->str:
|
|
|
11 |
from utils import draw_panoptic_segmentation
|
12 |
|
13 |
from tool_utils.clip_segmentation import CLIPSEG
|
|
|
14 |
from tool_utils.yolo_world import YoloWorld
|
15 |
from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
|
16 |
|
|
|
33 |
except:
|
34 |
logging.error("Unable to Maskformer model {}".format(err))
|
35 |
|
36 |
+
try:
|
37 |
+
yolo_world_model= YoloWorld()
|
38 |
+
except :
|
39 |
+
logging.error("Unable to Yolo world model {}".format(err))
|
40 |
|
41 |
def get_groq_model(model_name = "gemma2-9b-it"):
|
42 |
os.environ.get("GROQ_API_KEY")
|
|
|
120 |
@tool
|
121 |
def generate_bounding_box_tool(input_data:str)->str:
|
122 |
"use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
|
|
|
123 |
data = input_data.split(",")
|
124 |
image_path = data[0]
|
125 |
object_prompts = data[1:]
|
|
|
144 |
segment_label_id = segment['label_id']
|
145 |
segment_label = maskformer_model.config.id2label[segment_label_id]
|
146 |
objects.append(segment_label)
|
147 |
+
|
148 |
+
return "Detected objects are: "+ ",".join( objects)
|
149 |
|
150 |
@tool
|
151 |
def get_image_quality(image_path:str)->str:
|
tool_utils/final_mask.png
ADDED
![]() |
tool_utils/yolo_world.py
CHANGED
@@ -32,14 +32,16 @@ class YoloWorld:
|
|
32 |
return object_details
|
33 |
|
34 |
@staticmethod
|
35 |
-
def draw_bboxes(rgb_frame,boxes,labels,
|
36 |
-
rgb_frame = cv2.imread(rgb_frame)
|
37 |
-
rgb_frame = cv2.cvtColor(rgb_frame,cv2.COLOR_BGR2RGB)
|
38 |
|
39 |
tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
|
40 |
rgb_frame_copy = rgb_frame.copy()
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
for box,label in zip(boxes,labels):
|
44 |
if box.type() == 'torch.IntTensor':
|
45 |
box = box.numpy()
|
@@ -47,31 +49,32 @@ class YoloWorld:
|
|
47 |
x1,y1,x2,y2 = box
|
48 |
c1,c2 = (x1,y1),(x2,y2)
|
49 |
# Draw rectangle
|
50 |
-
cv2.rectangle(rgb_frame_copy, c1,c2,
|
51 |
-
|
52 |
tf = max(tl - 1, 1) # font thickness
|
53 |
# label = label2id[int(label.numpy())]
|
54 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
55 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
56 |
-
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [
|
57 |
return rgb_frame_copy
|
58 |
|
|
|
59 |
def run_yolo_infer(self,image_path:str,object_prompts:List):
|
60 |
-
self.model.set_classes(object_prompts)
|
61 |
-
results = self.model.predict(image_path)
|
62 |
processed_predictions = []
|
63 |
bounding_boxes = []
|
64 |
labels = []
|
65 |
scores = []
|
|
|
|
|
|
|
66 |
for result in results:
|
67 |
for i,box in enumerate(result.boxes):
|
68 |
x1, y1, x2, y2 = np.array(box.xyxy.cpu(), dtype=np.int32).squeeze()
|
69 |
bounding_boxes.append([x1,y1,x2,y2])
|
70 |
-
labels.append(int(box.cls.cpu()))
|
71 |
scores.append(round(float(box.conf.cpu()),2))
|
72 |
|
73 |
processed_predictions.append(dict(boxes= torch.tensor(bounding_boxes),
|
74 |
-
labels=
|
75 |
scores=torch.tensor(scores))
|
76 |
)
|
77 |
detected_image = self.draw_bboxes(rgb_frame=image_path,
|
@@ -80,5 +83,10 @@ class YoloWorld:
|
|
80 |
)
|
81 |
|
82 |
cv2.imwrite('final_mask.png', cv2.cvtColor(detected_image,cv2.COLOR_BGR2RGB))
|
83 |
-
return "Predicted image
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return object_details
|
33 |
|
34 |
@staticmethod
|
35 |
+
def draw_bboxes(rgb_frame,boxes,labels,line_thickness=3):
|
36 |
+
rgb_frame = cv2.cvtColor(cv2.imread(rgb_frame),cv2.COLOR_BGR2RGB)
|
|
|
37 |
|
38 |
tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
|
39 |
rgb_frame_copy = rgb_frame.copy()
|
40 |
+
color_dict = {}
|
41 |
+
# color = color or [random.randint(0, 255) for _ in range(3)]
|
42 |
+
for item in np.unique(np.asarray(labels)):
|
43 |
+
color_dict[item] = [random.randint(28, 255) for _ in range(3)]
|
44 |
+
|
45 |
for box,label in zip(boxes,labels):
|
46 |
if box.type() == 'torch.IntTensor':
|
47 |
box = box.numpy()
|
|
|
49 |
x1,y1,x2,y2 = box
|
50 |
c1,c2 = (x1,y1),(x2,y2)
|
51 |
# Draw rectangle
|
52 |
+
cv2.rectangle(rgb_frame_copy, c1,c2, color_dict[label], thickness=tl, lineType=cv2.LINE_AA)
|
|
|
53 |
tf = max(tl - 1, 1) # font thickness
|
54 |
# label = label2id[int(label.numpy())]
|
55 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
56 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
57 |
+
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, color_dict[label], thickness=tf, lineType=cv2.LINE_AA)
|
58 |
return rgb_frame_copy
|
59 |
|
60 |
+
|
61 |
def run_yolo_infer(self,image_path:str,object_prompts:List):
|
|
|
|
|
62 |
processed_predictions = []
|
63 |
bounding_boxes = []
|
64 |
labels = []
|
65 |
scores = []
|
66 |
+
|
67 |
+
self.model.set_classes(object_prompts)
|
68 |
+
results = self.model.predict(image_path)
|
69 |
for result in results:
|
70 |
for i,box in enumerate(result.boxes):
|
71 |
x1, y1, x2, y2 = np.array(box.xyxy.cpu(), dtype=np.int32).squeeze()
|
72 |
bounding_boxes.append([x1,y1,x2,y2])
|
73 |
+
labels.append(result.names[int(box.cls.cpu())])
|
74 |
scores.append(round(float(box.conf.cpu()),2))
|
75 |
|
76 |
processed_predictions.append(dict(boxes= torch.tensor(bounding_boxes),
|
77 |
+
labels= labels,
|
78 |
scores=torch.tensor(scores))
|
79 |
)
|
80 |
detected_image = self.draw_bboxes(rgb_frame=image_path,
|
|
|
83 |
)
|
84 |
|
85 |
cv2.imwrite('final_mask.png', cv2.cvtColor(detected_image,cv2.COLOR_BGR2RGB))
|
86 |
+
return "Predicted image : final_mask.jpg . Details :{}".format(processed_predictions[0])
|
87 |
|
88 |
+
if __name__ == "__main__":
|
89 |
+
yolo = YoloWorld()
|
90 |
+
predicted_data = yolo.run_yolo_infer('../image_store/demo2.jpg',['person','hat','building'])
|
91 |
+
print(predicted_data)
|
92 |
+
|
tool_utils/yolov8x-worldv2.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41e771bfbbb8894dd857f3fef7cac3b3578dffd49fd3547101efa6a606a02a0e
|
3 |
+
size 146355704
|
utils.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from collections import defaultdict
|
2 |
import matplotlib.pyplot as plt
|
3 |
import matplotlib.patches as mpatches
|
@@ -28,14 +29,17 @@ def draw_panoptic_segmentation(model,segmentation, segments_info):
|
|
28 |
return 'final_mask.png'
|
29 |
|
30 |
|
31 |
-
def draw_bboxes(rgb_frame,boxes,labels,
|
32 |
rgb_frame = cv2.imread(rgb_frame)
|
33 |
# rgb_frame = cv2.cvtColor(rgb_frame,cv2.COLOR_BGR2RGB)
|
34 |
|
35 |
tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
|
36 |
rgb_frame_copy = rgb_frame.copy()
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
for box,label in zip(boxes,labels):
|
40 |
if box.type() == 'torch.IntTensor':
|
41 |
box = box.numpy()
|
@@ -43,15 +47,10 @@ def draw_bboxes(rgb_frame,boxes,labels,color=None,line_thickness=3):
|
|
43 |
x1,y1,x2,y2 = box
|
44 |
c1,c2 = (x1,y1),(x2,y2)
|
45 |
# Draw rectangle
|
46 |
-
cv2.rectangle(rgb_frame_copy, c1,c2,
|
47 |
-
|
48 |
tf = max(tl - 1, 1) # font thickness
|
49 |
# label = label2id[int(label.numpy())]
|
50 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
51 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
52 |
-
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, [
|
53 |
return rgb_frame_copy
|
54 |
-
|
55 |
-
def object_extraction_using_maskformer(image_path):
|
56 |
-
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
57 |
-
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
|
|
|
1 |
+
import numpy as np
|
2 |
from collections import defaultdict
|
3 |
import matplotlib.pyplot as plt
|
4 |
import matplotlib.patches as mpatches
|
|
|
29 |
return 'final_mask.png'
|
30 |
|
31 |
|
32 |
+
def draw_bboxes(rgb_frame,boxes,labels,line_thickness=3):
|
33 |
rgb_frame = cv2.imread(rgb_frame)
|
34 |
# rgb_frame = cv2.cvtColor(rgb_frame,cv2.COLOR_BGR2RGB)
|
35 |
|
36 |
tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
|
37 |
rgb_frame_copy = rgb_frame.copy()
|
38 |
+
color_dict = {}
|
39 |
+
# color = color or [random.randint(0, 255) for _ in range(3)]
|
40 |
+
for item in np.unique(np.asarray(labels)):
|
41 |
+
color_dict[item] = [random.randint(28, 255) for _ in range(3)]
|
42 |
+
|
43 |
for box,label in zip(boxes,labels):
|
44 |
if box.type() == 'torch.IntTensor':
|
45 |
box = box.numpy()
|
|
|
47 |
x1,y1,x2,y2 = box
|
48 |
c1,c2 = (x1,y1),(x2,y2)
|
49 |
# Draw rectangle
|
50 |
+
cv2.rectangle(rgb_frame_copy, c1,c2, color_dict[label], thickness=tl, lineType=cv2.LINE_AA)
|
|
|
51 |
tf = max(tl - 1, 1) # font thickness
|
52 |
# label = label2id[int(label.numpy())]
|
53 |
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
|
54 |
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
55 |
+
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, color_dict[label], thickness=tf, lineType=cv2.LINE_AA)
|
56 |
return rgb_frame_copy
|
|
|
|
|
|
|
|