lcipolina commited on
Commit
fda5d3e
·
verified ·
1 Parent(s): e55fbd6

Upload base_simulator.py

Browse files
Files changed (1) hide show
  1. simulators/base_simulator.py +275 -0
simulators/base_simulator.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Base class for simulating games.'''
2
+
3
+ import os
4
+ import json
5
+ from typing import Dict, Any, List
6
+ from abc import ABC
7
+ import random
8
+ from utils.llm_utils import generate_prompt, llm_decide_move
9
+ from enum import Enum, unique
10
+
11
+
12
+ @unique
13
+ class PlayerId(Enum):
14
+ CHANCE = -1
15
+ SIMULTANEOUS = -2
16
+ INVALID = -3
17
+ TERMINAL = -4
18
+ MEAN_FIELD = -5
19
+
20
+ @classmethod
21
+ def from_value(cls, value: int):
22
+ """Returns the PlayerId corresponding to a given integer value.
23
+
24
+ Args:
25
+ value (int): The numerical value to map to a PlayerId.
26
+
27
+ Returns:
28
+ PlayerId: The matching enum member, or raises a ValueError if invalid.
29
+ """
30
+ for member in cls:
31
+ if member.value == value:
32
+ return member
33
+ if value >= 0: # Positive integers represent default players
34
+ return None # No enum corresponds to these values directly
35
+ raise ValueError(f"Unknown player ID value: {value}")
36
+
37
+
38
+ class PlayerType(Enum):
39
+ HUMAN = "human"
40
+ RANDOM_BOT = "random_bot"
41
+ LLM = "llm"
42
+ SELF_PLAY = "self_play"
43
+
44
+
45
+ class GameSimulator(ABC):
46
+ """Base class for simulating games with LLMs.
47
+
48
+ Handles common functionality like state transitions, scoring, and logging.
49
+ """
50
+
51
+ def __init__(self, game: Any, game_name: str, llms: Dict[str, Any],
52
+ player_type: Dict[str, str], max_game_rounds: int = None):
53
+ """
54
+ Args:
55
+ game (Any): The OpenSpiel game object being simulated.
56
+ game_name (str): A human-readable name for the game (for logging and reporting).
57
+ llms (Dict[str, Any]): A dictionary mapping player names (e.g., "Player 1")
58
+ to their corresponding LLM instances. Can be empty if no LLMs are used.
59
+ player_type (Dict[str, str]): A dictionary mapping player names to their types.
60
+ max_game_rounds (int): Maximum number of rounds for iterated games. Ignored by single-shot games.
61
+ """
62
+ self.game = game
63
+ self.game_name = game_name
64
+ self.llms = llms
65
+ self.player_type = player_type
66
+ self.max_game_rounds = max_game_rounds # For iterated games
67
+ self.scores = {name: 0 for name in self.llms.keys()} # Initialize scores
68
+
69
+ def simulate(self, rounds: int = 1, log_fn=None) -> Dict[str, Any]:
70
+ """Simulates a game for multiple rounds and computes metrics .
71
+
72
+ Args:
73
+ rounds: Number of times the game should be played.
74
+ log_fn: Optional function to log intermediate states.
75
+
76
+ Returns:
77
+ Dict[str, Any]: Summary of results for all rounds.
78
+ """
79
+ outcomes = self._initialize_outcomes() # Reset the outcomes dictionary
80
+
81
+ for _ in range(rounds):
82
+ self.scores = {name: 0 for name in self.llms.keys()} # Reset scores
83
+ state = self.game.new_initial_state()
84
+
85
+ while not state.is_terminal():
86
+ if self.max_game_rounds is not None and state.move_number() >= self.max_game_rounds:
87
+ # If max_game_rounds is specified, terminate the game after the maximum number of rounds.
88
+ # The state.move_number() method tracks the number of moves (or rounds) within the game.
89
+ # This ensures that iterated games, such as the Iterated Prisoner's Dilemma,
90
+ # stop after the specified number of rounds, even if the game would naturally continue.
91
+ break
92
+ if log_fn:
93
+ log_fn(state)
94
+
95
+ # Collect actions
96
+ current_player = state.current_player()
97
+ player_id = self.normalize_player_id(current_player)
98
+
99
+ if player_id == PlayerId.CHANCE.value:
100
+ # Handle chance nodes where the environment acts randomly.
101
+ self._handle_chance_node(state)
102
+ elif player_id == PlayerId.SIMULTANEOUS.value:
103
+ # Handle simultaneous moves for all players.
104
+ actions = self._collect_actions(state)
105
+ state.apply_actions(actions)
106
+ elif player_id == PlayerId.TERMINAL.value:
107
+ break
108
+ elif current_player >= 0: # Default players (turn-based)
109
+ legal_actions = state.legal_actions(current_player)
110
+ action = self._get_action(current_player, state, legal_actions)
111
+ state.apply_action(action)
112
+ else:
113
+ raise ValueError(f"Unexpected player ID: {current_player}")
114
+
115
+ # Record outcomes
116
+ final_scores = state.returns()
117
+ self._record_outcomes(final_scores, outcomes)
118
+
119
+ return outcomes
120
+
121
+ def _handle_chance_node(self, state: Any):
122
+ """Handle chance nodes. Default behavior raises an error."""
123
+ raise NotImplementedError("Chance node handling not implemented for this game.")
124
+
125
+
126
+ def _collect_actions(self, state: Any) -> List[int]:
127
+ """Collects actions for all players in a simultaneous-move game.
128
+
129
+ Args:
130
+ state: The current game state.
131
+
132
+ Returns:
133
+ List[int]: Actions chosen by all players.
134
+ """
135
+ return [
136
+ self._get_action(player, state, state.legal_actions(player))
137
+ for player in range(self.game.num_players())
138
+ ]
139
+
140
+ def _initialize_outcomes(self) -> Dict[str, Any]:
141
+ """Initializes the outcomes dictionary."""
142
+ return {"wins": {name: 0 for name in self.llms.keys()},
143
+ "losses": {name: 0 for name in self.llms.keys()},
144
+ "ties": 0
145
+ }
146
+
147
+
148
+ def _get_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
149
+ """Gets the action for the current player.
150
+
151
+ Args:
152
+ player: The index of the current player.
153
+ state: The current game state.
154
+ legal_actions: The legal actions available for the player.
155
+
156
+ Returns:
157
+ int: The action selected by the player.
158
+ """
159
+ player_name = f"Player {player + 1}" # Map index to player name
160
+ player_type = self.player_type.get(player_name)
161
+
162
+ if player_type == PlayerType.HUMAN.value:
163
+ return self._get_human_action(state, legal_actions)
164
+ if player_type == PlayerType.RANDOM_BOT.value:
165
+ return random.choice(legal_actions)
166
+ if player_type == PlayerType.LLM.value:
167
+ return self._get_llm_action(player, state, legal_actions)
168
+
169
+ raise ValueError(f"Unknown player type for {player_name}: {player_type}")
170
+
171
+
172
+ def _get_human_action(self, state: Any, legal_actions: List[int]) -> int:
173
+ """Handles input for human players."""
174
+ print(f"Current state of {self.game_name}:\n{state}")
175
+ print(f"Your options: {legal_actions}") # Display legal moves to the user
176
+ while True:
177
+ try:
178
+ action = int(input("Enter your action (number): "))
179
+ if action in legal_actions: # Validate the move
180
+ return action
181
+ except ValueError:
182
+ pass
183
+ print("Invalid action. Please choose from:", legal_actions)
184
+
185
+ def _get_llm_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
186
+ """Handles LLM-based decisions."""
187
+ player_name = f"Player {player + 1}"
188
+ llm = self.llms[player_name]
189
+ prompt = generate_prompt(self.game_name, str(state), legal_actions)
190
+ return llm_decide_move(llm, prompt, tuple(legal_actions))
191
+
192
+ def _apply_default_action(self, state):
193
+ """
194
+ Applies a default action when the current player is invalid.
195
+ """
196
+ state.apply_action(random.choice(state.legal_actions()))
197
+
198
+ def _record_outcomes(self, final_scores: List[float], outcomes: Dict[str, Any]) -> str:
199
+ """Records the outcome of a single game round.
200
+
201
+ Args:
202
+ final_scores (List[float]): Final cumulative scores of all players.
203
+ outcomes (Dict[str, Any]): Dictionary to record wins, losses, and ties.
204
+
205
+ Returns:
206
+ str: Name of the winner or "tie" if there is no single winner.
207
+ """
208
+ # Check if all scores are equal (a tie)
209
+ if all(score == final_scores[0] for score in final_scores):
210
+ outcomes["ties"] += 1
211
+ return "tie"
212
+
213
+ # Find the maximum score and determine winners
214
+ max_score = max(final_scores)
215
+ winners = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] == max_score]
216
+
217
+ # Track losers as players who do not have the maximum score
218
+ losers = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] != max_score]
219
+
220
+ # If there is one winner, record it; otherwise, record as a tie
221
+ if len(winners) == 1:
222
+ outcomes["wins"][winners[0]] += 1
223
+ for loser in losers:
224
+ outcomes["losses"][loser] += 1
225
+ return winners[0]
226
+ else:
227
+ outcomes["ties"] += 1
228
+ return "tie"
229
+
230
+
231
+ def save_results(self, state: Any, final_scores: List[float]) -> None:
232
+ """Save simulation results to a JSON file."""
233
+ results = self._prepare_results(state, final_scores)
234
+ filename = self._get_results_filename()
235
+
236
+ with open(filename, "w") as f:
237
+ json.dump(results, f, indent=4)
238
+ print(f"Results saved to {filename}")
239
+
240
+ def _prepare_results(self, state: Any, final_scores: List[float]) -> Dict[str, Any]:
241
+ """Prepares the results dictionary for JSON serialization."""
242
+ final_scores = final_scores.tolist() if hasattr(final_scores, "tolist") else final_scores
243
+ return {
244
+ "game_name": self.game_name,
245
+ "final_state": str(state),
246
+ "scores": self.scores,
247
+ "returns": final_scores,
248
+ "history": state.history_str(),
249
+ }
250
+
251
+ def _get_results_filename(self) -> str:
252
+ """Generates the filename for saving results."""
253
+ results_dir = "results"
254
+ os.makedirs(results_dir, exist_ok=True)
255
+ return os.path.join(results_dir, f"{self.game_name.lower().replace(' ', '_')}_results.json")
256
+
257
+ def log_progress(self, state: Any) -> None:
258
+ """Log the current game state."""
259
+ print(f"Current state of {self.game_name}:\n{state}")
260
+
261
+ def normalize_player_id(self,player_id):
262
+ """Normalize player_id to its integer value for consistent comparisons.
263
+
264
+ This is needed as OpenSpiel has ambiguous representation of the playerID
265
+
266
+ Args:
267
+ player_id (Union[int, PlayerId]): The player ID, which can be an
268
+ integer or a PlayerId enum instance.
269
+
270
+ Returns:
271
+ int: The integer value of the player ID.
272
+ """
273
+ if isinstance(player_id, PlayerId):
274
+ return player_id.value # Extract the integer value from the enum
275
+ return player_id # If already an integer, return it as is