paulopontesm commited on
Commit
bafe111
·
1 Parent(s): 3b241b7

add image_generator tool

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +2 -1
  3. tools/image_generator.py +31 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .venv
2
+ __pycache__
3
+ .gradio
app.py CHANGED
@@ -6,6 +6,7 @@ import yaml
6
  from tools.final_answer import FinalAnswerTool
7
  from tools.visit_webpage import VisitWebpageTool
8
  from tools.web_search import DuckDuckGoSearchTool
 
9
 
10
  from Gradio_UI import GradioUI
11
 
@@ -59,7 +60,7 @@ with open("prompts.yaml", 'r') as stream:
59
 
60
  agent = CodeAgent(
61
  model=model,
62
- tools=[visit_webpage, web_search, final_answer], ## add your tools here (don't remove final answer)
63
  max_steps=6,
64
  verbosity_level=1,
65
  grammar=None,
 
6
  from tools.final_answer import FinalAnswerTool
7
  from tools.visit_webpage import VisitWebpageTool
8
  from tools.web_search import DuckDuckGoSearchTool
9
+ from tools.image_generator import ImageGeneratorTool
10
 
11
  from Gradio_UI import GradioUI
12
 
 
60
 
61
  agent = CodeAgent(
62
  model=model,
63
+ tools=[visit_webpage, web_search, final_answer, image_generation_tool], ## add your tools here (don't remove final answer)
64
  max_steps=6,
65
  verbosity_level=1,
66
  grammar=None,
tools/image_generator.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents.tools import Tool
2
+
3
+ from PIL.Image import Image
4
+
5
+
6
+ class ImageGeneratorTool(Tool):
7
+ name = "image_generator"
8
+ description = "Generates an image based on your query."
9
+ inputs = {
10
+ "query": {
11
+ "type": "string",
12
+ "description": "The query to generate an image for.",
13
+ }
14
+ }
15
+ output_type = "any"
16
+
17
+ def __init__(self, **kwargs):
18
+ super().__init__()
19
+ try:
20
+ from huggingface_hub import InferenceClient
21
+ except ImportError as e:
22
+ raise ImportError(
23
+ "You must install package `huggingface_hub` to run this tool: for instance run `pip install huggingface_hub`."
24
+ ) from e
25
+ self.client = InferenceClient("black-forest-labs/FLUX.1-dev")
26
+
27
+ def forward(self, query: str) -> Image:
28
+ image: Image = self.client.text_to_image(query)
29
+ if image is None:
30
+ raise Exception("No results found! Try a less restrictive/shorter query.")
31
+ return image