# # imports | |
# import os | |
# import json | |
# import base64 | |
# from io import BytesIO | |
# from dotenv import load_dotenv | |
# from openai import OpenAI | |
# import gradio as gr | |
# import numpy as np | |
# from PIL import Image, ImageDraw | |
# import requests | |
# import torch | |
# from transformers import ( | |
# AutoProcessor, | |
# Owlv2ForObjectDetection, | |
# AutoModelForZeroShotObjectDetection | |
# ) | |
# # from transformers import AutoProcessor, Owlv2ForObjectDetection | |
# from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
# # Initialization | |
# load_dotenv() | |
# os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') | |
# PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') | |
# MODEL = "gpt-4o" | |
# openai = OpenAI() | |
# # Initialize models | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# # Owlv2 | |
# owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") | |
# owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) | |
# # DINO | |
# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") | |
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) | |
# system_message = """You are an expert in object detection. When users mention: | |
# 1. "count [object(s)]" - Use detect_objects with proper format based on model | |
# 2. "detect [object(s)]" - Same as count | |
# 3. "show [object(s)]" - Same as count | |
# For DINO model: Format queries as "a [object]." (e.g., "a frog.") | |
# For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]] | |
# Always use object detection tool when counting/detecting is mentioned.""" | |
# system_message += "Always be accurate. If you don't know the answer, say so." | |
# class State: | |
# def __init__(self): | |
# self.current_image = None | |
# self.last_prediction = None | |
# self.current_model = "owlv2" # Default model | |
# state = State() | |
# def get_preprocessed_image(pixel_values): | |
# pixel_values = pixel_values.squeeze().numpy() | |
# unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None] | |
# unnormalized_image = (unnormalized_image * 255).astype(np.uint8) | |
# unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) | |
# return unnormalized_image | |
# def encode_image_to_base64(image_array): | |
# if image_array is None: | |
# return None | |
# image = Image.fromarray(image_array) | |
# buffered = BytesIO() | |
# image.save(buffered, format="JPEG") | |
# return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
# def format_query_for_model(text_input, model_type="owlv2"): | |
# """Format query based on model requirements""" | |
# # Extract objects (e.g., "detect a lion" -> "lion") | |
# text = text_input.lower() | |
# words = [w.strip('.,?!') for w in text.split() | |
# if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']] | |
# if model_type == "owlv2": | |
# # Return just the list of queries for Owlv2, not nested list | |
# queries = ["a photo of " + obj for obj in words] | |
# print("Owlv2 queries:", queries) | |
# return queries | |
# else: # DINO | |
# # DINO query format | |
# query = f"a {words[:]}." | |
# print("DINO query:", query) | |
# return query | |
# def detect_objects(query_text): | |
# if state.current_image is None: | |
# return {"count": 0, "message": "No image provided"} | |
# image = Image.fromarray(state.current_image) | |
# draw = ImageDraw.Draw(image) | |
# if state.current_model == "owlv2": | |
# # For Owlv2, pass the text queries directly | |
# inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device) | |
# with torch.no_grad(): | |
# outputs = owlv2_model(**inputs) | |
# results = owlv2_processor.post_process_object_detection( | |
# outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]]) | |
# ) | |
# else: # DINO | |
# # For DINO, pass the single text query | |
# inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device) | |
# with torch.no_grad(): | |
# outputs = dino_model(**inputs) | |
# results = dino_processor.post_process_grounded_object_detection( | |
# outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3, | |
# target_sizes=[image.size[::-1]] | |
# ) | |
# # Draw detection boxes | |
# boxes = results[0]["boxes"] | |
# scores = results[0]["scores"] | |
# for box, score in zip(boxes, scores): | |
# box = [round(i) for i in box.tolist()] | |
# draw.rectangle(box, outline="red", width=3) | |
# draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red") | |
# state.last_prediction = np.array(image) | |
# return { | |
# "count": len(boxes), | |
# "confidence": scores.tolist(), | |
# "message": f"Detected {len(boxes)} objects" | |
# } | |
# def identify_plant(): | |
# if state.current_image is None: | |
# return {"error": "No image provided"} | |
# image = Image.fromarray(state.current_image) | |
# img_byte_arr = BytesIO() | |
# image.save(img_byte_arr, format='JPEG') | |
# img_byte_arr = img_byte_arr.getvalue() | |
# api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" | |
# files = [('images', ('image.jpg', img_byte_arr))] | |
# data = {'organs': ['leaf']} | |
# try: | |
# response = requests.post(api_endpoint, files=files, data=data) | |
# if response.status_code == 200: | |
# result = response.json() | |
# best_match = result['results'][0] | |
# return { | |
# "scientific_name": best_match['species']['scientificName'], | |
# "common_names": best_match['species'].get('commonNames', []), | |
# "family": best_match['species']['family']['scientificName'], | |
# "genus": best_match['species']['genus']['scientificName'], | |
# "confidence": f"{best_match['score']*100:.1f}%" | |
# } | |
# else: | |
# return {"error": f"API Error: {response.status_code}"} | |
# except Exception as e: | |
# return {"error": f"Error: {str(e)}"} | |
# # Tool definitions | |
# object_detection_function = { | |
# "name": "detect_objects", | |
# "description": "Use this function to detect and count objects in images based on text queries.", | |
# "parameters": { | |
# "type": "object", | |
# "properties": { | |
# "query_text": { | |
# "type": "array", | |
# "description": "List of text queries describing objects to detect", | |
# "items": {"type": "string"} | |
# } | |
# } | |
# } | |
# } | |
# plant_identification_function = { | |
# "name": "identify_plant", | |
# "description": "Use this when asked about plant species identification or botanical classification.", | |
# "parameters": { | |
# "type": "object", | |
# "properties": {}, | |
# "required": [] | |
# } | |
# } | |
# tools = [ | |
# {"type": "function", "function": object_detection_function}, | |
# {"type": "function", "function": plant_identification_function} | |
# ] | |
# def format_tool_response(tool_response_content): | |
# data = json.loads(tool_response_content) | |
# if "error" in data: | |
# return f"Error: {data['error']}" | |
# elif "scientific_name" in data: | |
# return f"""📋 Plant Identification Results: | |
# 🌿 Scientific Name: {data['scientific_name']} | |
# 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} | |
# 👪 Family: {data['family']} | |
# 🎯 Confidence: {data['confidence']}""" | |
# else: | |
# return f"I detected {data['count']} objects in the image." | |
# def chat(message, image, history): | |
# if image is not None: | |
# state.current_image = image | |
# if state.current_image is None: | |
# return "Please upload an image first.", None | |
# base64_image = encode_image_to_base64(state.current_image) | |
# messages = [{"role": "system", "content": system_message}] | |
# for human, assistant in history: | |
# messages.append({"role": "user", "content": human}) | |
# messages.append({"role": "assistant", "content": assistant}) | |
# # Extract objects to detect from user message | |
# # This could be enhanced with better NLP | |
# objects_to_detect = message.lower() | |
# formatted_query = format_query_for_model(objects_to_detect, state.current_model) | |
# messages.append({ | |
# "role": "user", | |
# "content": [ | |
# {"type": "text", "text": message}, | |
# {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
# ] | |
# }) | |
# response = openai.chat.completions.create( | |
# model=MODEL, | |
# messages=messages, | |
# tools=tools, | |
# max_tokens=300 | |
# ) | |
# if response.choices[0].finish_reason == "tool_calls": | |
# message = response.choices[0].message | |
# messages.append(message) | |
# for tool_call in message.tool_calls: | |
# if tool_call.function.name == "detect_objects": | |
# results = detect_objects(formatted_query) | |
# else: | |
# results = identify_plant() | |
# tool_response = { | |
# "role": "tool", | |
# "content": json.dumps(results), | |
# "tool_call_id": tool_call.id | |
# } | |
# messages.append(tool_response) | |
# response = openai.chat.completions.create( | |
# model=MODEL, | |
# messages=messages, | |
# max_tokens=300 | |
# ) | |
# return response.choices[0].message.content, state.last_prediction | |
# def update_model(choice): | |
# print(f"Model switched to: {choice}") | |
# state.current_model = choice.lower() | |
# return f"Model switched to {choice}" | |
# # Create Gradio interface | |
# with gr.Blocks() as demo: | |
# gr.Markdown("# Object Detection and Plant Analysis System") | |
# with gr.Row(): | |
# with gr.Column(): | |
# model_choice = gr.Radio( | |
# choices=["Owlv2", "DINO"], | |
# value="Owlv2", | |
# label="Select Detection Model", | |
# interactive=True | |
# ) | |
# image_input = gr.Image(type="numpy", label="Upload Image") | |
# text_input = gr.Textbox( | |
# label="Ask about the image", | |
# placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'" | |
# ) | |
# with gr.Row(): | |
# submit_btn = gr.Button("Analyze") | |
# reset_btn = gr.Button("Reset") | |
# with gr.Column(): | |
# chatbot = gr.Chatbot() | |
# # output_image = gr.Image(label="Detected Objects") | |
# output_image = gr.Image(type="numpy", label="Detected Objects") | |
# def process_interaction(message, image, history): | |
# response, pred_image = chat(message, image, history) | |
# history.append((message, response)) | |
# return "", pred_image, history | |
# def reset_interface(): | |
# state.current_image = None | |
# state.last_prediction = None | |
# return None, None, None, [] | |
# model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)]) | |
# submit_btn.click( | |
# fn=process_interaction, | |
# inputs=[text_input, image_input, chatbot], | |
# outputs=[text_input, output_image, chatbot] | |
# ) | |
# reset_btn.click( | |
# fn=reset_interface, | |
# inputs=[], | |
# outputs=[image_input, output_image, text_input, chatbot] | |
# ) | |
# gr.Markdown("""## Instructions | |
# 1. Select the detection model (Owlv2 or DINO) | |
# 2. Upload an image | |
# 3. Ask specific questions about objects or plants | |
# 4. Click Analyze to get results""") | |
# demo.launch(share=True) | |
import os | |
import re | |
import io | |
import uuid | |
import contextlib | |
import gradio as gr | |
from PIL import Image | |
import shutil | |
# Required packages: | |
# pip install vision-agent gradio openai anthropic | |
from vision_agent.agent import VisionAgentCoderV2 | |
from vision_agent.models import AgentMessage | |
############################################# | |
# GLOBAL INITIALIZATION | |
############################################# | |
# Create a unique temporary directory for saved images | |
TEMP_DIR = "temp_images" | |
if not os.path.exists(TEMP_DIR): | |
os.makedirs(TEMP_DIR) | |
# Initialize VisionAgentCoderV2 with verbose logging so the generated code has detailed print outputs. | |
agent = VisionAgentCoderV2(verbose=True) | |
############################################# | |
# UTILITY: SAVE UPLOADED IMAGE TO A TEMP FILE | |
############################################# | |
def save_uploaded_image(image): | |
""" | |
Saves the uploaded image (a numpy array) to a temporary file. | |
Returns the filename (including path) to be passed as media to VisionAgent. | |
""" | |
# Generate a unique filename | |
filename = os.path.join(TEMP_DIR, f"{uuid.uuid4().hex}.jpg") | |
im = Image.fromarray(image) | |
im.save(filename) | |
return filename | |
############################################# | |
# UTILITY: PARSE FILENAMES FROM save_image(...) | |
############################################# | |
def parse_saved_image_filenames(code_str): | |
""" | |
Find all filenames in lines that look like: | |
save_image(..., 'filename.jpg') | |
Returns a list of the extracted filenames. | |
""" | |
pattern = r"save_image\s*\(\s*[^,]+,\s*'([^']+)'\s*\)" | |
return re.findall(pattern, code_str) | |
############################################# | |
# UTILITY: EXECUTE CODE, CAPTURE STDOUT, IDENTIFY IMAGES | |
############################################# | |
def run_and_capture_with_images(code_str): | |
""" | |
Executes the given code_str, capturing stdout and returning: | |
- output: a string with all print statements (the step logs) | |
- existing_images: list of filenames that were saved and exist on disk. | |
""" | |
# Parse the code for image filenames saved via save_image | |
filenames = parse_saved_image_filenames(code_str) | |
# Capture stdout using a StringIO buffer | |
buf = io.StringIO() | |
with contextlib.redirect_stdout(buf): | |
# IMPORTANT: Here we exec the generated code. | |
exec(code_str, globals(), locals()) | |
# Gather all printed output | |
output = buf.getvalue() | |
# Check which of the parsed filenames exist on disk (prepend TEMP_DIR if needed) | |
existing_images = [] | |
for fn in filenames: | |
# If filename is not an absolute path, assume it is in TEMP_DIR | |
if not os.path.isabs(fn): | |
fn = os.path.join(TEMP_DIR, fn) | |
if os.path.exists(fn): | |
existing_images.append(fn) | |
return output, existing_images | |
############################################# | |
# CHAT FUNCTION: PROCESS USER PROMPT & IMAGE | |
############################################# | |
def chat(prompt, image, history): | |
""" | |
When the user sends a prompt and optionally an image, do the following: | |
1. Save the image to a temp file. | |
2. Use VisionAgentCoderV2 to generate code for the task. | |
3. Execute the generated code, capturing its stdout logs and any saved image files. | |
4. Append the logs and image gallery info to the conversation history. | |
""" | |
# Validate that an image was provided. | |
if image is None: | |
history.append(("System", "Please upload an image.")) | |
return history, None | |
# Save the uploaded image for use in the generated code. | |
image_path = save_uploaded_image(image) | |
# Generate the code with VisionAgent using the user prompt and the image filename. | |
code_context = agent.generate_code( | |
[ | |
AgentMessage( | |
role="user", | |
content=prompt, | |
media=[image_path] | |
) | |
] | |
) | |
# Combine the generated code and its test snippet. | |
generated_code = code_context.code + "\n" + code_context.test | |
# Run the generated code and capture output and any saved images. | |
stdout_text, image_files = run_and_capture_with_images(generated_code) | |
# Format the response text (the captured logs). | |
response_text = f"**Execution Logs:**\n{stdout_text}\n" | |
if image_files: | |
response_text += "\n**Saved Images:** " + ", ".join(image_files) | |
else: | |
response_text += "\nNo images were saved by the generated code." | |
# Append the prompt and response to the chat history. | |
history.append((prompt, response_text)) | |
# Optionally, you could clear the image input after use. | |
return history, image_files | |
############################################# | |
# GRADIO CHAT INTERFACE | |
############################################# | |
with gr.Blocks() as demo: | |
gr.Markdown("# VisionAgent Chat App") | |
gr.Markdown( | |
""" | |
This chat app lets you enter a prompt (e.g., "Count the number of cacao oranges in the image") | |
along with an image. The app then uses VisionAgentCoderV2 to generate multi-step code, executes it, | |
and returns the detailed logs and any saved images. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=7): | |
chatbot = gr.Chatbot(label="Chat History") | |
prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., Count the number of cacao oranges in the image") | |
submit_btn = gr.Button("Send") | |
with gr.Column(scale=5): | |
image_input = gr.Image(label="Upload Image", type="numpy") | |
gallery = gr.Gallery(label="Generated Images").style(grid=[2], height="auto") | |
# Clear chat history button | |
clear_btn = gr.Button("Clear Chat") | |
# Chat function wrapper (it takes current chat history, prompt, image) | |
def user_chat_wrapper(prompt, image, history): | |
history = history or [] | |
history, image_files = chat(prompt, image, history) | |
return history, image_files | |
submit_btn.click(fn=user_chat_wrapper, inputs=[prompt_input, image_input, chatbot], outputs=[chatbot, gallery]) | |
clear_btn.click(lambda: ([], None), None, [chatbot, gallery]) | |
demo.launch() | |