Jofthomas commited on
Commit
22cdc3d
·
verified ·
1 Parent(s): deb30f4

Create agents.py

Browse files
Files changed (1) hide show
  1. agents.py +420 -0
agents.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import random
5
+
6
+
7
+ import os
8
+ import json
9
+ import asyncio
10
+ import random
11
+
12
+ # --- OpenAI ---
13
+ from openai import AsyncOpenAI, APIError # (Keep if needed for other parts)
14
+
15
+ # --- Google Gemini ---
16
+ from google import genai
17
+ from google.genai import types
18
+ # It's good practice to also import potential exceptions
19
+
20
+
21
+ # --- Mistral AI ---
22
+ from mistralai.async_client import MistralAsyncClient # (Keep if needed)
23
+
24
+ # --- Poke-Env ---
25
+ from poke_env.player import Player
26
+ from poke_env.environment.battle import Battle
27
+ from poke_env.environment.move import Move
28
+ from poke_env.environment.pokemon import Pokemon
29
+
30
+ # --- Helper Function & Base Class (Assuming they are defined above) ---
31
+ def normalize_name(name: str) -> str:
32
+ """Lowercase and remove non-alphanumeric characters."""
33
+ return "".join(filter(str.isalnum, name)).lower()
34
+
35
+ STANDARD_TOOL_SCHEMA = {
36
+ "choose_move": {
37
+ "name": "choose_move",
38
+ "description": "Selects and executes an available attacking or status move.",
39
+ "parameters": {
40
+ "type": "object",
41
+ "properties": {
42
+ "move_name": {
43
+ "type": "string",
44
+ "description": "The exact name or ID (e.g., 'thunderbolt', 'swordsdance') of the move to use. Must be one of the available moves.",
45
+ },
46
+ },
47
+ "required": ["move_name"],
48
+ },
49
+ },
50
+ "choose_switch": {
51
+ "name": "choose_switch",
52
+ "description": "Selects an available Pokémon from the bench to switch into.",
53
+ "parameters": {
54
+ "type": "object",
55
+ "properties": {
56
+ "pokemon_name": {
57
+ "type": "string",
58
+ "description": "The exact name of the Pokémon species to switch to (e.g., 'Pikachu', 'Charizard'). Must be one of the available switches.",
59
+ },
60
+ },
61
+ "required": ["pokemon_name"],
62
+ },
63
+ },
64
+ }
65
+
66
+
67
+ class LLMAgentBase(Player): # Make sure this base class exists
68
+ def __init__(self, *args, **kwargs):
69
+ super().__init__(*args, **kwargs)
70
+ self.standard_tools = STANDARD_TOOL_SCHEMA
71
+ self.battle_history = [] # Example attribute
72
+
73
+ def _format_battle_state(self, battle: Battle) -> str:
74
+ # (Implementation as provided in the question)
75
+ active_pkmn = battle.active_pokemon
76
+ active_pkmn_info = f"Your active Pokemon: {active_pkmn.species} " \
77
+ f"(Type: {'/'.join(map(str, active_pkmn.types))}) " \
78
+ f"HP: {active_pkmn.current_hp_fraction * 100:.1f}% " \
79
+ f"Status: {active_pkmn.status.name if active_pkmn.status else 'None'} " \
80
+ f"Boosts: {active_pkmn.boosts}"
81
+
82
+ opponent_pkmn = battle.opponent_active_pokemon
83
+ opp_info_str = "Unknown"
84
+ if opponent_pkmn:
85
+ opp_info_str = f"{opponent_pkmn.species} " \
86
+ f"(Type: {'/'.join(map(str, opponent_pkmn.types))}) " \
87
+ f"HP: {opponent_pkmn.current_hp_fraction * 100:.1f}% " \
88
+ f"Status: {opponent_pkmn.status.name if opponent_pkmn.status else 'None'} " \
89
+ f"Boosts: {opponent_pkmn.boosts}"
90
+ opponent_pkmn_info = f"Opponent's active Pokemon: {opp_info_str}"
91
+
92
+ available_moves_info = "Available moves:\n"
93
+ if battle.available_moves:
94
+ available_moves_info += "\n".join(
95
+ [f"- {move.id} (Type: {move.type}, BP: {move.base_power}, Acc: {move.accuracy}, PP: {move.current_pp}/{move.max_pp}, Cat: {move.category.name})"
96
+ for move in battle.available_moves]
97
+ )
98
+ else:
99
+ available_moves_info += "- None (Must switch or Struggle)"
100
+
101
+ available_switches_info = "Available switches:\n"
102
+ if battle.available_switches:
103
+ available_switches_info += "\n".join(
104
+ [f"- {pkmn.species} (HP: {pkmn.current_hp_fraction * 100:.1f}%, Status: {pkmn.status.name if pkmn.status else 'None'})"
105
+ for pkmn in battle.available_switches]
106
+ )
107
+ else:
108
+ available_switches_info += "- None"
109
+
110
+ state_str = f"{active_pkmn_info}\n" \
111
+ f"{opponent_pkmn_info}\n\n" \
112
+ f"{available_moves_info}\n\n" \
113
+ f"{available_switches_info}\n\n" \
114
+ f"Weather: {battle.weather}\n" \
115
+ f"Terrains: {battle.fields}\n" \
116
+ f"Your Side Conditions: {battle.side_conditions}\n" \
117
+ f"Opponent Side Conditions: {battle.opponent_side_conditions}"
118
+ return state_str.strip()
119
+
120
+
121
+ def _find_move_by_name(self, battle: Battle, move_name: str) -> Move | None:
122
+ # (Implementation as provided in the question)
123
+ normalized_name = normalize_name(move_name)
124
+ # Prioritize exact ID match
125
+ for move in battle.available_moves:
126
+ if move.id == normalized_name:
127
+ return move
128
+ # Fallback: Check display name (less reliable)
129
+ for move in battle.available_moves:
130
+ if move.name.lower() == move_name.lower():
131
+ print(f"Warning: Matched move by display name '{move.name}' instead of ID '{move.id}'. Input was '{move_name}'.")
132
+ return move
133
+ return None
134
+
135
+ def _find_pokemon_by_name(self, battle: Battle, pokemon_name: str) -> Pokemon | None:
136
+ # (Implementation as provided in the question)
137
+ normalized_name = normalize_name(pokemon_name)
138
+ for pkmn in battle.available_switches:
139
+ # Normalize the species name for comparison
140
+ if normalize_name(pkmn.species) == normalized_name:
141
+ return pkmn
142
+ return None
143
+
144
+ async def choose_move(self, battle: Battle) -> str:
145
+ # (Implementation as provided in the question - relies on _get_llm_decision)
146
+ battle_state_str = self._format_battle_state(battle)
147
+ decision_result = await self._get_llm_decision(battle_state_str)
148
+ decision = decision_result.get("decision")
149
+ error_message = decision_result.get("error")
150
+ action_taken = False
151
+ fallback_reason = ""
152
+
153
+ if decision:
154
+ function_name = decision.get("name")
155
+ args = decision.get("arguments", {})
156
+ if function_name == "choose_move":
157
+ move_name = args.get("move_name")
158
+ if move_name:
159
+ chosen_move = self._find_move_by_name(battle, move_name)
160
+ if chosen_move and chosen_move in battle.available_moves:
161
+ action_taken = True
162
+ chat_msg = f"AI Decision: Using move '{chosen_move.id}'."
163
+ print(chat_msg) # Print to console for debugging
164
+ # await self.send_message(chat_msg, battle=battle) # Uncomment if send_message exists
165
+ return self.create_order(chosen_move)
166
+ else:
167
+ fallback_reason = f"LLM chose unavailable/invalid move '{move_name}'."
168
+ else:
169
+ fallback_reason = "LLM 'choose_move' called without 'move_name'."
170
+ elif function_name == "choose_switch":
171
+ pokemon_name = args.get("pokemon_name")
172
+ if pokemon_name:
173
+ chosen_switch = self._find_pokemon_by_name(battle, pokemon_name)
174
+ if chosen_switch and chosen_switch in battle.available_switches:
175
+ action_taken = True
176
+ chat_msg = f"AI Decision: Switching to '{chosen_switch.species}'."
177
+ print(chat_msg) # Print to console for debugging
178
+ # await self.send_message(chat_msg, battle=battle) # Uncomment if send_message exists
179
+ return self.create_order(chosen_switch)
180
+ else:
181
+ fallback_reason = f"LLM chose unavailable/invalid switch '{pokemon_name}'."
182
+ else:
183
+ fallback_reason = "LLM 'choose_switch' called without 'pokemon_name'."
184
+ else:
185
+ fallback_reason = f"LLM called unknown function '{function_name}'."
186
+
187
+ if not action_taken:
188
+ if not fallback_reason: # If no specific reason yet, check for API errors
189
+ if error_message:
190
+ fallback_reason = f"API Error: {error_message}"
191
+ elif decision is None: # Model didn't call a function or response was bad
192
+ fallback_reason = "LLM did not provide a valid function call."
193
+ else: # Should not happen if logic above is correct
194
+ fallback_reason = "Unknown error processing LLM decision."
195
+
196
+ print(f"Warning: {fallback_reason} Choosing random action.")
197
+ # await self.send_message(f"AI Fallback: {fallback_reason} Choosing random action.", battle=battle) # Uncomment
198
+
199
+ # Use poke-env's built-in random choice
200
+ if battle.available_moves or battle.available_switches:
201
+ return self.choose_random_move(battle)
202
+ else:
203
+ print("AI Fallback: No moves or switches available. Using Struggle/Default.")
204
+ # await self.send_message("AI Fallback: No moves or switches available. Using Struggle.", battle=battle) # Uncomment
205
+ return self.choose_default_move(battle) # Handles struggle
206
+
207
+ async def _get_llm_decision(self, battle_state: str) -> dict:
208
+ raise NotImplementedError("Subclasses must implement _get_llm_decision")
209
+
210
+ # --- Google Gemini Agent ---
211
+ class GeminiAgent(LLMAgentBase):
212
+ """Uses Google Gemini API for decisions."""
213
+ def __init__(self, api_key: str | None = None, model: str = "gemini-1.5-flash", *args, **kwargs): # Default to flash for speed/cost
214
+ super().__init__(*args, **kwargs)
215
+ self.model_name = model
216
+ used_api_key = api_key or os.environ.get("GOOGLE_API_KEY")
217
+ self.model_name=model
218
+ if not used_api_key:
219
+ raise ValueError("Google API key not provided or found in GOOGLE_API_KEY env var.")
220
+ self.client = genai.Client(
221
+ api_key='GEMINI_API_KEY',
222
+ http_options=types.HttpOptions(api_version='v1alpha')
223
+ )
224
+ # --- Correct Tool Definition ---
225
+ # Create a list of function declaration dictionaries from the values in STANDARD_TOOL_SCHEMA
226
+ function_declarations = list(self.standard_tools.values())
227
+ # Create the Tool object expected by the API
228
+ self.gemini_tool_config = types.Tool(function_declarations=function_declarations)
229
+ # --- End Tool Definition ---
230
+
231
+ # --- Correct Model Initialization ---
232
+ # Pass the Tool object directly to the model's 'tools' parameter
233
+
234
+ # --- End Model Initialization ---
235
+
236
+ async def _get_llm_decision(self, battle_state: str) -> dict:
237
+ """Sends state to the Gemini API and gets back the function call decision."""
238
+ prompt = (
239
+ "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
240
+ "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
241
+ "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
242
+ "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
243
+ "Use the provided functions to indicate your choice.\n\n"
244
+ f"Current Battle State:\n{battle_state}\n\n"
245
+ "Choose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
246
+ )
247
+
248
+ try:
249
+ # --- Correct API Call ---
250
+ # Call generate_content_async directly on the model object.
251
+ # Tools are already configured in the model, no need to pass config here.
252
+ response = await client.aio.models.generate_content(
253
+ model=self.model_name,
254
+ contents=prompt
255
+ )
256
+ # --- End API Call ---
257
+
258
+ # --- Response Parsing (Your logic was already good here) ---
259
+ # Check candidates and parts safely
260
+ if not response.candidates:
261
+ finish_reason_str = "No candidates found"
262
+ try: finish_reason_str = response.prompt_feedback.block_reason.name
263
+ except AttributeError: pass
264
+ return {"error": f"Gemini response issue. Reason: {finish_reason_str}"}
265
+
266
+ candidate = response.candidates[0]
267
+ if not candidate.content or not candidate.content.parts:
268
+ finish_reason_str = "Unknown"
269
+ try: finish_reason_str = candidate.finish_reason.name
270
+ except AttributeError: pass
271
+ return {"error": f"Gemini response issue. Finish Reason: {finish_reason_str}"}
272
+
273
+ part = candidate.content.parts[0]
274
+
275
+ # Check for function_call attribute
276
+ if hasattr(part, 'function_call') and part.function_call:
277
+ fc = part.function_call
278
+ function_name = fc.name
279
+ # fc.args is a proto_plus.MapComposite, convert to dict
280
+ arguments = dict(fc.args) if fc.args else {}
281
+
282
+ if function_name in self.standard_tools:
283
+ # Valid function call found
284
+ return {"decision": {"name": function_name, "arguments": arguments}}
285
+ else:
286
+ # Model hallucinated a function name
287
+ return {"error": f"Model called unknown function '{function_name}'. Args: {arguments}"}
288
+ elif hasattr(part, 'text'):
289
+ # Handle case where the model returns text instead of a function call
290
+ text_response = part.text
291
+ return {"error": f"Gemini did not return a function call. Response: {text_response}"}
292
+ else:
293
+ # Unexpected part type
294
+ return {"error": f"Gemini response part type unknown. Part: {part}"}
295
+ # --- End Response Parsing ---
296
+ except Exception as e:
297
+ # Catch any other unexpected errors during the API call or processing
298
+ print(f"Unexpected error during Gemini processing: {e}")
299
+ import traceback
300
+ traceback.print_exc() # Print stack trace for debugging
301
+ return {"error": f"Unexpected error: {str(e)}"}
302
+ # --- OpenAI Agent ---
303
+ class OpenAIAgent(LLMAgentBase):
304
+ """Uses OpenAI API for decisions."""
305
+ def __init__(self, api_key: str | None = None, model: str = "gpt-4o", *args, **kwargs):
306
+ super().__init__(*args, **kwargs)
307
+ self.model = model
308
+ used_api_key = api_key or os.environ.get("OPENAI_API_KEY")
309
+ if not used_api_key:
310
+ raise ValueError("OpenAI API key not provided or found in OPENAI_API_KEY env var.")
311
+ self.openai_client = AsyncOpenAI(api_key=used_api_key)
312
+
313
+ # Convert standard schema to OpenAI's format
314
+ self.openai_functions = [v for k, v in self.standard_tools.items()]
315
+
316
+ async def _get_llm_decision(self, battle_state: str) -> dict:
317
+ system_prompt = (
318
+ "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
319
+ "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
320
+ "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
321
+ "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
322
+ "Use the provided functions to indicate your choice."
323
+ )
324
+ user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
325
+
326
+ try:
327
+ response = await self.openai_client.chat.completions.create(
328
+ model=self.model,
329
+ messages=[
330
+ {"role": "system", "content": system_prompt},
331
+ {"role": "user", "content": user_prompt},
332
+ ],
333
+ functions=STANDARD_TOOL_SCHEMA,
334
+ function_call="auto",
335
+ temperature=0.5,
336
+ )
337
+ message = response.choices[0].message
338
+ print("MESSAGE BACK : ", message)
339
+ if message.function_call:
340
+ function_name = message.function_call.name
341
+ try:
342
+ # Ensure arguments is always a dict, even if empty/null
343
+ arguments = json.loads(message.function_call.arguments or '{}')
344
+ if function_name in self.standard_tools: # Validate function name
345
+ return {"decision": {"name": function_name, "arguments": arguments}}
346
+ else:
347
+ return {"error": f"Model called unknown function '{function_name}'."}
348
+ except json.JSONDecodeError:
349
+ return {"error": f"Error decoding function call arguments: {message.function_call.arguments}"}
350
+ else:
351
+ # Model decided not to call a function (or generated text instead)
352
+ return {"error": f"OpenAI did not return a function call. Response: {message.content}"}
353
+
354
+ except APIError as e:
355
+ print(f"Error during OpenAI API call: {e}")
356
+ return {"error": f"OpenAI API Error: {e.status_code} - {e.message}"}
357
+ except Exception as e:
358
+ print(f"Unexpected error during OpenAI API call: {e}")
359
+ return {"error": f"Unexpected error: {e}"}
360
+
361
+
362
+ # --- Mistral Agent ---
363
+ class MistralAgent(LLMAgentBase):
364
+ """Uses Mistral AI API for decisions."""
365
+ def __init__(self, api_key: str | None = None, model: str = "mistral-large-latest", *args, **kwargs):
366
+ super().__init__(*args, **kwargs)
367
+ self.model = model
368
+ used_api_key = api_key or os.environ.get("MISTRAL_API_KEY")
369
+ if not used_api_key:
370
+ raise ValueError("Mistral API key not provided or found in MISTRAL_API_KEY env var.")
371
+ self.mistral_client = MistralAsyncClient(api_key=used_api_key)
372
+
373
+ # Convert standard schema to Mistral's tool format (very similar to OpenAI's)
374
+ self.mistral_tools = STANDARD_TOOL_SCHEMA
375
+
376
+ async def _get_llm_decision(self, battle_state: str) -> dict:
377
+ system_prompt = (
378
+ "You are a skilled Pokemon battle AI. Your goal is to win the battle. "
379
+ "Based on the current battle state, decide the best action: either use an available move or switch to an available Pokémon. "
380
+ "Consider type matchups, HP, status conditions, field effects, entry hazards, and potential opponent actions. "
381
+ "Only choose actions listed as available using their exact ID (for moves) or species name (for switches). "
382
+ "Use the provided tools/functions to indicate your choice."
383
+ )
384
+ user_prompt = f"Current Battle State:\n{battle_state}\n\nChoose the best action by calling the appropriate function ('choose_move' or 'choose_switch')."
385
+
386
+ try:
387
+ response = await self.mistral_client.chat.complete(
388
+ model=self.model,
389
+ messages=[
390
+ {"role": "system", "content": f"{system}"},
391
+ {"role": "user", "content": f"{user_prompt}"}
392
+ ],
393
+ tools=self.mistral_tools,
394
+ tool_choice="auto", # Let the model choose
395
+ temperature=0.5,
396
+ )
397
+
398
+ message = response.choices[0].message
399
+ # Mistral returns tool_calls as a list
400
+ if message.tool_calls:
401
+ tool_call = response.choices[0].message.tool_calls[0]
402
+ function_name = tool_call.function.name
403
+ function_params = json.loads(tool_call.function.arguments)
404
+ print("\nfunction_name: ", function_name, "\nfunction_params: ", function_params)
405
+ if function_name and function_params: # Validate function name
406
+ return {"decision": {"name": function_name, "arguments": arguments}}
407
+
408
+ else:
409
+ # Model decided not to call a tool (or generated text instead)
410
+ return {"error": f"Mistral did not return a tool call. Response: {message.content}"}
411
+
412
+ # Mistral client might raise specific exceptions, add them here if known
413
+ # from mistralai.exceptions import MistralAPIException # Example
414
+ except Exception as e: # Catch general exceptions for now
415
+ print(f"Error during Mistral API call: {e}")
416
+ # Try to get specific details if it's a known exception type
417
+ error_details = str(e)
418
+ # if isinstance(e, MistralAPIException): # Example
419
+ # error_details = f"{e.status_code} - {e.message}"
420
+ return {"error": f"Mistral API Error: {error_details}"}