real-jiakai commited on
Commit
c1d3919
Β·
verified Β·
1 Parent(s): ce2d7d4

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +74 -84
agent.py CHANGED
@@ -1,24 +1,22 @@
1
  """
2
- agent.py ­– Gemini-smolagents baseline using google-genai SDK
3
- ------------------------------------------------------------
4
- Environment variables
5
- ---------------------
6
- GOOGLE_API_KEY API key from Google AI Studio
7
- Optional:
8
- GAIA_API_URL GAIA evaluation endpoint (default: official URL)
9
-
10
- This file defines:
11
- β€’ GeminiModel – wraps google-genai for smolagents
12
- β€’ gaia_file_reader – custom tool to fetch <file:xyz> attachments
13
- β€’ GeminiAgent – CodeAgent with Python / Search / File tools + Gemini model
14
  """
15
 
16
- import os
17
- import re
18
  import base64
19
  import mimetypes
20
- import requests
 
 
 
21
  import google.genai as genai
 
22
  from google.genai import types as gtypes
23
  from smolagents import (
24
  CodeAgent,
@@ -28,13 +26,12 @@ from smolagents import (
28
  )
29
 
30
  # --------------------------------------------------------------------------- #
31
- # Constants & helpers
32
  # --------------------------------------------------------------------------- #
33
  DEFAULT_API_URL = os.getenv(
34
  "GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space"
35
  )
36
- FILE_TAG = re.compile(r"<file:([^>]+)>")
37
-
38
 
39
  def _download_file(file_id: str) -> bytes:
40
  """Download the attachment for a GAIA task."""
@@ -43,13 +40,13 @@ def _download_file(file_id: str) -> bytes:
43
  resp.raise_for_status()
44
  return resp.content
45
 
46
-
47
  # --------------------------------------------------------------------------- #
48
- # Model wrapper
49
  # --------------------------------------------------------------------------- #
50
  class GeminiModel:
51
  """
52
- Thin adapter around google-genai Client for smolagents.
 
53
  """
54
 
55
  def __init__(
@@ -62,47 +59,44 @@ class GeminiModel:
62
  if not api_key:
63
  raise EnvironmentError("GOOGLE_API_KEY is not set.")
64
  self.client = genai.Client(api_key=api_key)
 
65
  self.model_name = model_name
66
  self.temperature = temperature
67
  self.max_tokens = max_tokens
68
 
69
- # ---------- main generation helpers ---------- #
70
- def call(self, prompt: str, **kwargs) -> str:
71
- """Text-only helper used by __call__."""
72
- resp = self.client.models.generate_content(
73
- model=self.model_name,
74
- contents=prompt,
75
- config=gtypes.GenerateContentConfig(
76
- temperature=self.temperature,
77
- max_output_tokens=self.max_tokens,
78
- ),
79
- )
80
- return resp.text.strip()
81
-
82
- def call_messages(self, messages, **kwargs) -> str:
83
- sys_msg, user_msg = messages
84
- contents = (
85
- [sys_msg["content"], *user_msg["content"]]
86
- if isinstance(user_msg["content"], list)
87
- else f"{sys_msg['content']}\n\n{user_msg['content']}"
88
- )
89
  resp = self.client.models.generate_content(
90
  model=self.model_name,
91
  contents=contents,
92
  config=gtypes.GenerateContentConfig(
 
93
  temperature=self.temperature,
94
  max_output_tokens=self.max_tokens,
95
  ),
96
  )
97
  return resp.text.strip()
98
 
99
- # ---------- make the instance itself callable ---------- #
100
- def __call__(self, prompt: str, **kwargs) -> str: # <-- NEW
101
- return self.call(prompt, **kwargs)
 
102
 
 
 
 
 
 
 
 
 
103
 
104
  # --------------------------------------------------------------------------- #
105
- # Custom tool: fetch GAIA attachments
106
  # --------------------------------------------------------------------------- #
107
  @tool
108
  def gaia_file_reader(file_id: str) -> str:
@@ -110,12 +104,11 @@ def gaia_file_reader(file_id: str) -> str:
110
  Download a GAIA attachment and return its contents.
111
 
112
  Args:
113
- file_id: The identifier that appears inside a <file:...> placeholder
114
- in the GAIA question prompt.
115
 
116
  Returns:
117
- A base-64 string for binary files (images, PDF, etc.) or UTF-8 text for
118
- plain-text files.
119
  """
120
  try:
121
  raw = _download_file(file_id)
@@ -123,47 +116,49 @@ def gaia_file_reader(file_id: str) -> str:
123
  if mime.startswith("text") or mime in ("application/json",):
124
  return raw.decode(errors="ignore")
125
  return base64.b64encode(raw).decode()
126
- except Exception as exc:
127
  return f"ERROR downloading {file_id}: {exc}"
128
 
129
-
130
  # --------------------------------------------------------------------------- #
131
- # Final agent class
132
  # --------------------------------------------------------------------------- #
133
  class GeminiAgent:
134
- def __init__(self):
135
- self.system_prompt = (
136
- "You are a concise, highly accurate assistant. "
137
- "Unless explicitly required, reply with ONE short sentence. "
138
- "Use the provided tools if needed. "
139
- "All answers are graded by exact string match."
140
- )
141
 
142
- model = GeminiModel()
143
- tools = [
144
- PythonInterpreterTool(),
145
- DuckDuckGoSearchTool(),
146
- gaia_file_reader,
147
- ]
148
 
149
- # ✨ system_prompt removed – newest smolagents doesn't take it
 
150
  self.agent = CodeAgent(
151
- model=model,
152
- tools=tools,
153
- # any other kwargs (executor_type, additional_authorized_imports…)
 
 
 
154
  verbosity_level=0,
155
  )
156
- print("βœ… GeminiAgent ready.")
157
 
 
 
 
158
  def __call__(self, question: str) -> str:
159
  file_ids = FILE_TAG.findall(question)
160
 
161
- # -------- multimodal branch -------- #
162
  if file_ids:
163
- parts: list[gtypes.Part] = []
 
164
  text_part = FILE_TAG.sub("", question).strip()
165
  if text_part:
166
  parts.append(gtypes.Part.from_text(text_part))
 
167
  for fid in file_ids:
168
  try:
169
  img_bytes = _download_file(fid)
@@ -172,19 +167,14 @@ class GeminiAgent:
172
  gtypes.Part.from_bytes(data=img_bytes, mime_type=mime)
173
  )
174
  except Exception as exc:
175
- parts.append(
176
- gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]")
177
- )
178
- messages = [
179
- {"role": "system", "content": self.system_prompt},
180
- {"role": "user", "content": parts},
181
- ]
182
- answer = self.agent.model.call_messages(messages)
183
 
184
- # -------- text-only branch -------- #
185
  else:
186
- # prepend system prompt to the user question
187
- full_prompt = f"{self.system_prompt}\n\n{question}"
188
- answer = self.agent(full_prompt)
189
 
190
  return answer.rstrip(" .\n\r\t")
 
1
  """
2
+ agent.py – Gemini-smolagents baseline using google-genai SDK
3
+ -----------------------------------------------------------
4
+ Environment
5
+ -----------
6
+ GOOGLE_API_KEY – API key from Google AI Studio
7
+ GAIA_API_URL – (optional) override for the GAIA scoring endpoint
 
 
 
 
 
 
8
  """
9
 
10
+ from __future__ import annotations
11
+
12
  import base64
13
  import mimetypes
14
+ import os
15
+ import re
16
+ from typing import List
17
+
18
  import google.genai as genai
19
+ import requests
20
  from google.genai import types as gtypes
21
  from smolagents import (
22
  CodeAgent,
 
26
  )
27
 
28
  # --------------------------------------------------------------------------- #
29
+ # constants & helpers
30
  # --------------------------------------------------------------------------- #
31
  DEFAULT_API_URL = os.getenv(
32
  "GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space"
33
  )
34
+ FILE_TAG = re.compile(r"<file:([^>]+)>") # <file:xyz>
 
35
 
36
  def _download_file(file_id: str) -> bytes:
37
  """Download the attachment for a GAIA task."""
 
40
  resp.raise_for_status()
41
  return resp.content
42
 
 
43
  # --------------------------------------------------------------------------- #
44
+ # model wrapper
45
  # --------------------------------------------------------------------------- #
46
  class GeminiModel:
47
  """
48
+ Minimal adapter around google-genai.Client so the instance itself is
49
+ callable (required by smolagents).
50
  """
51
 
52
  def __init__(
 
59
  if not api_key:
60
  raise EnvironmentError("GOOGLE_API_KEY is not set.")
61
  self.client = genai.Client(api_key=api_key)
62
+
63
  self.model_name = model_name
64
  self.temperature = temperature
65
  self.max_tokens = max_tokens
66
 
67
+ # internal helper -------------------------------------------------------- #
68
+ def _genai_call(
69
+ self,
70
+ contents,
71
+ system_instruction: str,
72
+ ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  resp = self.client.models.generate_content(
74
  model=self.model_name,
75
  contents=contents,
76
  config=gtypes.GenerateContentConfig(
77
+ system_instruction=system_instruction,
78
  temperature=self.temperature,
79
  max_output_tokens=self.max_tokens,
80
  ),
81
  )
82
  return resp.text.strip()
83
 
84
+ # public helpers --------------------------------------------------------- #
85
+ def __call__(self, prompt: str, system_instruction: str, **__) -> str:
86
+ """Used by CodeAgent for plain-text questions."""
87
+ return self._genai_call(prompt, system_instruction)
88
 
89
+ def call_parts(
90
+ self,
91
+ parts: List[gtypes.Part],
92
+ system_instruction: str,
93
+ ) -> str:
94
+ """Multimodal path used by GeminiAgent for <file:…> questions."""
95
+ user_content = gtypes.Content(role="user", parts=parts)
96
+ return self._genai_call([user_content], system_instruction)
97
 
98
  # --------------------------------------------------------------------------- #
99
+ # custom tool: fetch GAIA attachments
100
  # --------------------------------------------------------------------------- #
101
  @tool
102
  def gaia_file_reader(file_id: str) -> str:
 
104
  Download a GAIA attachment and return its contents.
105
 
106
  Args:
107
+ file_id: identifier that appears inside a <file:...> placeholder.
 
108
 
109
  Returns:
110
+ base64-encoded string for binary files (images, PDFs, …) or decoded
111
+ UTF-8 text for textual files.
112
  """
113
  try:
114
  raw = _download_file(file_id)
 
116
  if mime.startswith("text") or mime in ("application/json",):
117
  return raw.decode(errors="ignore")
118
  return base64.b64encode(raw).decode()
119
+ except Exception as exc: # pragma: no cover
120
  return f"ERROR downloading {file_id}: {exc}"
121
 
 
122
  # --------------------------------------------------------------------------- #
123
+ # final agent
124
  # --------------------------------------------------------------------------- #
125
  class GeminiAgent:
126
+ """Instantiated once in app.py; called once per question."""
 
 
 
 
 
 
127
 
128
+ SYSTEM_PROMPT = (
129
+ "You are a concise, highly accurate assistant. "
130
+ "Unless explicitly required, reply with ONE short sentence. "
131
+ "Use the provided tools if needed. "
132
+ "All answers are graded by exact string match."
133
+ )
134
 
135
+ def __init__(self):
136
+ self.model = GeminiModel()
137
  self.agent = CodeAgent(
138
+ model=self.model,
139
+ tools=[
140
+ PythonInterpreterTool(),
141
+ DuckDuckGoSearchTool(),
142
+ gaia_file_reader,
143
+ ],
144
  verbosity_level=0,
145
  )
146
+ print("βœ… GeminiAgent initialised.")
147
 
148
+ # --------------------------------------------------------------------- #
149
+ # main entry point
150
+ # --------------------------------------------------------------------- #
151
  def __call__(self, question: str) -> str:
152
  file_ids = FILE_TAG.findall(question)
153
 
154
+ # ---------- multimodal branch (images / files) -------------------- #
155
  if file_ids:
156
+ parts: List[gtypes.Part] = []
157
+
158
  text_part = FILE_TAG.sub("", question).strip()
159
  if text_part:
160
  parts.append(gtypes.Part.from_text(text_part))
161
+
162
  for fid in file_ids:
163
  try:
164
  img_bytes = _download_file(fid)
 
167
  gtypes.Part.from_bytes(data=img_bytes, mime_type=mime)
168
  )
169
  except Exception as exc:
170
+ parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]"))
171
+
172
+ answer = self.model.call_parts(parts, system_instruction=self.SYSTEM_PROMPT)
 
 
 
 
 
173
 
174
+ # ---------- plain-text branch ------------------------------------- #
175
  else:
176
+ # Prepend system prompt to make sure CodeAgent->model sees it.
177
+ prompt = f"{self.SYSTEM_PROMPT}\n\n{question}"
178
+ answer = self.agent.model(prompt, system_instruction=self.SYSTEM_PROMPT)
179
 
180
  return answer.rstrip(" .\n\r\t")