|
""" |
|
agent.py Ββ Gemini-smolagents baseline using google-genai SDK |
|
------------------------------------------------------------ |
|
Environment variables |
|
--------------------- |
|
GOOGLE_API_KEY API key from Google AI Studio |
|
Optional: |
|
GAIA_API_URL GAIA evaluation endpoint (default: official URL) |
|
|
|
This file defines: |
|
β’ GeminiModel β wraps google-genai for smolagents |
|
β’ gaia_file_reader β custom tool to fetch <file:xyz> attachments |
|
β’ GeminiAgent β CodeAgent with Python / Search / File tools + Gemini model |
|
""" |
|
|
|
import os |
|
import re |
|
import base64 |
|
import mimetypes |
|
import requests |
|
import google.genai as genai |
|
from google.genai import types as gtypes |
|
from smolagents import ( |
|
CodeAgent, |
|
DuckDuckGoSearchTool, |
|
PythonInterpreterTool, |
|
BaseModel, |
|
tool, |
|
) |
|
|
|
|
|
|
|
|
|
DEFAULT_API_URL = os.getenv( |
|
"GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space" |
|
) |
|
FILE_TAG = re.compile(r"<file:([^>]+)>") |
|
|
|
|
|
def _download_file(file_id: str) -> bytes: |
|
"""Download the attachment for a GAIA task.""" |
|
url = f"{DEFAULT_API_URL}/files/{file_id}" |
|
resp = requests.get(url, timeout=30) |
|
resp.raise_for_status() |
|
return resp.content |
|
|
|
|
|
|
|
|
|
|
|
class GeminiModel(BaseModel): |
|
""" |
|
Thin adapter around google-genai.Client so it can be used by smolagents. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "gemini-2.0-flash", |
|
temperature: float = 0.1, |
|
max_tokens: int = 128, |
|
): |
|
api_key = os.getenv("GOOGLE_API_KEY") |
|
if not api_key: |
|
raise EnvironmentError("GOOGLE_API_KEY is not set.") |
|
|
|
self.client = genai.Client(api_key=api_key) |
|
self.model_name = model_name |
|
self.temperature = temperature |
|
self.max_tokens = max_tokens |
|
|
|
|
|
def call(self, prompt: str, **kwargs) -> str: |
|
response = self.client.models.generate_content( |
|
model=self.model_name, |
|
contents=prompt, |
|
generation_config=gtypes.GenerateContentConfig( |
|
temperature=self.temperature, |
|
max_output_tokens=self.max_tokens, |
|
), |
|
) |
|
return response.text.strip() |
|
|
|
|
|
def call_messages(self, messages, **kwargs) -> str: |
|
""" |
|
`messages` is a list of dictionaries with keys 'role' | 'content'. |
|
If `content` is already a list[types.Content], we forward it as-is. |
|
Otherwise we concatenate to a single string prompt. |
|
""" |
|
sys_msg, user_msg = messages |
|
if isinstance(user_msg["content"], list): |
|
|
|
contents = [sys_msg["content"], *user_msg["content"]] |
|
else: |
|
|
|
contents = f"{sys_msg['content']}\n\n{user_msg['content']}" |
|
response = self.client.models.generate_content( |
|
model=self.model_name, |
|
contents=contents, |
|
generation_config=gtypes.GenerateContentConfig( |
|
temperature=self.temperature, |
|
max_output_tokens=self.max_tokens, |
|
), |
|
) |
|
return response.text.strip() |
|
|
|
|
|
|
|
|
|
|
|
@tool(name="gaia_file_reader", description="Download attachment referenced as <file:id>") |
|
def gaia_file_reader(file_id: str) -> str: |
|
""" |
|
Returns: |
|
β’ base64-str for binary files (images, pdf, etc.) |
|
β’ decoded text for utf-8 files |
|
""" |
|
try: |
|
raw = _download_file(file_id) |
|
mime = mimetypes.guess_type(file_id)[0] or "application/octet-stream" |
|
if mime.startswith("text") or mime in ("application/json",): |
|
return raw.decode(errors="ignore") |
|
return base64.b64encode(raw).decode() |
|
except Exception as exc: |
|
return f"ERROR downloading {file_id}: {exc}" |
|
|
|
|
|
|
|
|
|
|
|
class GeminiAgent: |
|
""" |
|
Exposed to `app.py` β instantiated once and then called per question. |
|
""" |
|
|
|
def __init__(self): |
|
model = GeminiModel() |
|
tools = [ |
|
PythonInterpreterTool(), |
|
DuckDuckGoSearchTool(), |
|
gaia_file_reader, |
|
] |
|
self.system_prompt = ( |
|
"You are a concise, highly accurate assistant. " |
|
"Unless explicitly required, reply with ONE short sentence. " |
|
"Use the provided tools if needed. " |
|
"All answers are graded by exact string match." |
|
) |
|
self.agent = CodeAgent( |
|
model=model, |
|
tools=tools, |
|
system_prompt=self.system_prompt, |
|
) |
|
print("β
GeminiAgent (google-genai) initialised.") |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
file_ids = FILE_TAG.findall(question) |
|
if file_ids: |
|
|
|
parts: list[gtypes.Part] = [] |
|
text_part = FILE_TAG.sub("", question).strip() |
|
if text_part: |
|
parts.append(gtypes.Part.from_text(text_part)) |
|
for fid in file_ids: |
|
try: |
|
img_bytes = _download_file(fid) |
|
mime = mimetypes.guess_type(fid)[0] or "image/png" |
|
parts.append(gtypes.Part.from_bytes(data=img_bytes, mime_type=mime)) |
|
except Exception as exc: |
|
parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]")) |
|
messages = [ |
|
{"role": "system", "content": self.system_prompt}, |
|
{"role": "user", "content": parts}, |
|
] |
|
answer = self.agent.model.call_messages(messages) |
|
else: |
|
answer = self.agent(question) |
|
|
|
return answer.rstrip(" .\n\r\t") |
|
|