Spaces:
Sleeping
Sleeping
Upload base_simulator.py
Browse files- simulators/base_simulator.py +275 -0
simulators/base_simulator.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Base class for simulating games.'''
|
2 |
+
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from typing import Dict, Any, List
|
6 |
+
from abc import ABC
|
7 |
+
import random
|
8 |
+
from utils.llm_utils import generate_prompt, llm_decide_move
|
9 |
+
from enum import Enum, unique
|
10 |
+
|
11 |
+
|
12 |
+
@unique
|
13 |
+
class PlayerId(Enum):
|
14 |
+
CHANCE = -1
|
15 |
+
SIMULTANEOUS = -2
|
16 |
+
INVALID = -3
|
17 |
+
TERMINAL = -4
|
18 |
+
MEAN_FIELD = -5
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def from_value(cls, value: int):
|
22 |
+
"""Returns the PlayerId corresponding to a given integer value.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
value (int): The numerical value to map to a PlayerId.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
PlayerId: The matching enum member, or raises a ValueError if invalid.
|
29 |
+
"""
|
30 |
+
for member in cls:
|
31 |
+
if member.value == value:
|
32 |
+
return member
|
33 |
+
if value >= 0: # Positive integers represent default players
|
34 |
+
return None # No enum corresponds to these values directly
|
35 |
+
raise ValueError(f"Unknown player ID value: {value}")
|
36 |
+
|
37 |
+
|
38 |
+
class PlayerType(Enum):
|
39 |
+
HUMAN = "human"
|
40 |
+
RANDOM_BOT = "random_bot"
|
41 |
+
LLM = "llm"
|
42 |
+
SELF_PLAY = "self_play"
|
43 |
+
|
44 |
+
|
45 |
+
class GameSimulator(ABC):
|
46 |
+
"""Base class for simulating games with LLMs.
|
47 |
+
|
48 |
+
Handles common functionality like state transitions, scoring, and logging.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, game: Any, game_name: str, llms: Dict[str, Any],
|
52 |
+
player_type: Dict[str, str], max_game_rounds: int = None):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
game (Any): The OpenSpiel game object being simulated.
|
56 |
+
game_name (str): A human-readable name for the game (for logging and reporting).
|
57 |
+
llms (Dict[str, Any]): A dictionary mapping player names (e.g., "Player 1")
|
58 |
+
to their corresponding LLM instances. Can be empty if no LLMs are used.
|
59 |
+
player_type (Dict[str, str]): A dictionary mapping player names to their types.
|
60 |
+
max_game_rounds (int): Maximum number of rounds for iterated games. Ignored by single-shot games.
|
61 |
+
"""
|
62 |
+
self.game = game
|
63 |
+
self.game_name = game_name
|
64 |
+
self.llms = llms
|
65 |
+
self.player_type = player_type
|
66 |
+
self.max_game_rounds = max_game_rounds # For iterated games
|
67 |
+
self.scores = {name: 0 for name in self.llms.keys()} # Initialize scores
|
68 |
+
|
69 |
+
def simulate(self, rounds: int = 1, log_fn=None) -> Dict[str, Any]:
|
70 |
+
"""Simulates a game for multiple rounds and computes metrics .
|
71 |
+
|
72 |
+
Args:
|
73 |
+
rounds: Number of times the game should be played.
|
74 |
+
log_fn: Optional function to log intermediate states.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Dict[str, Any]: Summary of results for all rounds.
|
78 |
+
"""
|
79 |
+
outcomes = self._initialize_outcomes() # Reset the outcomes dictionary
|
80 |
+
|
81 |
+
for _ in range(rounds):
|
82 |
+
self.scores = {name: 0 for name in self.llms.keys()} # Reset scores
|
83 |
+
state = self.game.new_initial_state()
|
84 |
+
|
85 |
+
while not state.is_terminal():
|
86 |
+
if self.max_game_rounds is not None and state.move_number() >= self.max_game_rounds:
|
87 |
+
# If max_game_rounds is specified, terminate the game after the maximum number of rounds.
|
88 |
+
# The state.move_number() method tracks the number of moves (or rounds) within the game.
|
89 |
+
# This ensures that iterated games, such as the Iterated Prisoner's Dilemma,
|
90 |
+
# stop after the specified number of rounds, even if the game would naturally continue.
|
91 |
+
break
|
92 |
+
if log_fn:
|
93 |
+
log_fn(state)
|
94 |
+
|
95 |
+
# Collect actions
|
96 |
+
current_player = state.current_player()
|
97 |
+
player_id = self.normalize_player_id(current_player)
|
98 |
+
|
99 |
+
if player_id == PlayerId.CHANCE.value:
|
100 |
+
# Handle chance nodes where the environment acts randomly.
|
101 |
+
self._handle_chance_node(state)
|
102 |
+
elif player_id == PlayerId.SIMULTANEOUS.value:
|
103 |
+
# Handle simultaneous moves for all players.
|
104 |
+
actions = self._collect_actions(state)
|
105 |
+
state.apply_actions(actions)
|
106 |
+
elif player_id == PlayerId.TERMINAL.value:
|
107 |
+
break
|
108 |
+
elif current_player >= 0: # Default players (turn-based)
|
109 |
+
legal_actions = state.legal_actions(current_player)
|
110 |
+
action = self._get_action(current_player, state, legal_actions)
|
111 |
+
state.apply_action(action)
|
112 |
+
else:
|
113 |
+
raise ValueError(f"Unexpected player ID: {current_player}")
|
114 |
+
|
115 |
+
# Record outcomes
|
116 |
+
final_scores = state.returns()
|
117 |
+
self._record_outcomes(final_scores, outcomes)
|
118 |
+
|
119 |
+
return outcomes
|
120 |
+
|
121 |
+
def _handle_chance_node(self, state: Any):
|
122 |
+
"""Handle chance nodes. Default behavior raises an error."""
|
123 |
+
raise NotImplementedError("Chance node handling not implemented for this game.")
|
124 |
+
|
125 |
+
|
126 |
+
def _collect_actions(self, state: Any) -> List[int]:
|
127 |
+
"""Collects actions for all players in a simultaneous-move game.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
state: The current game state.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
List[int]: Actions chosen by all players.
|
134 |
+
"""
|
135 |
+
return [
|
136 |
+
self._get_action(player, state, state.legal_actions(player))
|
137 |
+
for player in range(self.game.num_players())
|
138 |
+
]
|
139 |
+
|
140 |
+
def _initialize_outcomes(self) -> Dict[str, Any]:
|
141 |
+
"""Initializes the outcomes dictionary."""
|
142 |
+
return {"wins": {name: 0 for name in self.llms.keys()},
|
143 |
+
"losses": {name: 0 for name in self.llms.keys()},
|
144 |
+
"ties": 0
|
145 |
+
}
|
146 |
+
|
147 |
+
|
148 |
+
def _get_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
|
149 |
+
"""Gets the action for the current player.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
player: The index of the current player.
|
153 |
+
state: The current game state.
|
154 |
+
legal_actions: The legal actions available for the player.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
int: The action selected by the player.
|
158 |
+
"""
|
159 |
+
player_name = f"Player {player + 1}" # Map index to player name
|
160 |
+
player_type = self.player_type.get(player_name)
|
161 |
+
|
162 |
+
if player_type == PlayerType.HUMAN.value:
|
163 |
+
return self._get_human_action(state, legal_actions)
|
164 |
+
if player_type == PlayerType.RANDOM_BOT.value:
|
165 |
+
return random.choice(legal_actions)
|
166 |
+
if player_type == PlayerType.LLM.value:
|
167 |
+
return self._get_llm_action(player, state, legal_actions)
|
168 |
+
|
169 |
+
raise ValueError(f"Unknown player type for {player_name}: {player_type}")
|
170 |
+
|
171 |
+
|
172 |
+
def _get_human_action(self, state: Any, legal_actions: List[int]) -> int:
|
173 |
+
"""Handles input for human players."""
|
174 |
+
print(f"Current state of {self.game_name}:\n{state}")
|
175 |
+
print(f"Your options: {legal_actions}") # Display legal moves to the user
|
176 |
+
while True:
|
177 |
+
try:
|
178 |
+
action = int(input("Enter your action (number): "))
|
179 |
+
if action in legal_actions: # Validate the move
|
180 |
+
return action
|
181 |
+
except ValueError:
|
182 |
+
pass
|
183 |
+
print("Invalid action. Please choose from:", legal_actions)
|
184 |
+
|
185 |
+
def _get_llm_action(self, player: int, state: Any, legal_actions: List[int]) -> int:
|
186 |
+
"""Handles LLM-based decisions."""
|
187 |
+
player_name = f"Player {player + 1}"
|
188 |
+
llm = self.llms[player_name]
|
189 |
+
prompt = generate_prompt(self.game_name, str(state), legal_actions)
|
190 |
+
return llm_decide_move(llm, prompt, tuple(legal_actions))
|
191 |
+
|
192 |
+
def _apply_default_action(self, state):
|
193 |
+
"""
|
194 |
+
Applies a default action when the current player is invalid.
|
195 |
+
"""
|
196 |
+
state.apply_action(random.choice(state.legal_actions()))
|
197 |
+
|
198 |
+
def _record_outcomes(self, final_scores: List[float], outcomes: Dict[str, Any]) -> str:
|
199 |
+
"""Records the outcome of a single game round.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
final_scores (List[float]): Final cumulative scores of all players.
|
203 |
+
outcomes (Dict[str, Any]): Dictionary to record wins, losses, and ties.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
str: Name of the winner or "tie" if there is no single winner.
|
207 |
+
"""
|
208 |
+
# Check if all scores are equal (a tie)
|
209 |
+
if all(score == final_scores[0] for score in final_scores):
|
210 |
+
outcomes["ties"] += 1
|
211 |
+
return "tie"
|
212 |
+
|
213 |
+
# Find the maximum score and determine winners
|
214 |
+
max_score = max(final_scores)
|
215 |
+
winners = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] == max_score]
|
216 |
+
|
217 |
+
# Track losers as players who do not have the maximum score
|
218 |
+
losers = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] != max_score]
|
219 |
+
|
220 |
+
# If there is one winner, record it; otherwise, record as a tie
|
221 |
+
if len(winners) == 1:
|
222 |
+
outcomes["wins"][winners[0]] += 1
|
223 |
+
for loser in losers:
|
224 |
+
outcomes["losses"][loser] += 1
|
225 |
+
return winners[0]
|
226 |
+
else:
|
227 |
+
outcomes["ties"] += 1
|
228 |
+
return "tie"
|
229 |
+
|
230 |
+
|
231 |
+
def save_results(self, state: Any, final_scores: List[float]) -> None:
|
232 |
+
"""Save simulation results to a JSON file."""
|
233 |
+
results = self._prepare_results(state, final_scores)
|
234 |
+
filename = self._get_results_filename()
|
235 |
+
|
236 |
+
with open(filename, "w") as f:
|
237 |
+
json.dump(results, f, indent=4)
|
238 |
+
print(f"Results saved to {filename}")
|
239 |
+
|
240 |
+
def _prepare_results(self, state: Any, final_scores: List[float]) -> Dict[str, Any]:
|
241 |
+
"""Prepares the results dictionary for JSON serialization."""
|
242 |
+
final_scores = final_scores.tolist() if hasattr(final_scores, "tolist") else final_scores
|
243 |
+
return {
|
244 |
+
"game_name": self.game_name,
|
245 |
+
"final_state": str(state),
|
246 |
+
"scores": self.scores,
|
247 |
+
"returns": final_scores,
|
248 |
+
"history": state.history_str(),
|
249 |
+
}
|
250 |
+
|
251 |
+
def _get_results_filename(self) -> str:
|
252 |
+
"""Generates the filename for saving results."""
|
253 |
+
results_dir = "results"
|
254 |
+
os.makedirs(results_dir, exist_ok=True)
|
255 |
+
return os.path.join(results_dir, f"{self.game_name.lower().replace(' ', '_')}_results.json")
|
256 |
+
|
257 |
+
def log_progress(self, state: Any) -> None:
|
258 |
+
"""Log the current game state."""
|
259 |
+
print(f"Current state of {self.game_name}:\n{state}")
|
260 |
+
|
261 |
+
def normalize_player_id(self,player_id):
|
262 |
+
"""Normalize player_id to its integer value for consistent comparisons.
|
263 |
+
|
264 |
+
This is needed as OpenSpiel has ambiguous representation of the playerID
|
265 |
+
|
266 |
+
Args:
|
267 |
+
player_id (Union[int, PlayerId]): The player ID, which can be an
|
268 |
+
integer or a PlayerId enum instance.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
int: The integer value of the player ID.
|
272 |
+
"""
|
273 |
+
if isinstance(player_id, PlayerId):
|
274 |
+
return player_id.value # Extract the integer value from the enum
|
275 |
+
return player_id # If already an integer, return it as is
|