PLBot commited on
Commit
300f4a0
·
verified ·
1 Parent(s): a31a574

Create find_image_online.py

Browse files
Files changed (1) hide show
  1. tools/find_image_online.py +91 -0
tools/find_image_online.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+ import re
3
+ from smolagents.tools import Tool
4
+ from smolagents.agent_types import AgentImage
5
+ import requests
6
+ from io import BytesIO
7
+ import os
8
+ import tempfile
9
+
10
+ class FindImageOnlineTool(Tool):
11
+ name = "find_image_online"
12
+ description = "Searches for images online based on a query and returns an image that matches the description."
13
+ inputs = {'query': {'type': 'string', 'description': 'The search query for the image you want to find.'}}
14
+ output_type = "image"
15
+
16
+ def __init__(self, web_search_tool=None, visit_webpage_tool=None):
17
+ super().__init__()
18
+ self.web_search_tool = web_search_tool
19
+ self.visit_webpage_tool = visit_webpage_tool
20
+ self.is_initialized = True
21
+
22
+ def extract_image_urls(self, markdown_content):
23
+ # Extract image URLs from markdown using regex
24
+ # Look for standard markdown image patterns ![alt](url)
25
+ md_image_pattern = r'!\[.*?\]\((https?://[^)]+\.(jpg|jpeg|png|gif|webp))\)'
26
+ md_images = re.findall(md_image_pattern, markdown_content)
27
+
28
+ # Also look for direct URLs that end with image extensions
29
+ direct_url_pattern = r'(https?://[^\s)]+\.(jpg|jpeg|png|gif|webp))'
30
+ direct_urls = re.findall(direct_url_pattern, markdown_content)
31
+
32
+ # Combine and deduplicate results
33
+ image_urls = [url for url, _ in md_images] + [url for url, _ in direct_urls]
34
+ return list(set(image_urls))
35
+
36
+ def download_image(self, url):
37
+ try:
38
+ response = requests.get(url, stream=True, timeout=10)
39
+ response.raise_for_status()
40
+
41
+ # Create a temporary file with appropriate extension
42
+ ext = os.path.splitext(url)[1]
43
+ if not ext or ext not in ['.jpg', '.jpeg', '.png', '.gif', '.webp']:
44
+ ext = '.jpg' # Default extension
45
+
46
+ temp_file = tempfile.NamedTemporaryFile(suffix=ext, delete=False)
47
+ temp_file.write(response.content)
48
+ temp_file.close()
49
+
50
+ return temp_file.name, url
51
+ except Exception as e:
52
+ print(f"Error downloading image from {url}: {str(e)}")
53
+ return None, url
54
+
55
+ def forward(self, query: str) -> Any:
56
+ if not self.web_search_tool or not self.visit_webpage_tool:
57
+ return "Error: Web search and visit webpage tools must be provided."
58
+
59
+ try:
60
+ # Step 1: Search for the query + "image"
61
+ search_query = f"{query} image"
62
+ search_results = self.web_search_tool.forward(search_query)
63
+
64
+ # Step 2: Extract URLs from search results
65
+ url_pattern = r'\((https?://[^)]+)\)'
66
+ urls = re.findall(url_pattern, search_results)
67
+
68
+ # Step 3: Visit each page and look for images
69
+ for url in urls[:3]: # Limit to first 3 results for efficiency
70
+ try:
71
+ page_content = self.visit_webpage_tool.forward(url)
72
+ image_urls = self.extract_image_urls(page_content)
73
+
74
+ # Step 4: Download the first valid image found
75
+ for img_url in image_urls[:5]: # Try up to 5 images per page
76
+ img_path, source_url = self.download_image(img_url)
77
+ if img_path:
78
+ # Return both the image and the source information
79
+ return {
80
+ "image": AgentImage(img_path),
81
+ "source_url": source_url,
82
+ "page_url": url,
83
+ "query": query
84
+ }
85
+ except Exception as e:
86
+ continue # Try the next URL if this one fails
87
+
88
+ return f"Could not find a suitable image for '{query}'. Please try a different query."
89
+
90
+ except Exception as e:
91
+ return f"Error finding image: {str(e)}"