''' Base class for simulating games.''' import os import json from typing import Dict, Any, List from abc import ABC import random from utils.llm_utils import generate_prompt, llm_decide_move from enum import Enum, unique @unique class PlayerId(Enum): CHANCE = -1 SIMULTANEOUS = -2 INVALID = -3 TERMINAL = -4 MEAN_FIELD = -5 @classmethod def from_value(cls, value: int): """Returns the PlayerId corresponding to a given integer value. Args: value (int): The numerical value to map to a PlayerId. Returns: PlayerId: The matching enum member, or raises a ValueError if invalid. """ for member in cls: if member.value == value: return member if value >= 0: # Positive integers represent default players return None # No enum corresponds to these values directly raise ValueError(f"Unknown player ID value: {value}") class PlayerType(Enum): HUMAN = "human" RANDOM_BOT = "random_bot" LLM = "llm" SELF_PLAY = "self_play" class GameSimulator(ABC): """Base class for simulating games with LLMs. Handles common functionality like state transitions, scoring, and logging. """ def __init__(self, game: Any, game_name: str, llms: Dict[str, Any], player_type: Dict[str, str], max_game_rounds: int = None): """ Args: game (Any): The OpenSpiel game object being simulated. game_name (str): A human-readable name for the game (for logging and reporting). llms (Dict[str, Any]): A dictionary mapping player names (e.g., "Player 1") to their corresponding LLM instances. Can be empty if no LLMs are used. player_type (Dict[str, str]): A dictionary mapping player names to their types. max_game_rounds (int): Maximum number of rounds for iterated games. Ignored by single-shot games. """ self.game = game self.game_name = game_name self.llms = llms self.player_type = player_type self.max_game_rounds = max_game_rounds # For iterated games self.scores = {name: 0 for name in self.llms.keys()} # Initialize scores def simulate(self, rounds: int = 1, log_fn=None) -> Dict[str, Any]: """Simulates a game for multiple rounds and computes metrics . Args: rounds: Number of times the game should be played. log_fn: Optional function to log intermediate states. Returns: Dict[str, Any]: Summary of results for all rounds. """ outcomes = self._initialize_outcomes() # Reset the outcomes dictionary for _ in range(rounds): self.scores = {name: 0 for name in self.llms.keys()} # Reset scores state = self.game.new_initial_state() while not state.is_terminal(): if self.max_game_rounds is not None and state.move_number() >= self.max_game_rounds: # If max_game_rounds is specified, terminate the game after the maximum number of rounds. # The state.move_number() method tracks the number of moves (or rounds) within the game. # This ensures that iterated games, such as the Iterated Prisoner's Dilemma, # stop after the specified number of rounds, even if the game would naturally continue. break if log_fn: log_fn(state) # Collect actions current_player = state.current_player() player_id = self.normalize_player_id(current_player) if player_id == PlayerId.CHANCE.value: # Handle chance nodes where the environment acts randomly. self._handle_chance_node(state) elif player_id == PlayerId.SIMULTANEOUS.value: # Handle simultaneous moves for all players. actions = self._collect_actions(state) state.apply_actions(actions) elif player_id == PlayerId.TERMINAL.value: break elif current_player >= 0: # Default players (turn-based) legal_actions = state.legal_actions(current_player) action = self._get_action(current_player, state, legal_actions) state.apply_action(action) else: raise ValueError(f"Unexpected player ID: {current_player}") # Record outcomes final_scores = state.returns() self._record_outcomes(final_scores, outcomes) return outcomes def _handle_chance_node(self, state: Any): """Handle chance nodes. Default behavior raises an error.""" raise NotImplementedError("Chance node handling not implemented for this game.") def _collect_actions(self, state: Any) -> List[int]: """Collects actions for all players in a simultaneous-move game. Args: state: The current game state. Returns: List[int]: Actions chosen by all players. """ return [ self._get_action(player, state, state.legal_actions(player)) for player in range(self.game.num_players()) ] def _initialize_outcomes(self) -> Dict[str, Any]: """Initializes the outcomes dictionary.""" return {"wins": {name: 0 for name in self.llms.keys()}, "losses": {name: 0 for name in self.llms.keys()}, "ties": 0 } def _get_action(self, player: int, state: Any, legal_actions: List[int]) -> int: """Gets the action for the current player. Args: player: The index of the current player. state: The current game state. legal_actions: The legal actions available for the player. Returns: int: The action selected by the player. """ player_name = f"Player {player + 1}" # Map index to player name player_type = self.player_type.get(player_name) if player_type == PlayerType.HUMAN.value: return self._get_human_action(state, legal_actions) if player_type == PlayerType.RANDOM_BOT.value: return random.choice(legal_actions) if player_type == PlayerType.LLM.value: return self._get_llm_action(player, state, legal_actions) raise ValueError(f"Unknown player type for {player_name}: {player_type}") def _get_human_action(self, state: Any, legal_actions: List[int]) -> int: """Handles input for human players.""" print(f"Current state of {self.game_name}:\n{state}") print(f"Your options: {legal_actions}") # Display legal moves to the user while True: try: action = int(input("Enter your action (number): ")) if action in legal_actions: # Validate the move return action except ValueError: pass print("Invalid action. Please choose from:", legal_actions) def _get_llm_action(self, player: int, state: Any, legal_actions: List[int]) -> int: """Handles LLM-based decisions.""" player_name = f"Player {player + 1}" llm = self.llms[player_name] prompt = generate_prompt(self.game_name, str(state), legal_actions) return llm_decide_move(llm, prompt, tuple(legal_actions)) def _apply_default_action(self, state): """ Applies a default action when the current player is invalid. """ state.apply_action(random.choice(state.legal_actions())) def _record_outcomes(self, final_scores: List[float], outcomes: Dict[str, Any]) -> str: """Records the outcome of a single game round. Args: final_scores (List[float]): Final cumulative scores of all players. outcomes (Dict[str, Any]): Dictionary to record wins, losses, and ties. Returns: str: Name of the winner or "tie" if there is no single winner. """ # Check if all scores are equal (a tie) if all(score == final_scores[0] for score in final_scores): outcomes["ties"] += 1 return "tie" # Find the maximum score and determine winners max_score = max(final_scores) winners = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] == max_score] # Track losers as players who do not have the maximum score losers = [name for i, name in enumerate(self.llms.keys()) if final_scores[i] != max_score] # If there is one winner, record it; otherwise, record as a tie if len(winners) == 1: outcomes["wins"][winners[0]] += 1 for loser in losers: outcomes["losses"][loser] += 1 return winners[0] else: outcomes["ties"] += 1 return "tie" def save_results(self, state: Any, final_scores: List[float]) -> None: """Save simulation results to a JSON file.""" results = self._prepare_results(state, final_scores) filename = self._get_results_filename() with open(filename, "w") as f: json.dump(results, f, indent=4) print(f"Results saved to {filename}") def _prepare_results(self, state: Any, final_scores: List[float]) -> Dict[str, Any]: """Prepares the results dictionary for JSON serialization.""" final_scores = final_scores.tolist() if hasattr(final_scores, "tolist") else final_scores return { "game_name": self.game_name, "final_state": str(state), "scores": self.scores, "returns": final_scores, "history": state.history_str(), } def _get_results_filename(self) -> str: """Generates the filename for saving results.""" results_dir = "results" os.makedirs(results_dir, exist_ok=True) return os.path.join(results_dir, f"{self.game_name.lower().replace(' ', '_')}_results.json") def log_progress(self, state: Any) -> None: """Log the current game state.""" print(f"Current state of {self.game_name}:\n{state}") def normalize_player_id(self,player_id): """Normalize player_id to its integer value for consistent comparisons. This is needed as OpenSpiel has ambiguous representation of the playerID Args: player_id (Union[int, PlayerId]): The player ID, which can be an integer or a PlayerId enum instance. Returns: int: The integer value of the player ID. """ if isinstance(player_id, PlayerId): return player_id.value # Extract the integer value from the enum return player_id # If already an integer, return it as is