Test_Magus / agent.py
SergeyO7's picture
Update agent.py
6a0cd84 verified
raw
history blame
5.96 kB
from smolagents import CodeAgent, LiteLLMModel, tool, load_tool, DuckDuckGoSearchTool, WikipediaSearchTool #, HfApiModel, OpenAIServerModel
import asyncio
import os
import re
import yaml
from PIL import Image
import requests
from io import BytesIO
import whisper
# Simulated additional tools (implementation depends on external APIs or setup)
#@tool
#def GoogleSearchTool(query: str) -> str:
# """Tool for performing Google searches using Custom Search JSON API
# Args:
# query (str): Search query string
# Returns:
# str: Formatted search results
# """
# cse_id = os.environ.get("GOOGLE_CSE_ID")
# if not api_key or not cse_id:
# raise ValueError("GOOGLE_API_KEY and GOOGLE_CSE_ID must be set in environment variables.")
# url = "https://www.googleapis.com/customsearch/v1"
# params = {
# "key": api_key,
# "cx": cse_id,
# "q": query,
# "num": 5 # Number of results to return
# }
# try:
# response = requests.get(url, params=params)
# response.raise_for_status()
# results = response.json().get("items", [])
# return "\n".join([f"{item['title']}: {item['link']}" for item in results]) or "No results found."
# except Exception as e:
# return f"Error performing Google search: {str(e)}"
@tool
def ImageAnalysisTool(question: str) -> str:
"""Tool for analyzing images mentioned in the question.
Args:
question (str): The question text which may contain an image URL.
Returns:
str: Image description or error message.
"""
# Extract URL from question using regex
url_pattern = r'https?://\S+'
match = re.search(url_pattern, question)
if not match:
return "No image URL found in the question."
image_url = match.group(0)
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"
}
try:
response = requests.get(image_url, headers=headers)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
return f"Error fetching image: {e}"
model = LiteLLMModel(
model_id="gemini/gemini-2.5-pro",
api_key=os.environ.get("GEMINI_KEY"),
max_tokens=8192
)
agent = CodeAgent(
tools=[],
model=model,
max_steps=20,
verbosity_level=2
)
response = agent.run(
"Describe in details the chess position you see in the image.",
images=[image]
)
return f"The image description: '{response}'"
@tool
def SpeechToTextTool(audio_path: str) -> str:
"""Tool for converting an audio file to text using OpenAI Whisper.
Args:
audio_path (str): Path to audio file
Returns:
str: audio speech text
"""
model = whisper.load_model("base")
if not os.path.exists(audio_path):
return f"Error: File not found at {audio_path}"
result = model.transcribe(audio_path)
return result.get("text", "")
#@tool
#def youtube_transcript(url: str) -> str:
# """
# Get transcript of YouTube video.
# Args:
# url: YouTube video url in ""
# """
# video_id = url.partition("https://www.youtube.com/watch?v=")[2]
# transcript = YouTubeTranscriptApi.get_transcript(video_id)
# transcript_text = " ".join([item["text"] for item in transcript])
# return {"youtube_transcript": transcript_text}
#@tool
#class LocalFileAudioTool:
# """Tool for transcribing audio files"""
#
# @tool
# def transcribe(self, file_path: str) -> str:
# """Transcribe audio from file
# Args:
# file_path (str): Path to audio file
# Returns:
# str: Transcription text
# """
# return f"Transcribed audio from '{file_path}' (simulated)."
class MagAgent:
def __init__(self):
"""Initialize the MagAgent with search tools."""
print("Initializing MagAgent with search tools...")
model = LiteLLMModel(
model_id="gemini/gemini-2.0-flash",
api_key= os.environ.get("GEMINI_KEY"),
max_tokens=8192
)
# Load prompt templates
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
self.agent = CodeAgent(
model= model,
tools=[
# youtube_transcript,
# GoogleSearchTool,
DuckDuckGoSearchTool(),
WikipediaSearchTool(),
ImageAnalysisTool,
SpeechToTextTool
# LocalFileAudioTool()
]
)
print("MagAgent initialized.")
async def __call__(self, question: str) -> str:
"""Process a question asynchronously using the MagAgent."""
print(f"MagAgent received question (first 50 chars): {question[:50]}...")
try:
# Define a task with fallback search logic
task = (
f"Answer the following question accurately and concisely: {question}\n"
)
response = await asyncio.to_thread(
self.agent.run,
task=task
)
# Ensure response is a string, fixing the integer error
response = str(response) if response is not None else "No answer found."
if not response or "No Wikipedia page found" in response:
# Fallback response if search fails
response = "Unable to retrieve exact data. Please refine the question or check external sources."
print(f"MagAgent response: {response[:50]}...")
return response
except Exception as e:
error_msg = f"Error processing question: {str(e)}. Check API key or network connectivity."
print(error_msg)
return error_msg