real-jiakai's picture
Update agent.py
c1d3919 verified
raw
history blame
6.33 kB
"""
agent.py – Gemini-smolagents baseline using google-genai SDK
-----------------------------------------------------------
Environment
-----------
GOOGLE_API_KEY – API key from Google AI Studio
GAIA_API_URL – (optional) override for the GAIA scoring endpoint
"""
from __future__ import annotations
import base64
import mimetypes
import os
import re
from typing import List
import google.genai as genai
import requests
from google.genai import types as gtypes
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
PythonInterpreterTool,
tool,
)
# --------------------------------------------------------------------------- #
# constants & helpers
# --------------------------------------------------------------------------- #
DEFAULT_API_URL = os.getenv(
"GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space"
)
FILE_TAG = re.compile(r"<file:([^>]+)>") # <file:xyz>
def _download_file(file_id: str) -> bytes:
"""Download the attachment for a GAIA task."""
url = f"{DEFAULT_API_URL}/files/{file_id}"
resp = requests.get(url, timeout=30)
resp.raise_for_status()
return resp.content
# --------------------------------------------------------------------------- #
# model wrapper
# --------------------------------------------------------------------------- #
class GeminiModel:
"""
Minimal adapter around google-genai.Client so the instance itself is
callable (required by smolagents).
"""
def __init__(
self,
model_name: str = "gemini-2.0-flash",
temperature: float = 0.1,
max_tokens: int = 128,
):
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise EnvironmentError("GOOGLE_API_KEY is not set.")
self.client = genai.Client(api_key=api_key)
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
# internal helper -------------------------------------------------------- #
def _genai_call(
self,
contents,
system_instruction: str,
) -> str:
resp = self.client.models.generate_content(
model=self.model_name,
contents=contents,
config=gtypes.GenerateContentConfig(
system_instruction=system_instruction,
temperature=self.temperature,
max_output_tokens=self.max_tokens,
),
)
return resp.text.strip()
# public helpers --------------------------------------------------------- #
def __call__(self, prompt: str, system_instruction: str, **__) -> str:
"""Used by CodeAgent for plain-text questions."""
return self._genai_call(prompt, system_instruction)
def call_parts(
self,
parts: List[gtypes.Part],
system_instruction: str,
) -> str:
"""Multimodal path used by GeminiAgent for <file:…> questions."""
user_content = gtypes.Content(role="user", parts=parts)
return self._genai_call([user_content], system_instruction)
# --------------------------------------------------------------------------- #
# custom tool: fetch GAIA attachments
# --------------------------------------------------------------------------- #
@tool
def gaia_file_reader(file_id: str) -> str:
"""
Download a GAIA attachment and return its contents.
Args:
file_id: identifier that appears inside a <file:...> placeholder.
Returns:
base64-encoded string for binary files (images, PDFs, …) or decoded
UTF-8 text for textual files.
"""
try:
raw = _download_file(file_id)
mime = mimetypes.guess_type(file_id)[0] or "application/octet-stream"
if mime.startswith("text") or mime in ("application/json",):
return raw.decode(errors="ignore")
return base64.b64encode(raw).decode()
except Exception as exc: # pragma: no cover
return f"ERROR downloading {file_id}: {exc}"
# --------------------------------------------------------------------------- #
# final agent
# --------------------------------------------------------------------------- #
class GeminiAgent:
"""Instantiated once in app.py; called once per question."""
SYSTEM_PROMPT = (
"You are a concise, highly accurate assistant. "
"Unless explicitly required, reply with ONE short sentence. "
"Use the provided tools if needed. "
"All answers are graded by exact string match."
)
def __init__(self):
self.model = GeminiModel()
self.agent = CodeAgent(
model=self.model,
tools=[
PythonInterpreterTool(),
DuckDuckGoSearchTool(),
gaia_file_reader,
],
verbosity_level=0,
)
print("βœ… GeminiAgent initialised.")
# --------------------------------------------------------------------- #
# main entry point
# --------------------------------------------------------------------- #
def __call__(self, question: str) -> str:
file_ids = FILE_TAG.findall(question)
# ---------- multimodal branch (images / files) -------------------- #
if file_ids:
parts: List[gtypes.Part] = []
text_part = FILE_TAG.sub("", question).strip()
if text_part:
parts.append(gtypes.Part.from_text(text_part))
for fid in file_ids:
try:
img_bytes = _download_file(fid)
mime = mimetypes.guess_type(fid)[0] or "image/png"
parts.append(
gtypes.Part.from_bytes(data=img_bytes, mime_type=mime)
)
except Exception as exc:
parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]"))
answer = self.model.call_parts(parts, system_instruction=self.SYSTEM_PROMPT)
# ---------- plain-text branch ------------------------------------- #
else:
# Prepend system prompt to make sure CodeAgent->model sees it.
prompt = f"{self.SYSTEM_PROMPT}\n\n{question}"
answer = self.agent.model(prompt, system_instruction=self.SYSTEM_PROMPT)
return answer.rstrip(" .\n\r\t")