CV-Agent / extract_tools.py
Samarth991's picture
Update extract_tools.py
608b2b6 verified
import os
import cv2
import requests
from PIL import Image
import logging
import torch
from llm.llm_service import get_llm
from langchain_core.tools import tool,Tool
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_groq import ChatGroq
from utils import draw_panoptic_segmentation
from typing import List
from tool_utils.clip_segmentation import CLIPSEG
from tool_utils.yolo_world import YoloWorld
from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
from tool_utils.image_description import SMOLVLM2
try:
# from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
except ImportError as err:
logging.error("Import error :{}".format(err))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info("Loading Foundation Models")
try:
clipseg_model = CLIPSEG()
except Exception as err :
logging.error("Unable to clipseg model {}".format(err))
try:
maskformer_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
except:
logging.error("Unable to Mask2Former model")
try:
yolo_world_model= YoloWorld()
except :
logging.error("Unable to Yolo world model")
def get_groq_model(model_name = "gemma2-9b-it"):
os.environ.get("GROQ_API_KEY")
llm_groq = ChatGroq(model=model_name)
return llm_groq
@tool
def panoptic_image_segemntation(image_path:str)->str:
"""
The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \
the objects present in the image . Use the tool in case user ask to create a panoptic image segmentation or panoptic mask .
"""
if image_path.startswith('https'):
image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
else:
image = Image.open(image_path).convert('RGB')
maskformer_model.to(device)
inputs = maskformer_processor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = maskformer_model(**inputs)
prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
save_mask_path = draw_panoptic_segmentation(maskformer_model,prediction['segmentation'],prediction['segments_info'])
labels = []
for segment in prediction['segments_info']:
label_names = maskformer_model.config.id2label[segment['label_id']]
labels.append(label_names)
labels = " ".join([label_name for label_name in labels])
return 'Panoptic Segmentation image {} Found labels {} in the image '.format(save_mask_path,labels)
@tool
def image_description(img_path:str)->str:
"Use this tool to describe the image " \
"The tool helps you to identify weather in the image as well "
smol_vlm = SMOLVLM2(memory_efficient=True)
query="Describe the image. Higlight the details in 2-3 lines"
response = smol_vlm.run_inference_on_image(image_path=img_path,query=query)
return response
@tool
def clipsegmentation_mask(input_data:str)->str:
"""
The tool helps to extract the object masks from the image.
For example : If you want to extract the object masks from the image use this tool.
"""
data = input_data.split(",")
image_path = data[0]
object_prompts = data[1:]
masks = clipseg_model.get_segmentation_mask(image_path,object_prompts)
return masks
@tool
def generate_bounding_box_tool(input_data:str)->str:
"use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
print(input_data)
data = input_data.split(",")
image_path = data[0]
object_prompts = data[1:]
object_data = yolo_world_model.run_yolo_infer(image_path,object_prompts)
return object_data
@tool
def object_extraction(image_path:str)->str:
"Use this tool to identify the objects within the image"
objects = []
maskformer_model.to(device)
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
inputs = maskformer_processor(image, return_tensors="pt")
inputs.to(device)
with torch.no_grad():
outputs = maskformer_model(**inputs)
prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.shape[:2]])[0]
segments_info = prediction['segments_info']
for segment in segments_info:
segment_label_id = segment['label_id']
segment_label = maskformer_model.config.id2label[segment_label_id]
objects.append(segment_label)
return "Detected objects are: "+ ",".join( objects)
@tool
def get_image_quality(image_path:str)->str:
"""
This tool helps to find out the parameters of the image.The tool will determine if image is blurry or not.
It will also tell you if image is bright or not.
This tool also determines the Signal to Noise Ratio of the image as well .
For example Output of the tool will be :
example 1 : Image is blurry.Image is not bright.Signal to Noise is less than 1 - More Noise in image
example 2 : Image is not blurry . Image is bright.Signal to Noise is greater than 1 - More Signal in image
"""
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
brightness_text = brightness_check(image)
blurry_text = gaussian_noise_check(image)
# snr_text = snr_check(image)
final_text = "Image properties are :\n{}\n{}".format(blurry_text, brightness_text)
return final_text
def get_all_tools():
## bind tools
image_desc_tool = Tool(
name = 'Image_Descprtion_Tool',
func= image_description,
description = """
The tool helps to describe about the image or create a caption of the image
If the user asks to decribe or genrerate a caption for the image use this tool.
This tool can also be used to identify the weather within the image .
user example questions :
1. Describe the image ?
2. What the weather looks like in the image ?
"""
)
clipseg_tool = Tool(
name = 'ClipSegmentation-tool',
func = clipsegmentation_mask,
description="""Use this tool when user ask to extract the objects from the image .
The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
For example :
Query :Provide a segmentation mask of all road car and dog in the image
The tool will generate the segmentation mask of the objects in the image.
for such query from the user you need to first use the tool to identify the objects and then use this tool to
generate the segmentation mask for the objects.
"""
)
bounding_box_generator = Tool(
name = 'Bounding Box Generator',
func = generate_bounding_box_tool,
description= """The tool helps to provide bounding boxes for the given image and list of objects
.Use this tool when user ask to provide bounding boxes for the objects.if user has not specified the names of the objects
then use the object extraction tool to identify the objects and then use this tool to generate the bounding boxes for the objects.
The input to this tool is the path of the image and list of objects for which bounding boxes are to be generated
For Example :
"action_input ": "image_store/<image_path>,person,dog,sand,"
"""
)
object_extractor = Tool(
name = "Object description Tool",
func = object_extraction,
description = " The Tool is used to describe the objects within the image . Use this tool if user specifically ask to identify \
what are the objects I can view in the image or identify the objects within the image. "
)
image_parameters_tool = Tool(
name = 'Image Parameters_Tool',
func = get_image_quality,
description= """ This tool will help you to determine
- If the image is blurry or not
- If the image is bright/sharp or not
- SNR ratio of the image
Based on the tool output take a proper decision regarding the image quality"""
)
panoptic_segmentation = Tool(
name = 'panoptic_Segmentation_tool',
func = panoptic_image_segemntation,
description = "The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \
the objects present in the image . Use the tool in case user ask to create a panoptic segmentation or count objects in the image.\
The tool also provides a list of objects along with the mask image of the all segmented objects found in the image .\
Us the tool if user ask to create a panoptic image segmentation or panoptic mask"
)
tools = [
DuckDuckGoSearchResults(),
image_desc_tool,
clipseg_tool,
image_parameters_tool,
object_extractor,
bounding_box_generator,
panoptic_segmentation
]
return tools