File size: 10,924 Bytes
0e78cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c457ff3
0e78cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c457ff3
 
 
0e78cbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os 
import cv2 
import requests
from PIL import Image
import logging 
import torch
from llm_service import get_llm
from langchain_core.tools import tool,Tool
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_groq import ChatGroq
from utils import draw_panoptic_segmentation

from tool_utils.clip_segmentation import CLIPSEG
from tool_utils.object_extractor import create_object_extraction_chain
from tool_utils.yolo_world import YoloWorld
from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check

try:
    from transformers import BlipProcessor, BlipForConditionalGeneration
    from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
except ImportError as err:
    logging.error("Import error :{}".format(err))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

logging.info("Loading Foundation Models")
try:
    clipseg_model = CLIPSEG()
except Exception as err :
    logging.error("Unable to clipseg model {}".format(err))
try:
    maskformer_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
    maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
except:
    logging.error("Unable to Maskformer model {}".format(err))


def get_groq_model(model_name = "gemma2-9b-it"):
    os.environ.get("GROQ_API_KEY")
    llm_groq = ChatGroq(model=model_name)
    return llm_groq

@tool
def panoptic_image_segemntation(image_path:str)->str:
    """
    The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \
    the objects present in the image . Use the tool in case user ask to create a panoptic segmentation.
    """
    if image_path.startswith('https'):
        image = Image.open(requests.get(image_path, stream=True).raw).convert('RGB')
    else:
        image = Image.open(image_path).convert('RGB')
    maskformer_model.to(device)
    inputs = maskformer_processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = maskformer_model(**inputs)
    
    
    prediction = maskformer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
    save_mask_path = draw_panoptic_segmentation(maskformer_model,prediction['segmentation'],prediction['segments_info'])
    labels = []
    for segment in prediction['segments_info']:
        label_names = maskformer_model.config.id2label[segment['label_id']]
        print(label_names)
        labels.append(label_names)
    return 'Panoptic Segmentation image {} created with labels {} '.format(save_mask_path,labels)

@tool 
def image_description(img_path:str)->str:
    "Use this tool to describe the image " \
    "The tool helps you to identify weather in the image as well "
    hf_model = "Salesforce/blip-image-captioning-base"
    text = ""
    if img_path.startswith('https'):
        image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
    else:
        image = Image.open(img_path).convert('RGB')
    try:
        processor = BlipProcessor.from_pretrained(hf_model)
        caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
    except:
        logging.error("unable to load the Blip model ")
    
    logging.info("Image Caption model loaded ! ")
    
    # unconditional image captioning
    inputs = processor(image, return_tensors ='pt').to(device)
    output = caption_model.generate(**inputs, max_new_tokens=50)
    caption = processor.decode(output[0], skip_special_tokens=True)
    
    # conditional image captioning
    obj_text = "Total number of objects in image "
    inputs_2 = processor(image, obj_text ,return_tensors ='pt').to(device)
    out_2 = caption_model.generate(**inputs_2,max_new_tokens=50)
    object_caption = processor.decode(out_2[0], skip_special_tokens=True)
    
    ## clear the GPU cache 
    with torch.no_grad():
        torch.cuda.empty_cache()
    text = caption + " ."+ object_caption+" ."
    return text 


@tool 
def clipsegmentation_mask(input_data:str)->str:
    """
    The tool helps to extract the object masks from the image. 
    For example : If you want to extract the object masks from the image use this tool.
    """
    data = input_data.split(",")
    image_path = data[0]
    object_prompts = data[1:]
    masks = clipseg_model.get_segmentation_mask(image_path,object_prompts)
    return masks

@tool
def generate_bounding_box_tool(input_data:str)->str:
    "use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
    yolo_world_model= YoloWorld()
    data = input_data.split(",")
    image_path = data[0]
    object_prompts = data[1:]
    object_data = yolo_world_model.run_inference(image_path,object_prompts)
    return object_data

@tool 
def object_extraction(img_path:str)->str:
    "Use this tool to identify the objects within the image"
    
    hf_model = "Salesforce/blip-image-captioning-base"
    if img_path.startswith('https'):
        image = Image.open(requests.get(img_path, stream=True).raw).convert('RGB')
    else:
        image = Image.open(img_path).convert('RGB')
    try:
        processor = BlipProcessor.from_pretrained(hf_model)
        caption_model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device)
    except:
        logging.error("unable to load the Blip model ")
    
    logging.info("Image Caption model loaded ! ")
    
    # unconditional image captioning
    inputs = processor(image, return_tensors ='pt').to(device)
    output = caption_model.generate(**inputs, max_new_tokens=50)
    llm = get_groq_model() 
    getobject_chain = create_object_extraction_chain(llm=llm)
    
    extracted_objects = getobject_chain.invoke({
        'context': processor.decode(output[0], skip_special_tokens=True)
    }).objects

    print("Extracted objects : ",extracted_objects)
    ## clear the GPU cache 
    with torch.no_grad():
        torch.cuda.empty_cache()
    
    return extracted_objects.split(',') 

@tool 
def get_image_quality(image_path:str)->str:
    """
    This tool helps to find out the parameters of the image.The tool will determine if image is blurry or not.
    It will also tell you if image is bright or not.  
    This tool also determines the Signal to Noise Ratio of the image as well . 
    For example Output of the tool will be :
    example 1 : Image is blurry.Image is not bright.Signal to Noise is less than 1 - More Noise in image
    example 2 : Image is not blurry . Image is bright.Signal to Noise is greater than 1 - More Signal in image
    """
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
    
    brightness_text = brightness_check(image)
    blurry_text = gaussian_noise_check(image)
    snr_text = snr_check(image)
    final_text = "Image properties are :\n{}\n{}\n{}".format(blurry_text, brightness_text,snr_text)
    return final_text



def get_all_tools():
     ## bind tools 
        image_desc_tool = Tool(
        name = 'Image_Descprtion_Tool',
        func= image_description,
        description =  """
                    The tool helps to describe about the image or create a caption of the image 
                    If the user asks to decribe or genrerate a caption for the image use this tool.
                    This tool can also be used to identify the weather within the image .
                    user example questions : 
                        1. Describe the image ? 
                        2. What the weather looks like in the image ?  
                    """
        )

        clipseg_tool = Tool(
            name = 'ClipSegmentation-tool',
            func = clipsegmentation_mask,
            description="""Use this tool when user ask to generate the segmentation Mask of the objects provided by the user. 
                        The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
                        For example : 
                        Query :Provide a segmentation mask of all road car and dog in the image      

                        The tool will generate the segmentation mask of the objects in the image.
                        for such query from the user you need to first use the tool to identify the objects and then use this tool to 
                        generate the segmentation mask for the objects.

                        """
        )

        bounding_box_generator = Tool(
            name = 'Bounding Box Generator',
            func = generate_bounding_box_tool,
            description= "The tool helps to provide bounding boxes for the given image and list of objects\
                .Use this tool when user ask to provide bounding boxes for the objects.if user has not specified the names of the objects \
                then use the object extraction tool to identify the objects and then use this tool to generate the bounding boxes for the objects.\
                The input to this tool is the path of the image and list of objects for which bounding boxes are to be generated"
        )

        object_extractor = Tool(
            name = "Object Extraction Tool",
            func = object_extraction,
            description = " The Tool is used to extract objects within the image . Use this tool if user specifically ask to identify \
                what are the objects I can view in the image or identify the objects within the image . "
        )

        image_parameters_tool = Tool(
            name = 'Image Parameters_Tool',
            func = get_image_quality,
            description= """ This tool will help you to determine
                - If the image is blurry or not
                - If the image is bright/sharp or not
                - SNR ratio of the image
            Based on the tool output take a proper decision regarding the image quality"""
        )

        panoptic_segmentation = Tool(
            name = 'panoptic_Segmentation_tool',
            func = panoptic_image_segemntation,
            description = "The tool is used to create a Panoptic segmentation mask . It uses Maskformer network to create a panoptic segmentation of all \
                        the objects present in the image . Use the tool in case user ask to create a panoptic segmentation or count objects in the image.\
                        The tool also provides a list of objects along with the mask image of the all segmented objects found in the image ."
        )

        tools = [
            DuckDuckGoSearchResults(),
            image_desc_tool,
            clipseg_tool,
            image_parameters_tool,
            object_extractor,
            bounding_box_generator,
            panoptic_segmentation
            ]
        return tools