|
""" |
|
agent.py β Gemini-smolagents baseline using google-genai SDK |
|
----------------------------------------------------------- |
|
Environment |
|
----------- |
|
GOOGLE_API_KEY β API key from Google AI Studio |
|
GAIA_API_URL β (optional) override for the GAIA scoring endpoint |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import base64 |
|
import mimetypes |
|
import os |
|
import re |
|
from typing import List |
|
|
|
import google.genai as genai |
|
import requests |
|
from google.genai import types as gtypes |
|
from smolagents import ( |
|
CodeAgent, |
|
DuckDuckGoSearchTool, |
|
PythonInterpreterTool, |
|
tool, |
|
) |
|
|
|
|
|
|
|
|
|
DEFAULT_API_URL = os.getenv( |
|
"GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space" |
|
) |
|
FILE_TAG = re.compile(r"<file:([^>]+)>") |
|
|
|
def _download_file(file_id: str) -> bytes: |
|
"""Download the attachment for a GAIA task.""" |
|
url = f"{DEFAULT_API_URL}/files/{file_id}" |
|
resp = requests.get(url, timeout=30) |
|
resp.raise_for_status() |
|
return resp.content |
|
|
|
|
|
|
|
|
|
class GeminiModel: |
|
""" |
|
Minimal adapter around google-genai.Client so the instance itself is |
|
callable (required by smolagents). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "gemini-2.0-flash", |
|
temperature: float = 0.1, |
|
max_tokens: int = 128, |
|
): |
|
api_key = os.getenv("GOOGLE_API_KEY") |
|
if not api_key: |
|
raise EnvironmentError("GOOGLE_API_KEY is not set.") |
|
self.client = genai.Client(api_key=api_key) |
|
|
|
self.model_name = model_name |
|
self.temperature = temperature |
|
self.max_tokens = max_tokens |
|
|
|
|
|
def _genai_call( |
|
self, |
|
contents, |
|
system_instruction: str, |
|
) -> str: |
|
resp = self.client.models.generate_content( |
|
model=self.model_name, |
|
contents=contents, |
|
config=gtypes.GenerateContentConfig( |
|
system_instruction=system_instruction, |
|
temperature=self.temperature, |
|
max_output_tokens=self.max_tokens, |
|
), |
|
) |
|
return resp.text.strip() |
|
|
|
|
|
def __call__(self, prompt: str, system_instruction: str, **__) -> str: |
|
"""Used by CodeAgent for plain-text questions.""" |
|
return self._genai_call(prompt, system_instruction) |
|
|
|
def call_parts( |
|
self, |
|
parts: List[gtypes.Part], |
|
system_instruction: str, |
|
) -> str: |
|
"""Multimodal path used by GeminiAgent for <file:β¦> questions.""" |
|
user_content = gtypes.Content(role="user", parts=parts) |
|
return self._genai_call([user_content], system_instruction) |
|
|
|
|
|
|
|
|
|
@tool |
|
def gaia_file_reader(file_id: str) -> str: |
|
""" |
|
Download a GAIA attachment and return its contents. |
|
|
|
Args: |
|
file_id: identifier that appears inside a <file:...> placeholder. |
|
|
|
Returns: |
|
base64-encoded string for binary files (images, PDFs, β¦) or decoded |
|
UTF-8 text for textual files. |
|
""" |
|
try: |
|
raw = _download_file(file_id) |
|
mime = mimetypes.guess_type(file_id)[0] or "application/octet-stream" |
|
if mime.startswith("text") or mime in ("application/json",): |
|
return raw.decode(errors="ignore") |
|
return base64.b64encode(raw).decode() |
|
except Exception as exc: |
|
return f"ERROR downloading {file_id}: {exc}" |
|
|
|
|
|
|
|
|
|
class GeminiAgent: |
|
"""Instantiated once in app.py; called once per question.""" |
|
|
|
SYSTEM_PROMPT = ( |
|
"You are a concise, highly accurate assistant. " |
|
"Unless explicitly required, reply with ONE short sentence. " |
|
"Use the provided tools if needed. " |
|
"All answers are graded by exact string match." |
|
) |
|
|
|
def __init__(self): |
|
self.model = GeminiModel() |
|
self.agent = CodeAgent( |
|
model=self.model, |
|
tools=[ |
|
PythonInterpreterTool(), |
|
DuckDuckGoSearchTool(), |
|
gaia_file_reader, |
|
], |
|
verbosity_level=0, |
|
) |
|
print("β
GeminiAgent initialised.") |
|
|
|
|
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
file_ids = FILE_TAG.findall(question) |
|
|
|
|
|
if file_ids: |
|
parts: List[gtypes.Part] = [] |
|
|
|
text_part = FILE_TAG.sub("", question).strip() |
|
if text_part: |
|
parts.append(gtypes.Part.from_text(text_part)) |
|
|
|
for fid in file_ids: |
|
try: |
|
img_bytes = _download_file(fid) |
|
mime = mimetypes.guess_type(fid)[0] or "image/png" |
|
parts.append( |
|
gtypes.Part.from_bytes(data=img_bytes, mime_type=mime) |
|
) |
|
except Exception as exc: |
|
parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]")) |
|
|
|
answer = self.model.call_parts(parts, system_instruction=self.SYSTEM_PROMPT) |
|
|
|
|
|
else: |
|
|
|
prompt = f"{self.SYSTEM_PROMPT}\n\n{question}" |
|
answer = self.agent.model(prompt, system_instruction=self.SYSTEM_PROMPT) |
|
|
|
return answer.rstrip(" .\n\r\t") |
|
|