Jofthomas commited on
Commit
6eaa8f9
·
verified ·
1 Parent(s): e3cef11

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +106 -71
agents.py CHANGED
@@ -56,6 +56,44 @@ STANDARD_TOOL_SCHEMA = {
56
  },
57
  }
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  class LLMAgentBase(Player):
61
  def __init__(self, *args, **kwargs):
@@ -202,26 +240,15 @@ class GeminiAgent(LLMAgentBase):
202
  if not used_api_key:
203
  raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
204
 
205
- # Initialize Gemini client
206
- genai.configure(api_key=used_api_key)
207
 
208
- # Configure the model with tools
209
- self.gemini_tool_config = [
210
- {
211
- "function_declarations": list(self.standard_tools.values())
212
- }
213
- ]
214
-
215
- # Initialize the model
216
- self.model = genai.GenerativeModel(
217
- model_name=self.model_name,
218
- tools=self.gemini_tool_config
219
- )
220
 
221
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
222
  """Sends state to the Gemini API and gets back the function call decision."""
223
  prompt = (
224
- "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
225
  "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
226
  "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
227
  "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
@@ -231,49 +258,40 @@ class GeminiAgent(LLMAgentBase):
231
  )
232
 
233
  try:
234
- # Use the async API for Gemini
235
- response = await self.model.generate_content_async(
236
- prompt,
237
- generation_config={"temperature": 0.5}
 
 
 
 
 
238
  )
239
- print("GEMINI RESPONSE : ",response)
240
- if not response.candidates:
241
- finish_reason_str = "No candidates found"
242
- try:
243
- finish_reason_str = response.prompt_feedback.block_reason.name
244
- except AttributeError:
245
- pass
246
- return {"error": f"Gemini response issue. Reason: {finish_reason_str}"}
247
-
248
- candidate = response.candidates[0]
249
- if not candidate.content or not candidate.content.parts:
250
- finish_reason_str = "Unknown"
251
- try:
252
- finish_reason_str = candidate.finish_reason.name
253
- except AttributeError:
254
- pass
255
- return {"error": f"Gemini response issue. Finish Reason: {finish_reason_str}"}
256
-
257
- for part in candidate.content.parts:
258
- if hasattr(part, 'function_call') and part.function_call:
259
- fc = part.function_call
260
- function_name = fc.name
261
- # Convert arguments to dict
262
- arguments = {}
263
- if fc.args:
264
- arguments = {k: v for k, v in fc.args.items()}
265
-
266
- if function_name in self.standard_tools:
267
- return {"decision": {"name": function_name, "arguments": arguments}}
268
- else:
269
- return {"error": f"Model called unknown function '{function_name}'. Args: {arguments}"}
270
 
271
- # If we got here, no function call was found in any part
272
- text_content = " ".join([
273
- part.text if hasattr(part, 'text') else str(part)
274
- for part in candidate.content.parts
275
- ])
276
- return {"error": f"Gemini did not return a function call. Response: {text_content[:100]}..."}
277
 
278
  except Exception as e:
279
  print(f"Unexpected error during Gemini processing: {e}")
@@ -293,8 +311,8 @@ class OpenAIAgent(LLMAgentBase):
293
  raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
294
  self.openai_client = AsyncOpenAI(api_key=used_api_key)
295
 
296
- # Convert standard schema to OpenAI's format
297
- self.openai_tools = list(self.standard_tools.values())
298
 
299
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
300
  system_prompt = (
@@ -354,8 +372,17 @@ class MistralAgent(LLMAgentBase):
354
  raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
355
  self.mistral_client = Mistral(api_key=used_api_key)
356
 
357
- # Convert standard schema to Mistral's tool format
358
- self.mistral_tools = list(self.standard_tools.values())
 
 
 
 
 
 
 
 
 
359
 
360
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
361
  system_prompt = (
@@ -368,23 +395,29 @@ class MistralAgent(LLMAgentBase):
368
  user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
369
 
370
  try:
371
- response = await self.mistral_client.chat.complete(
 
 
 
 
 
 
 
372
  model=self.model,
373
- messages=[
374
- {"role": "system", "content": system_prompt},
375
- {"role": "user", "content": user_prompt}
376
- ],
377
  tools=self.mistral_tools,
378
- tool_choice="auto", # Let the model choose
379
- temperature=0.5,
380
  )
381
- print("Mistral RESPONSE : ",response)
382
- message = response.choices[0].message
383
  # Check for tool calls in the response
384
- if message.tool_calls:
 
385
  tool_call = message.tool_calls[0] # Get the first tool call
386
  function_name = tool_call.function.name
387
  try:
 
388
  arguments = json.loads(tool_call.function.arguments or '{}')
389
  if function_name in self.standard_tools:
390
  return {"decision": {"name": function_name, "arguments": arguments}}
@@ -393,9 +426,11 @@ class MistralAgent(LLMAgentBase):
393
  except json.JSONDecodeError:
394
  return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
395
  else:
396
- # Model decided not to call a tool
397
  return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
398
 
399
  except Exception as e:
400
  print(f"Error during Mistral API call: {e}")
 
 
401
  return {"error": f"Unexpected error: {str(e)}"}
 
56
  },
57
  }
58
 
59
+ # --- OpenAI Tools Schema (with 'type' field) ---
60
+ OPENAI_TOOL_SCHEMA = {
61
+ "choose_move": {
62
+ "type": "function",
63
+ "function": {
64
+ "name": "choose_move",
65
+ "description": "Selects and executes an available attacking or status move.",
66
+ "parameters": {
67
+ "type": "object",
68
+ "properties": {
69
+ "move_name": {
70
+ "type": "string",
71
+ "description": "The exact name or ID (e.g., 'thunderbolt', 'swordsdance') of the move to use. Must be one of the available moves.",
72
+ },
73
+ },
74
+ "required": ["move_name"],
75
+ },
76
+ }
77
+ },
78
+ "choose_switch": {
79
+ "type": "function",
80
+ "function": {
81
+ "name": "choose_switch",
82
+ "description": "Selects an available Pokémon from the bench to switch into.",
83
+ "parameters": {
84
+ "type": "object",
85
+ "properties": {
86
+ "pokemon_name": {
87
+ "type": "string",
88
+ "description": "The exact name of the Pokémon species to switch to (e.g., 'Pikachu', 'Charizard'). Must be one of the available switches.",
89
+ },
90
+ },
91
+ "required": ["pokemon_name"],
92
+ },
93
+ }
94
+ },
95
+ }
96
+
97
 
98
  class LLMAgentBase(Player):
99
  def __init__(self, *args, **kwargs):
 
240
  if not used_api_key:
241
  raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
242
 
243
+ # Initialize Gemini client using the correct API
244
+ self.genai_client = genai.Client(api_key=used_api_key)
245
 
246
+ # Configure the tools for function calling
247
+ self.function_declarations = list(self.standard_tools.values())
 
 
 
 
 
 
 
 
 
 
248
 
249
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
250
  """Sends state to the Gemini API and gets back the function call decision."""
251
  prompt = (
 
252
  "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
253
  "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
254
  "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
 
258
  )
259
 
260
  try:
261
+ # Configure tools using the Gemini API format
262
+ tools = genai.types.Tool(function_declarations=self.function_declarations)
263
+ config = genai.types.GenerateContentConfig(tools=[tools])
264
+
265
+ # Send request to the model
266
+ response = self.genai_client.models.generate_content(
267
+ model=self.model_name,
268
+ contents=prompt,
269
+ config=config
270
  )
271
+ print("GEMINI RESPONSE : ", response)
272
+
273
+ # Check for function calls in the response
274
+ if (hasattr(response, 'candidates') and
275
+ response.candidates and
276
+ hasattr(response.candidates[0], 'content') and
277
+ hasattr(response.candidates[0].content, 'parts') and
278
+ response.candidates[0].content.parts and
279
+ hasattr(response.candidates[0].content.parts[0], 'function_call')):
280
+
281
+ function_call = response.candidates[0].content.parts[0].function_call
282
+ function_name = function_call.name
283
+ # Get arguments
284
+ arguments = {}
285
+ if hasattr(function_call, 'args'):
286
+ arguments = function_call.args
287
+
288
+ if function_name in self.standard_tools:
289
+ return {"decision": {"name": function_name, "arguments": arguments}}
290
+ else:
291
+ return {"error": f"Model called unknown function '{function_name}'."}
 
 
 
 
 
 
 
 
 
 
292
 
293
+ # No function call found
294
+ return {"error": "Gemini did not return a function call."}
 
 
 
 
295
 
296
  except Exception as e:
297
  print(f"Unexpected error during Gemini processing: {e}")
 
311
  raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
312
  self.openai_client = AsyncOpenAI(api_key=used_api_key)
313
 
314
+ # Use the OpenAI-specific schema with type field
315
+ self.openai_tools = list(OPENAI_TOOL_SCHEMA.values())
316
 
317
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
318
  system_prompt = (
 
372
  raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
373
  self.mistral_client = Mistral(api_key=used_api_key)
374
 
375
+ # Convert standard schema to Mistral's tool format with "function" wrapper
376
+ self.mistral_tools = []
377
+ for tool_name, tool_schema in self.standard_tools.items():
378
+ self.mistral_tools.append({
379
+ "type": "function",
380
+ "function": {
381
+ "name": tool_schema["name"],
382
+ "description": tool_schema["description"],
383
+ "parameters": tool_schema["parameters"]
384
+ }
385
+ })
386
 
387
  async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
388
  system_prompt = (
 
395
  user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
396
 
397
  try:
398
+ # Create the messages array
399
+ messages = [
400
+ {"role": "system", "content": system_prompt},
401
+ {"role": "user", "content": user_prompt}
402
+ ]
403
+
404
+ # Call the Mistral API with tool_choice set to "any" to force tool usage
405
+ response = self.mistral_client.chat.complete(
406
  model=self.model,
407
+ messages=messages,
 
 
 
408
  tools=self.mistral_tools,
409
+ tool_choice="any", # Force the model to use a tool
410
+ temperature=0.3,
411
  )
412
+ print("Mistral RESPONSE : ", response)
413
+
414
  # Check for tool calls in the response
415
+ message = response.choices[0].message
416
+ if hasattr(message, 'tool_calls') and message.tool_calls:
417
  tool_call = message.tool_calls[0] # Get the first tool call
418
  function_name = tool_call.function.name
419
  try:
420
+ # Parse the function arguments from JSON string
421
  arguments = json.loads(tool_call.function.arguments or '{}')
422
  if function_name in self.standard_tools:
423
  return {"decision": {"name": function_name, "arguments": arguments}}
 
426
  except json.JSONDecodeError:
427
  return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
428
  else:
429
+ # Model did not return a tool call
430
  return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
431
 
432
  except Exception as e:
433
  print(f"Error during Mistral API call: {e}")
434
+ import traceback
435
+ traceback.print_exc()
436
  return {"error": f"Unexpected error: {str(e)}"}