File size: 6,334 Bytes
dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 0527a8f dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 97d8b63 dfd19f5 ce2d7d4 c1d3919 dfd19f5 97d8b63 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 b9a4880 dfd19f5 8310e6d c1d3919 8310e6d dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 d367dae c1d3919 d367dae c1d3919 dfd19f5 c1d3919 d367dae dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 d367dae c1d3919 dfd19f5 c1d3919 dfd19f5 c1d3919 dfd19f5 d367dae dfd19f5 c1d3919 d367dae c1d3919 dfd19f5 c1d3919 d367dae dfd19f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
"""
agent.py β Gemini-smolagents baseline using google-genai SDK
-----------------------------------------------------------
Environment
-----------
GOOGLE_API_KEY β API key from Google AI Studio
GAIA_API_URL β (optional) override for the GAIA scoring endpoint
"""
from __future__ import annotations
import base64
import mimetypes
import os
import re
from typing import List
import google.genai as genai
import requests
from google.genai import types as gtypes
from smolagents import (
CodeAgent,
DuckDuckGoSearchTool,
PythonInterpreterTool,
tool,
)
# --------------------------------------------------------------------------- #
# constants & helpers
# --------------------------------------------------------------------------- #
DEFAULT_API_URL = os.getenv(
"GAIA_API_URL", "https://agents-course-unit4-scoring.hf.space"
)
FILE_TAG = re.compile(r"<file:([^>]+)>") # <file:xyz>
def _download_file(file_id: str) -> bytes:
"""Download the attachment for a GAIA task."""
url = f"{DEFAULT_API_URL}/files/{file_id}"
resp = requests.get(url, timeout=30)
resp.raise_for_status()
return resp.content
# --------------------------------------------------------------------------- #
# model wrapper
# --------------------------------------------------------------------------- #
class GeminiModel:
"""
Minimal adapter around google-genai.Client so the instance itself is
callable (required by smolagents).
"""
def __init__(
self,
model_name: str = "gemini-2.0-flash",
temperature: float = 0.1,
max_tokens: int = 128,
):
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise EnvironmentError("GOOGLE_API_KEY is not set.")
self.client = genai.Client(api_key=api_key)
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
# internal helper -------------------------------------------------------- #
def _genai_call(
self,
contents,
system_instruction: str,
) -> str:
resp = self.client.models.generate_content(
model=self.model_name,
contents=contents,
config=gtypes.GenerateContentConfig(
system_instruction=system_instruction,
temperature=self.temperature,
max_output_tokens=self.max_tokens,
),
)
return resp.text.strip()
# public helpers --------------------------------------------------------- #
def __call__(self, prompt: str, system_instruction: str, **__) -> str:
"""Used by CodeAgent for plain-text questions."""
return self._genai_call(prompt, system_instruction)
def call_parts(
self,
parts: List[gtypes.Part],
system_instruction: str,
) -> str:
"""Multimodal path used by GeminiAgent for <file:β¦> questions."""
user_content = gtypes.Content(role="user", parts=parts)
return self._genai_call([user_content], system_instruction)
# --------------------------------------------------------------------------- #
# custom tool: fetch GAIA attachments
# --------------------------------------------------------------------------- #
@tool
def gaia_file_reader(file_id: str) -> str:
"""
Download a GAIA attachment and return its contents.
Args:
file_id: identifier that appears inside a <file:...> placeholder.
Returns:
base64-encoded string for binary files (images, PDFs, β¦) or decoded
UTF-8 text for textual files.
"""
try:
raw = _download_file(file_id)
mime = mimetypes.guess_type(file_id)[0] or "application/octet-stream"
if mime.startswith("text") or mime in ("application/json",):
return raw.decode(errors="ignore")
return base64.b64encode(raw).decode()
except Exception as exc: # pragma: no cover
return f"ERROR downloading {file_id}: {exc}"
# --------------------------------------------------------------------------- #
# final agent
# --------------------------------------------------------------------------- #
class GeminiAgent:
"""Instantiated once in app.py; called once per question."""
SYSTEM_PROMPT = (
"You are a concise, highly accurate assistant. "
"Unless explicitly required, reply with ONE short sentence. "
"Use the provided tools if needed. "
"All answers are graded by exact string match."
)
def __init__(self):
self.model = GeminiModel()
self.agent = CodeAgent(
model=self.model,
tools=[
PythonInterpreterTool(),
DuckDuckGoSearchTool(),
gaia_file_reader,
],
verbosity_level=0,
)
print("β
GeminiAgent initialised.")
# --------------------------------------------------------------------- #
# main entry point
# --------------------------------------------------------------------- #
def __call__(self, question: str) -> str:
file_ids = FILE_TAG.findall(question)
# ---------- multimodal branch (images / files) -------------------- #
if file_ids:
parts: List[gtypes.Part] = []
text_part = FILE_TAG.sub("", question).strip()
if text_part:
parts.append(gtypes.Part.from_text(text_part))
for fid in file_ids:
try:
img_bytes = _download_file(fid)
mime = mimetypes.guess_type(fid)[0] or "image/png"
parts.append(
gtypes.Part.from_bytes(data=img_bytes, mime_type=mime)
)
except Exception as exc:
parts.append(gtypes.Part.from_text(f"[FILE {fid} ERROR: {exc}]"))
answer = self.model.call_parts(parts, system_instruction=self.SYSTEM_PROMPT)
# ---------- plain-text branch ------------------------------------- #
else:
# Prepend system prompt to make sure CodeAgent->model sees it.
prompt = f"{self.SYSTEM_PROMPT}\n\n{question}"
answer = self.agent.model(prompt, system_instruction=self.SYSTEM_PROMPT)
return answer.rstrip(" .\n\r\t")
|