Jofthomas commited on
Commit
f645c4a
·
verified ·
1 Parent(s): 2a4fcda

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +119 -131
agents.py CHANGED
@@ -3,31 +3,26 @@ import json
3
  import asyncio
4
  import random
5
 
6
-
7
- import os
8
- import json
9
- import asyncio
10
- import random
11
-
12
  # --- OpenAI ---
13
- from openai import AsyncOpenAI, APIError # (Keep if needed for other parts)
14
 
15
  # --- Google Gemini ---
16
  from google import genai
17
  from google.genai import types
18
- # It's good practice to also import potential exceptions
19
-
20
 
21
  # --- Mistral AI ---
22
- from mistralai.async_client import MistralAsyncClient # (Keep if needed)
 
23
 
24
  # --- Poke-Env ---
25
  from poke_env.player import Player
26
  from poke_env.environment.battle import Battle
27
  from poke_env.environment.move import Move
28
  from poke_env.environment.pokemon import Pokemon
 
29
 
30
- # --- Helper Function & Base Class (Assuming they are defined above) ---
31
  def normalize_name(name: str) -> str:
32
  """Lowercase and remove non-alphanumeric characters."""
33
  return "".join(filter(str.isalnum, name)).lower()
@@ -64,14 +59,13 @@ STANDARD_TOOL_SCHEMA = {
64
  }
65
 
66
 
67
- class LLMAgentBase(Player): # Make sure this base class exists
68
  def __init__(self, *args, **kwargs):
69
  super().__init__(*args, **kwargs)
70
  self.standard_tools = STANDARD_TOOL_SCHEMA
71
- self.battle_history = [] # Example attribute
72
 
73
  def _format_battle_state(self, battle: Battle) -> str:
74
- # (Implementation as provided in the question)
75
  active_pkmn = battle.active_pokemon
76
  active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
77
  f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
@@ -117,9 +111,7 @@ class LLMAgentBase(Player): # Make sure this base class exists
117
  f"Opponent Side Conditions: {battle.opponent_side_conditions}"
118
  return state_str.strip()
119
 
120
-
121
- def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None:
122
- # (Implementation as provided in the question)
123
  normalized_name = normalize_name(move_name)
124
  # Prioritize exact ID match
125
  for move in battle.available_moves:
@@ -132,8 +124,7 @@ class LLMAgentBase(Player): # Make sure this base class exists
132
  return move
133
  return None
134
 
135
- def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None:
136
- # (Implementation as provided in the question)
137
  normalized_name = normalize_name(pokemon_name)
138
  for pkmn in battle.available_switches:
139
  # Normalize the species name for comparison
@@ -142,7 +133,6 @@ class LLMAgentBase(Player): # Make sure this base class exists
142
  return None
143
 
144
  async def choose_move(self, battle: Battle) -> str:
145
- # (Implementation as provided in the question - relies on _get_llm_decision)
146
  battle_state_str = self._format_battle_state(battle)
147
  decision_result = await self._get_llm_decision(battle_state_str)
148
  decision = decision_result.get("decision")
@@ -160,8 +150,7 @@ class LLMAgentBase(Player): # Make sure this base class exists
160
  if chosen_move and chosen_move in battle.available_moves:
161
  action_taken = True
162
  chat_msg = f"AI Decision: Using move '{chosen_move.id}'."
163
- print(chat_msg) # Print to console for debugging
164
- # await self.send_message(chat_msg, battle=battle) # Uncomment if send_message exists
165
  return self.create_order(chosen_move)
166
  else:
167
  fallback_reason = f"LLM chose unavailable/invalid move '{move_name}'."
@@ -174,8 +163,7 @@ class LLMAgentBase(Player): # Make sure this base class exists
174
  if chosen_switch and chosen_switch in battle.available_switches:
175
  action_taken = True
176
  chat_msg = f"AI Decision: Switching to '{chosen_switch.species}'."
177
- print(chat_msg) # Print to console for debugging
178
- # await self.send_message(chat_msg, battle=battle) # Uncomment if send_message exists
179
  return self.create_order(chosen_switch)
180
  else:
181
  fallback_reason = f"LLM chose unavailable/invalid switch '{pokemon_name}'."
@@ -185,55 +173,53 @@ class LLMAgentBase(Player): # Make sure this base class exists
185
  fallback_reason = f"LLM called unknown function '{function_name}'."
186
 
187
  if not action_taken:
188
- if not fallback_reason: # If no specific reason yet, check for API errors
189
  if error_message:
190
  fallback_reason = f"API Error: {error_message}"
191
- elif decision is None: # Model didn't call a function or response was bad
192
  fallback_reason = "LLM did not provide a valid function call."
193
- else: # Should not happen if logic above is correct
194
  fallback_reason = "Unknown error processing LLM decision."
195
 
196
  print(f"Warning: {fallback_reason} Choosing random action.")
197
- # await self.send_message(f"AI Fallback: {fallback_reason} Choosing random action.", battle=battle) # Uncomment
198
 
199
- # Use poke-env's built-in random choice
200
  if battle.available_moves or battle.available_switches:
201
  return self.choose_random_move(battle)
202
  else:
203
  print("AI Fallback: No moves or switches available. Using Struggle/Default.")
204
- # await self.send_message("AI Fallback: No moves or switches available. Using Struggle.", battle=battle) # Uncomment
205
- return self.choose_default_move(battle) # Handles struggle
206
 
207
- async def _get_llm_decision(self, battle_state: str) -> dict:
208
  raise NotImplementedError("Subclasses must implement _get_llm_decision")
209
 
 
210
  # --- Google Gemini Agent ---
211
  class GeminiAgent(LLMAgentBase):
212
  """Uses Google Gemini API for decisions."""
213
- def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-flash", *args, **kwargs): # Default to flash for speed/cost
214
  super().__init__(*args, **kwargs)
215
  self.model_name = model
216
  used_api_key = api_key or os.environ.get("GOOGLE_API_KEY")
217
- self.model_name=model
218
  if not used_api_key:
219
  raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
220
- self.client = genai.Client(
221
- api_key='GEMINI_API_KEY',
222
- http_options=types.HttpOptions(api_version='v1alpha')
 
 
 
 
 
 
 
 
 
 
 
 
223
  )
224
- # --- Correct Tool Definition ---
225
- # Create a list of function declaration dictionaries from the values in STANDARD_TOOL_SCHEMA
226
- function_declarations = list(self.standard_tools.values())
227
- # Create the Tool object expected by the API
228
- self.gemini_tool_config = types.Tool(function_declarations=function_declarations)
229
- # --- End Tool Definition ---
230
-
231
- # --- Correct Model Initialization ---
232
- # Pass the Tool object directly to the model's 'tools' parameter
233
 
234
- # --- End Model Initialization ---
235
-
236
- async def _get_llm_decision(self, battle_state: str) -> dict:
237
  """Sends state to the Gemini API and gets back the function call decision."""
238
  prompt = (
239
  "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
@@ -246,63 +232,64 @@ class GeminiAgent(LLMAgentBase):
246
  )
247
 
248
  try:
249
- # --- Correct API Call ---
250
- # Call generate_content_async directly on the model object.
251
- # Tools are already configured in the model, no need to pass config here.
252
- response = await client.aio.models.generate_content(
253
- model=self.model_name,
254
- contents=prompt
255
  )
256
- # --- End API Call ---
257
-
258
- # --- Response Parsing (Your logic was already good here) ---
259
- # Check candidates and parts safely
260
  if not response.candidates:
261
- finish_reason_str = "No candidates found"
262
- try: finish_reason_str = response.prompt_feedback.block_reason.name
263
- except AttributeError: pass
264
- return {"error": f"Gemini response issue. Reason: {finish_reason_str}"}
 
 
265
 
266
  candidate = response.candidates[0]
267
  if not candidate.content or not candidate.content.parts:
268
  finish_reason_str = "Unknown"
269
- try: finish_reason_str = candidate.finish_reason.name
270
- except AttributeError: pass
 
 
271
  return {"error": f"Gemini response issue. Finish Reason: {finish_reason_str}"}
272
 
273
- part = candidate.content.parts[0]
274
-
275
- # Check for function_call attribute
276
- if hasattr(part, 'function_call') and part.function_call:
277
- fc = part.function_call
278
- function_name = fc.name
279
- # fc.args is a proto_plus.MapComposite, convert to dict
280
- arguments = dict(fc.args) if fc.args else {}
281
-
282
- if function_name in self.standard_tools:
283
- # Valid function call found
284
- return {"decision": {"name": function_name, "arguments": arguments}}
285
- else:
286
- # Model hallucinated a function name
287
- return {"error": f"Model called unknown function '{function_name}'. Args: {arguments}"}
288
- elif hasattr(part, 'text'):
289
- # Handle case where the model returns text instead of a function call
290
- text_response = part.text
291
- return {"error": f"Gemini did not return a function call. Response: {text_response}"}
292
- else:
293
- # Unexpected part type
294
- return {"error": f"Gemini response part type unknown. Part: {part}"}
295
- # --- End Response Parsing ---
 
296
  except Exception as e:
297
- # Catch any other unexpected errors during the API call or processing
298
  print(f"Unexpected error during Gemini processing: {e}")
299
  import traceback
300
- traceback.print_exc() # Print stack trace for debugging
301
  return {"error": f"Unexpected error: {str(e)}"}
 
 
302
  # --- OpenAI Agent ---
303
  class OpenAIAgent(LLMAgentBase):
304
  """Uses OpenAI API for decisions."""
305
- def __init__(self, api_key: str | None = None, model: str = "gpt-4o", *args, **kwargs):
306
  super().__init__(*args, **kwargs)
307
  self.model = model
308
  used_api_key = api_key or os.environ.get("OPENAI_API_KEY")
@@ -311,9 +298,9 @@ class OpenAIAgent(LLMAgentBase):
311
  self.openai_client = AsyncOpenAI(api_key=used_api_key)
312
 
313
  # Convert standard schema to OpenAI's format
314
- self.openai_functions = [v for k, v in self.standard_tools.items()]
315
 
316
- async def _get_llm_decision(self, battle_state: str) -> dict:
317
  system_prompt = (
318
  "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
319
  "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
@@ -330,25 +317,26 @@ class OpenAIAgent(LLMAgentBase):
330
  {"role": "system", "content": system_prompt},
331
  {"role": "user", "content": user_prompt},
332
  ],
333
- functions=STANDARD_TOOL_SCHEMA,
334
- function_call="auto",
335
  temperature=0.5,
336
  )
337
  message = response.choices[0].message
338
- print("MESSAGE BACK : ", message)
339
- if message.function_call:
340
- function_name = message.function_call.name
 
 
341
  try:
342
- # Ensure arguments is always a dict, even if empty/null
343
- arguments = json.loads(message.function_call.arguments or '{}')
344
- if function_name in self.standard_tools: # Validate function name
345
- return {"decision": {"name": function_name, "arguments": arguments}}
346
  else:
347
- return {"error": f"Model called unknown function '{function_name}'."}
348
  except json.JSONDecodeError:
349
- return {"error": f"Error decoding function call arguments: {message.function_call.arguments}"}
350
  else:
351
- # Model decided not to call a function (or generated text instead)
352
  return {"error": f"OpenAI did not return a function call. Response: {message.content}"}
353
 
354
  except APIError as e:
@@ -362,7 +350,7 @@ class OpenAIAgent(LLMAgentBase):
362
  # --- Mistral Agent ---
363
  class MistralAgent(LLMAgentBase):
364
  """Uses Mistral AI API for decisions."""
365
- def __init__(self, api_key: str | None = None, model: str = "mistral-large-latest", *args, **kwargs):
366
  super().__init__(*args, **kwargs)
367
  self.model = model
368
  used_api_key = api_key or os.environ.get("MISTRAL_API_KEY")
@@ -370,51 +358,51 @@ class MistralAgent(LLMAgentBase):
370
  raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
371
  self.mistral_client = MistralAsyncClient(api_key=used_api_key)
372
 
373
- # Convert standard schema to Mistral's tool format (very similar to OpenAI's)
374
- self.mistral_tools = STANDARD_TOOL_SCHEMA
375
 
376
- async def _get_llm_decision(self, battle_state: str) -> dict:
377
  system_prompt = (
378
  "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
379
  "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
380
  "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
381
  "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
382
- "Use the provided tools/functions to indicate your choice."
383
  )
384
  user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
385
 
386
  try:
387
- response = await self.mistral_client.chat.complete(
388
  model=self.model,
389
  messages=[
390
- {"role": "system", "content": f"{system}"},
391
- {"role": "user", "content": f"{user_prompt}"}
392
  ],
393
  tools=self.mistral_tools,
394
- tool_choice="auto", # Let the model choose
395
  temperature=0.5,
396
  )
397
 
398
  message = response.choices[0].message
399
- # Mistral returns tool_calls as a list
400
  if message.tool_calls:
401
- tool_call = response.choices[0].message.tool_calls[0]
402
  function_name = tool_call.function.name
403
- function_params = json.loads(tool_call.function.arguments)
404
- print("\nfunction_name: ", function_name, "\nfunction_params: ", function_params)
405
- if function_name and function_params: # Validate function name
406
- return {"decision": {"name": function_name, "arguments": arguments}}
407
-
 
 
 
408
  else:
409
- # Model decided not to call a tool (or generated text instead)
410
- return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
411
 
412
- # Mistral client might raise specific exceptions, add them here if known
413
- # from mistralai.exceptions import MistralAPIException # Example
414
- except Exception as e: # Catch general exceptions for now
 
415
  print(f"Error during Mistral API call: {e}")
416
- # Try to get specific details if it's a known exception type
417
- error_details = str(e)
418
- # if isinstance(e, MistralAPIException): # Example
419
- # error_details = f"{e.status_code} - {e.message}"
420
- return {"error": f"Mistral API Error: {error_details}"}
 
3
  import asyncio
4
  import random
5
 
 
 
 
 
 
 
6
  # --- OpenAI ---
7
+ from openai import AsyncOpenAI, APIError
8
 
9
  # --- Google Gemini ---
10
  from google import genai
11
  from google.genai import types
12
+ from google.api_core import exceptions as google_exceptions
 
13
 
14
  # --- Mistral AI ---
15
+ from mistralai.async_client import MistralAsyncClient
16
+ from mistralai.exceptions import MistralAPIError
17
 
18
  # --- Poke-Env ---
19
  from poke_env.player import Player
20
  from poke_env.environment.battle import Battle
21
  from poke_env.environment.move import Move
22
  from poke_env.environment.pokemon import Pokemon
23
+ from typing import Optional, Dict, Any, Union
24
 
25
+ # --- Helper Function & Base Class ---
26
  def normalize_name(name: str) -> str:
27
  """Lowercase and remove non-alphanumeric characters."""
28
  return "".join(filter(str.isalnum, name)).lower()
 
59
  }
60
 
61
 
62
+ class LLMAgentBase(Player):
63
  def __init__(self, *args, **kwargs):
64
  super().__init__(*args, **kwargs)
65
  self.standard_tools = STANDARD_TOOL_SCHEMA
66
+ self.battle_history = []
67
 
68
  def _format_battle_state(self, battle: Battle) -> str:
 
69
  active_pkmn = battle.active_pokemon
70
  active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
71
  f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
 
111
  f"Opponent Side Conditions: {battle.opponent_side_conditions}"
112
  return state_str.strip()
113
 
114
+ def _find_move_by_name(self, battle: Battle, move_name: str) -> Optional[Move]:
 
 
115
  normalized_name = normalize_name(move_name)
116
  # Prioritize exact ID match
117
  for move in battle.available_moves:
 
124
  return move
125
  return None
126
 
127
+ def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Optional[Pokemon]:
 
128
  normalized_name = normalize_name(pokemon_name)
129
  for pkmn in battle.available_switches:
130
  # Normalize the species name for comparison
 
133
  return None
134
 
135
  async def choose_move(self, battle: Battle) -> str:
 
136
  battle_state_str = self._format_battle_state(battle)
137
  decision_result = await self._get_llm_decision(battle_state_str)
138
  decision = decision_result.get("decision")
 
150
  if chosen_move and chosen_move in battle.available_moves:
151
  action_taken = True
152
  chat_msg = f"AI Decision: Using move '{chosen_move.id}'."
153
+ print(chat_msg)
 
154
  return self.create_order(chosen_move)
155
  else:
156
  fallback_reason = f"LLM chose unavailable/invalid move '{move_name}'."
 
163
  if chosen_switch and chosen_switch in battle.available_switches:
164
  action_taken = True
165
  chat_msg = f"AI Decision: Switching to '{chosen_switch.species}'."
166
+ print(chat_msg)
 
167
  return self.create_order(chosen_switch)
168
  else:
169
  fallback_reason = f"LLM chose unavailable/invalid switch '{pokemon_name}'."
 
173
  fallback_reason = f"LLM called unknown function '{function_name}'."
174
 
175
  if not action_taken:
176
+ if not fallback_reason:
177
  if error_message:
178
  fallback_reason = f"API Error: {error_message}"
179
+ elif decision is None:
180
  fallback_reason = "LLM did not provide a valid function call."
181
+ else:
182
  fallback_reason = "Unknown error processing LLM decision."
183
 
184
  print(f"Warning: {fallback_reason} Choosing random action.")
 
185
 
 
186
  if battle.available_moves or battle.available_switches:
187
  return self.choose_random_move(battle)
188
  else:
189
  print("AI Fallback: No moves or switches available. Using Struggle/Default.")
190
+ return self.choose_default_move(battle)
 
191
 
192
+ async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
193
  raise NotImplementedError("Subclasses must implement _get_llm_decision")
194
 
195
+
196
  # --- Google Gemini Agent ---
197
  class GeminiAgent(LLMAgentBase):
198
  """Uses Google Gemini API for decisions."""
199
+ def __init__(self, api_key: str = None, model: str = "gemini-1.5-flash", *args, **kwargs):
200
  super().__init__(*args, **kwargs)
201
  self.model_name = model
202
  used_api_key = api_key or os.environ.get("GOOGLE_API_KEY")
 
203
  if not used_api_key:
204
  raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
205
+
206
+ # Initialize Gemini client
207
+ genai.configure(api_key=used_api_key)
208
+
209
+ # Configure the model with tools
210
+ self.gemini_tool_config = [
211
+ {
212
+ "function_declarations": list(self.standard_tools.values())
213
+ }
214
+ ]
215
+
216
+ # Initialize the model
217
+ self.model = genai.GenerativeModel(
218
+ model_name=self.model_name,
219
+ tools=self.gemini_tool_config
220
  )
 
 
 
 
 
 
 
 
 
221
 
222
+ async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
 
 
223
  """Sends state to the Gemini API and gets back the function call decision."""
224
  prompt = (
225
  "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
 
232
  )
233
 
234
  try:
235
+ # Use the async API for Gemini
236
+ response = await self.model.generate_content_async(
237
+ prompt,
238
+ generation_config={"temperature": 0.5}
 
 
239
  )
240
+
 
 
 
241
  if not response.candidates:
242
+ finish_reason_str = "No candidates found"
243
+ try:
244
+ finish_reason_str = response.prompt_feedback.block_reason.name
245
+ except AttributeError:
246
+ pass
247
+ return {"error": f"Gemini response issue. Reason: {finish_reason_str}"}
248
 
249
  candidate = response.candidates[0]
250
  if not candidate.content or not candidate.content.parts:
251
  finish_reason_str = "Unknown"
252
+ try:
253
+ finish_reason_str = candidate.finish_reason.name
254
+ except AttributeError:
255
+ pass
256
  return {"error": f"Gemini response issue. Finish Reason: {finish_reason_str}"}
257
 
258
+ for part in candidate.content.parts:
259
+ if hasattr(part, 'function_call') and part.function_call:
260
+ fc = part.function_call
261
+ function_name = fc.name
262
+ # Convert arguments to dict
263
+ arguments = {}
264
+ if fc.args:
265
+ arguments = {k: v for k, v in fc.args.items()}
266
+
267
+ if function_name in self.standard_tools:
268
+ return {"decision": {"name": function_name, "arguments": arguments}}
269
+ else:
270
+ return {"error": f"Model called unknown function '{function_name}'. Args: {arguments}"}
271
+
272
+ # If we got here, no function call was found in any part
273
+ text_content = " ".join([
274
+ part.text if hasattr(part, 'text') else str(part)
275
+ for part in candidate.content.parts
276
+ ])
277
+ return {"error": f"Gemini did not return a function call. Response: {text_content[:100]}..."}
278
+
279
+ except google_exceptions.GoogleAPIError as e:
280
+ print(f"Google API error: {e}")
281
+ return {"error": f"Google API error: {str(e)}"}
282
  except Exception as e:
 
283
  print(f"Unexpected error during Gemini processing: {e}")
284
  import traceback
285
+ traceback.print_exc()
286
  return {"error": f"Unexpected error: {str(e)}"}
287
+
288
+
289
  # --- OpenAI Agent ---
290
  class OpenAIAgent(LLMAgentBase):
291
  """Uses OpenAI API for decisions."""
292
+ def __init__(self, api_key: str = None, model: str = "gpt-4o", *args, **kwargs):
293
  super().__init__(*args, **kwargs)
294
  self.model = model
295
  used_api_key = api_key or os.environ.get("OPENAI_API_KEY")
 
298
  self.openai_client = AsyncOpenAI(api_key=used_api_key)
299
 
300
  # Convert standard schema to OpenAI's format
301
+ self.openai_tools = list(self.standard_tools.values())
302
 
303
+ async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
304
  system_prompt = (
305
  "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
306
  "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
 
317
  {"role": "system", "content": system_prompt},
318
  {"role": "user", "content": user_prompt},
319
  ],
320
+ tools=self.openai_tools,
321
+ tool_choice="auto", # Let the model choose
322
  temperature=0.5,
323
  )
324
  message = response.choices[0].message
325
+
326
+ # Check for tool calls in the response
327
+ if message.tool_calls:
328
+ tool_call = message.tool_calls[0] # Get the first tool call
329
+ function_name = tool_call.function.name
330
  try:
331
+ arguments = json.loads(tool_call.function.arguments or '{}')
332
+ if function_name in self.standard_tools:
333
+ return {"decision": {"name": function_name, "arguments": arguments}}
 
334
  else:
335
+ return {"error": f"Model called unknown function '{function_name}'."}
336
  except json.JSONDecodeError:
337
+ return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
338
  else:
339
+ # Model decided not to call a function
340
  return {"error": f"OpenAI did not return a function call. Response: {message.content}"}
341
 
342
  except APIError as e:
 
350
  # --- Mistral Agent ---
351
  class MistralAgent(LLMAgentBase):
352
  """Uses Mistral AI API for decisions."""
353
+ def __init__(self, api_key: str = None, model: str = "mistral-large-latest", *args, **kwargs):
354
  super().__init__(*args, **kwargs)
355
  self.model = model
356
  used_api_key = api_key or os.environ.get("MISTRAL_API_KEY")
 
358
  raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
359
  self.mistral_client = MistralAsyncClient(api_key=used_api_key)
360
 
361
+ # Convert standard schema to Mistral's tool format
362
+ self.mistral_tools = list(self.standard_tools.values())
363
 
364
+ async def _get_llm_decision(self, battle_state: str) -> Dict[str, Any]:
365
  system_prompt = (
366
  "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
367
  "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
368
  "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
369
  "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
370
+ "Use the provided tools to indicate your choice."
371
  )
372
  user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
373
 
374
  try:
375
+ response = await self.mistral_client.chat(
376
  model=self.model,
377
  messages=[
378
+ {"role": "system", "content": system_prompt},
379
+ {"role": "user", "content": user_prompt}
380
  ],
381
  tools=self.mistral_tools,
382
+ tool_choice="auto", # Let the model choose
383
  temperature=0.5,
384
  )
385
 
386
  message = response.choices[0].message
387
+ # Check for tool calls in the response
388
  if message.tool_calls:
389
+ tool_call = message.tool_calls[0] # Get the first tool call
390
  function_name = tool_call.function.name
391
+ try:
392
+ arguments = json.loads(tool_call.function.arguments or '{}')
393
+ if function_name in self.standard_tools:
394
+ return {"decision": {"name": function_name, "arguments": arguments}}
395
+ else:
396
+ return {"error": f"Model called unknown function '{function_name}'."}
397
+ except json.JSONDecodeError:
398
+ return {"error": f"Error decoding function arguments: {tool_call.function.arguments}"}
399
  else:
400
+ # Model decided not to call a tool
401
+ return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
402
 
403
+ except MistralAPIError as e:
404
+ print(f"Error during Mistral API call: {e}")
405
+ return {"error": f"Mistral API Error: {str(e)}"}
406
+ except Exception as e:
407
  print(f"Error during Mistral API call: {e}")
408
+ return {"error": f"Unexpected error: {str(e)}"}