Spaces:
Running
Running
Update agents.py
Browse files
agents.py
CHANGED
@@ -56,6 +56,44 @@ STANDARD_TOOL_SCHEMA = {
|
|
56 |
},
|
57 |
}
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
class LLMAgentBase(Player):
|
61 |
def __init__(self, *args, **kwargs):
|
@@ -202,26 +240,15 @@ class GeminiAgent(LLMAgentBase):
|
|
202 |
if not used_api_key:
|
203 |
raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
|
204 |
|
205 |
-
# Initialize Gemini client
|
206 |
-
genai.
|
207 |
|
208 |
-
# Configure the
|
209 |
-
self.
|
210 |
-
{
|
211 |
-
"function_declarations": list(self.standard_tools.values())
|
212 |
-
}
|
213 |
-
]
|
214 |
-
|
215 |
-
# Initialize the model
|
216 |
-
self.model = genai.GenerativeModel(
|
217 |
-
model_name=self.model_name,
|
218 |
-
tools=self.gemini_tool_config
|
219 |
-
)
|
220 |
|
221 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
222 |
"""Sends state to the Gemini API and gets back the function call decision."""
|
223 |
prompt = (
|
224 |
-
"You are a skilled Pokemon battle AI. Your goal is to win the battle. "
|
225 |
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
|
226 |
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
|
227 |
"Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
|
@@ -231,49 +258,40 @@ class GeminiAgent(LLMAgentBase):
|
|
231 |
)
|
232 |
|
233 |
try:
|
234 |
-
#
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
|
|
|
|
238 |
)
|
239 |
-
print("GEMINI RESPONSE : ",response)
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
function_name = fc.name
|
261 |
-
# Convert arguments to dict
|
262 |
-
arguments = {}
|
263 |
-
if fc.args:
|
264 |
-
arguments = {k: v for k, v in fc.args.items()}
|
265 |
-
|
266 |
-
if function_name in self.standard_tools:
|
267 |
-
return {"decision": {"name": function_name, "arguments": arguments}}
|
268 |
-
else:
|
269 |
-
return {"error": f"Model called unknown function '{function_name}'. Args: {arguments}"}
|
270 |
|
271 |
-
#
|
272 |
-
|
273 |
-
part.text if hasattr(part, 'text') else str(part)
|
274 |
-
for part in candidate.content.parts
|
275 |
-
])
|
276 |
-
return {"error": f"Gemini did not return a function call. Response: {text_content[:100]}..."}
|
277 |
|
278 |
except Exception as e:
|
279 |
print(f"Unexpected error during Gemini processing: {e}")
|
@@ -293,8 +311,8 @@ class OpenAIAgent(LLMAgentBase):
|
|
293 |
raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
|
294 |
self.openai_client = AsyncOpenAI(api_key=used_api_key)
|
295 |
|
296 |
-
#
|
297 |
-
self.openai_tools = list(
|
298 |
|
299 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
300 |
system_prompt = (
|
@@ -354,8 +372,17 @@ class MistralAgent(LLMAgentBase):
|
|
354 |
raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
|
355 |
self.mistral_client = Mistral(api_key=used_api_key)
|
356 |
|
357 |
-
# Convert standard schema to Mistral's tool format
|
358 |
-
self.mistral_tools =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
361 |
system_prompt = (
|
@@ -368,23 +395,29 @@ class MistralAgent(LLMAgentBase):
|
|
368 |
user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
|
369 |
|
370 |
try:
|
371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
model=self.model,
|
373 |
-
messages=
|
374 |
-
{"role": "system", "content": system_prompt},
|
375 |
-
{"role": "user", "content": user_prompt}
|
376 |
-
],
|
377 |
tools=self.mistral_tools,
|
378 |
-
tool_choice="
|
379 |
-
temperature=0.
|
380 |
)
|
381 |
-
print("Mistral RESPONSE : ",response)
|
382 |
-
|
383 |
# Check for tool calls in the response
|
384 |
-
|
|
|
385 |
tool_call = message.tool_calls[0] # Get the first tool call
|
386 |
function_name = tool_call.function.name
|
387 |
try:
|
|
|
388 |
arguments = json.loads(tool_call.function.arguments or '{}')
|
389 |
if function_name in self.standard_tools:
|
390 |
return {"decision": {"name": function_name, "arguments": arguments}}
|
@@ -393,9 +426,11 @@ class MistralAgent(LLMAgentBase):
|
|
393 |
except json.JSONDecodeError:
|
394 |
return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
|
395 |
else:
|
396 |
-
# Model
|
397 |
return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
|
398 |
|
399 |
except Exception as e:
|
400 |
print(f"Error during Mistral API call: {e}")
|
|
|
|
|
401 |
return {"error": f"Unexpected error: {str(e)}"}
|
|
|
56 |
},
|
57 |
}
|
58 |
|
59 |
+
# --- OpenAI Tools Schema (with 'type' field) ---
|
60 |
+
OPENAI_TOOL_SCHEMA = {
|
61 |
+
"choose_move": {
|
62 |
+
"type": "function",
|
63 |
+
"function": {
|
64 |
+
"name": "choose_move",
|
65 |
+
"description": "Selects and executes an available attacking or status move.",
|
66 |
+
"parameters": {
|
67 |
+
"type": "object",
|
68 |
+
"properties": {
|
69 |
+
"move_name": {
|
70 |
+
"type": "string",
|
71 |
+
"description": "The exact name or ID (e.g., 'thunderbolt', 'swordsdance') of the move to use. Must be one of the available moves.",
|
72 |
+
},
|
73 |
+
},
|
74 |
+
"required": ["move_name"],
|
75 |
+
},
|
76 |
+
}
|
77 |
+
},
|
78 |
+
"choose_switch": {
|
79 |
+
"type": "function",
|
80 |
+
"function": {
|
81 |
+
"name": "choose_switch",
|
82 |
+
"description": "Selects an available Pokémon from the bench to switch into.",
|
83 |
+
"parameters": {
|
84 |
+
"type": "object",
|
85 |
+
"properties": {
|
86 |
+
"pokemon_name": {
|
87 |
+
"type": "string",
|
88 |
+
"description": "The exact name of the Pokémon species to switch to (e.g., 'Pikachu', 'Charizard'). Must be one of the available switches.",
|
89 |
+
},
|
90 |
+
},
|
91 |
+
"required": ["pokemon_name"],
|
92 |
+
},
|
93 |
+
}
|
94 |
+
},
|
95 |
+
}
|
96 |
+
|
97 |
|
98 |
class LLMAgentBase(Player):
|
99 |
def __init__(self, *args, **kwargs):
|
|
|
240 |
if not used_api_key:
|
241 |
raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
|
242 |
|
243 |
+
# Initialize Gemini client using the correct API
|
244 |
+
self.genai_client = genai.Client(api_key=used_api_key)
|
245 |
|
246 |
+
# Configure the tools for function calling
|
247 |
+
self.function_declarations = list(self.standard_tools.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
250 |
"""Sends state to the Gemini API and gets back the function call decision."""
|
251 |
prompt = (
|
|
|
252 |
"Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
|
253 |
"Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
|
254 |
"Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
|
|
|
258 |
)
|
259 |
|
260 |
try:
|
261 |
+
# Configure tools using the Gemini API format
|
262 |
+
tools = genai.types.Tool(function_declarations=self.function_declarations)
|
263 |
+
config = genai.types.GenerateContentConfig(tools=[tools])
|
264 |
+
|
265 |
+
# Send request to the model
|
266 |
+
response = self.genai_client.models.generate_content(
|
267 |
+
model=self.model_name,
|
268 |
+
contents=prompt,
|
269 |
+
config=config
|
270 |
)
|
271 |
+
print("GEMINI RESPONSE : ", response)
|
272 |
+
|
273 |
+
# Check for function calls in the response
|
274 |
+
if (hasattr(response, 'candidates') and
|
275 |
+
response.candidates and
|
276 |
+
hasattr(response.candidates[0], 'content') and
|
277 |
+
hasattr(response.candidates[0].content, 'parts') and
|
278 |
+
response.candidates[0].content.parts and
|
279 |
+
hasattr(response.candidates[0].content.parts[0], 'function_call')):
|
280 |
+
|
281 |
+
function_call = response.candidates[0].content.parts[0].function_call
|
282 |
+
function_name = function_call.name
|
283 |
+
# Get arguments
|
284 |
+
arguments = {}
|
285 |
+
if hasattr(function_call, 'args'):
|
286 |
+
arguments = function_call.args
|
287 |
+
|
288 |
+
if function_name in self.standard_tools:
|
289 |
+
return {"decision": {"name": function_name, "arguments": arguments}}
|
290 |
+
else:
|
291 |
+
return {"error": f"Model called unknown function '{function_name}'."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
+
# No function call found
|
294 |
+
return {"error": "Gemini did not return a function call."}
|
|
|
|
|
|
|
|
|
295 |
|
296 |
except Exception as e:
|
297 |
print(f"Unexpected error during Gemini processing: {e}")
|
|
|
311 |
raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
|
312 |
self.openai_client = AsyncOpenAI(api_key=used_api_key)
|
313 |
|
314 |
+
# Use the OpenAI-specific schema with type field
|
315 |
+
self.openai_tools = list(OPENAI_TOOL_SCHEMA.values())
|
316 |
|
317 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
318 |
system_prompt = (
|
|
|
372 |
raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
|
373 |
self.mistral_client = Mistral(api_key=used_api_key)
|
374 |
|
375 |
+
# Convert standard schema to Mistral's tool format with "function" wrapper
|
376 |
+
self.mistral_tools = []
|
377 |
+
for tool_name, tool_schema in self.standard_tools.items():
|
378 |
+
self.mistral_tools.append({
|
379 |
+
"type": "function",
|
380 |
+
"function": {
|
381 |
+
"name": tool_schema["name"],
|
382 |
+
"description": tool_schema["description"],
|
383 |
+
"parameters": tool_schema["parameters"]
|
384 |
+
}
|
385 |
+
})
|
386 |
|
387 |
async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
|
388 |
system_prompt = (
|
|
|
395 |
user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
|
396 |
|
397 |
try:
|
398 |
+
# Create the messages array
|
399 |
+
messages = [
|
400 |
+
{"role": "system", "content": system_prompt},
|
401 |
+
{"role": "user", "content": user_prompt}
|
402 |
+
]
|
403 |
+
|
404 |
+
# Call the Mistral API with tool_choice set to "any" to force tool usage
|
405 |
+
response = self.mistral_client.chat.complete(
|
406 |
model=self.model,
|
407 |
+
messages=messages,
|
|
|
|
|
|
|
408 |
tools=self.mistral_tools,
|
409 |
+
tool_choice="any", # Force the model to use a tool
|
410 |
+
temperature=0.3,
|
411 |
)
|
412 |
+
print("Mistral RESPONSE : ", response)
|
413 |
+
|
414 |
# Check for tool calls in the response
|
415 |
+
message = response.choices[0].message
|
416 |
+
if hasattr(message, 'tool_calls') and message.tool_calls:
|
417 |
tool_call = message.tool_calls[0] # Get the first tool call
|
418 |
function_name = tool_call.function.name
|
419 |
try:
|
420 |
+
# Parse the function arguments from JSON string
|
421 |
arguments = json.loads(tool_call.function.arguments or '{}')
|
422 |
if function_name in self.standard_tools:
|
423 |
return {"decision": {"name": function_name, "arguments": arguments}}
|
|
|
426 |
except json.JSONDecodeError:
|
427 |
return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
|
428 |
else:
|
429 |
+
# Model did not return a tool call
|
430 |
return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
|
431 |
|
432 |
except Exception as e:
|
433 |
print(f"Error during Mistral API call: {e}")
|
434 |
+
import traceback
|
435 |
+
traceback.print_exc()
|
436 |
return {"error": f"Unexpected error: {str(e)}"}
|