real-jiakai commited on
Commit
dfd19f5
Β·
verified Β·
1 Parent(s): 8887a2c

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +180 -0
agent.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ agent.py ­– Gemini-smolagents baseline using google-genai SDK
3
+ ------------------------------------------------------------
4
+ Environment variables
5
+ ---------------------
6
+ GOOGLE_API_KEY API key from Google AI Studio
7
+ Optional:
8
+ GAIA_API_URL GAIA evaluation endpoint (default: official URL)
9
+
10
+ This file defines:
11
+ β€’ GeminiModel – wraps google-genai for smolagents
12
+ β€’ gaia_file_reader – custom tool to fetch <file:xyz> attachments
13
+ β€’ GeminiAgent – CodeAgent with Python / Search / File tools + Gemini model
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import base64
19
+ import mimetypes
20
+ import requests
21
+ import google.genai as genai
22
+ from google.genai import types as gtypes
23
+ from smolagents import (
24
+ CodeAgent,
25
+ DuckDuckGoSearchTool,
26
+ PythonInterpreterTool,
27
+ BaseModel,
28
+ tool,
29
+ )
30
+
31
+ # --------------------------------------------------------------------------- #
32
+ # Constants & helpers
33
+ # --------------------------------------------------------------------------- #
34
+ DEFAULT_API_URL = os.getenv(
35
+ "GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space"
36
+ )
37
+ FILE_TAG = re.compile(r"<file:([^>]+)>")
38
+
39
+
40
+ def _download_file(file_id: str) -> bytes:
41
+ """Download the attachment for a GAIA task."""
42
+ url = f"{DEFAULT_API_URL}/files/{file_id}"
43
+ resp = requests.get(url, timeout=30)
44
+ resp.raise_for_status()
45
+ return resp.content
46
+
47
+
48
+ # --------------------------------------------------------------------------- #
49
+ # Model wrapper
50
+ # --------------------------------------------------------------------------- #
51
+ class GeminiModel(BaseModel):
52
+ """
53
+ Thin adapter around google-genai.Client so it can be used by smolagents.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model_name: str = "gemini-2.0-flash",
59
+ temperature: float = 0.1,
60
+ max_tokens: int = 128,
61
+ ):
62
+ api_key = os.getenv("GOOGLE_API_KEY")
63
+ if not api_key:
64
+ raise EnvironmentError("GOOGLE_API_KEY is not set.")
65
+ # One client per process is enough
66
+ self.client = genai.Client(api_key=api_key)
67
+ self.model_name = model_name
68
+ self.temperature = temperature
69
+ self.max_tokens = max_tokens
70
+
71
+ # ---------- Text-only convenience ---------- #
72
+ def call(self, prompt: str, **kwargs) -> str:
73
+ response = self.client.models.generate_content(
74
+ model=self.model_name,
75
+ contents=prompt,
76
+ generation_config=gtypes.GenerateContentConfig(
77
+ temperature=self.temperature,
78
+ max_output_tokens=self.max_tokens,
79
+ ),
80
+ )
81
+ return response.text.strip()
82
+
83
+ # ---------- smolagents will use this when messages are present ---------- #
84
+ def call_messages(self, messages, **kwargs) -> str:
85
+ """
86
+ `messages` is a list of dictionaries with keys 'role' | 'content'.
87
+ If `content` is already a list[types.Content], we forward it as-is.
88
+ Otherwise we concatenate to a single string prompt.
89
+ """
90
+ sys_msg, user_msg = messages # CodeAgent always sends two
91
+ if isinstance(user_msg["content"], list):
92
+ # Multimodal path – pass system text first, then structured user parts
93
+ contents = [sys_msg["content"], *user_msg["content"]]
94
+ else:
95
+ # Text prompt path
96
+ contents = f"{sys_msg['content']}\n\n{user_msg['content']}"
97
+ response = self.client.models.generate_content(
98
+ model=self.model_name,
99
+ contents=contents,
100
+ generation_config=gtypes.GenerateContentConfig(
101
+ temperature=self.temperature,
102
+ max_output_tokens=self.max_tokens,
103
+ ),
104
+ )
105
+ return response.text.strip()
106
+
107
+
108
+ # --------------------------------------------------------------------------- #
109
+ # Custom tool: fetch GAIA attachments
110
+ # --------------------------------------------------------------------------- #
111
+ @tool(name="gaia_file_reader", description="Download attachment referenced as <file:id>")
112
+ def gaia_file_reader(file_id: str) -> str:
113
+ """
114
+ Returns:
115
+ β€’ base64-str for binary files (images, pdf, etc.)
116
+ β€’ decoded text for utf-8 files
117
+ """
118
+ try:
119
+ raw = _download_file(file_id)
120
+ mime = mimetypes.guess_type(file_id)[0] or "application/octet-stream"
121
+ if mime.startswith("text") or mime in ("application/json",):
122
+ return raw.decode(errors="ignore")
123
+ return base64.b64encode(raw).decode()
124
+ except Exception as exc:
125
+ return f"ERROR downloading {file_id}: {exc}"
126
+
127
+
128
+ # --------------------------------------------------------------------------- #
129
+ # Final agent class
130
+ # --------------------------------------------------------------------------- #
131
+ class GeminiAgent:
132
+ """
133
+ Exposed to `app.py` – instantiated once and then called per question.
134
+ """
135
+
136
+ def __init__(self):
137
+ model = GeminiModel()
138
+ tools = [
139
+ PythonInterpreterTool(), # maths, csv, small image ops
140
+ DuckDuckGoSearchTool(), # quick web look-ups
141
+ gaia_file_reader, # our custom file tool
142
+ ]
143
+ self.system_prompt = (
144
+ "You are a concise, highly accurate assistant. "
145
+ "Unless explicitly required, reply with ONE short sentence. "
146
+ "Use the provided tools if needed. "
147
+ "All answers are graded by exact string match."
148
+ )
149
+ self.agent = CodeAgent(
150
+ model=model,
151
+ tools=tools,
152
+ system_prompt=self.system_prompt,
153
+ )
154
+ print("βœ… GeminiAgent (google-genai) initialised.")
155
+
156
+ # ----------- Main entry point for app.py ----------- #
157
+ def __call__(self, question: str) -> str:
158
+ file_ids = FILE_TAG.findall(question)
159
+ if file_ids:
160
+ # Build multimodal user content
161
+ parts: list[gtypes.Part] = []
162
+ text_part = FILE_TAG.sub("", question).strip()
163
+ if text_part:
164
+ parts.append(gtypes.Part.from_text(text_part))
165
+ for fid in file_ids:
166
+ try:
167
+ img_bytes = _download_file(fid)
168
+ mime = mimetypes.guess_type(fid)[0] or "image/png"
169
+ parts.append(gtypes.Part.from_bytes(data=img_bytes, mime_type=mime))
170
+ except Exception as exc:
171
+ parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]"))
172
+ messages = [
173
+ {"role": "system", "content": self.system_prompt},
174
+ {"role": "user", "content": parts},
175
+ ]
176
+ answer = self.agent.model.call_messages(messages)
177
+ else:
178
+ answer = self.agent(question)
179
+ # Trim trailing punctuation – GAIA scoring is case-/punctuation-sensitive
180
+ return answer.rstrip(" .\n\r\t")