Eric Botti
commited on
Commit
·
dfdde45
1
Parent(s):
758a706
redid controllers into AgentInterface class, unified message system
Browse files- src/agent_interfaces.py +148 -0
- src/controllers.py +0 -21
- src/game.py +45 -68
- src/message.py +33 -30
- src/player.py +5 -129
src/agent_interfaces.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type, NewType
|
2 |
+
import json
|
3 |
+
|
4 |
+
from openai import OpenAI
|
5 |
+
from colorama import Fore, Style
|
6 |
+
from pydantic import BaseModel, ValidationError
|
7 |
+
|
8 |
+
from message import Message, AgentMessage
|
9 |
+
|
10 |
+
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
11 |
+
Here is the output schema:
|
12 |
+
```
|
13 |
+
{schema}
|
14 |
+
```
|
15 |
+
"""
|
16 |
+
|
17 |
+
|
18 |
+
class BaseAgentInterface:
|
19 |
+
"""
|
20 |
+
The interface that agents use to receive info from and interact with the game.
|
21 |
+
This is the base class and should not be used directly.
|
22 |
+
"""
|
23 |
+
|
24 |
+
is_human: bool = False
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
agent_id: str = None
|
29 |
+
):
|
30 |
+
self.id = agent_id
|
31 |
+
self.messages = []
|
32 |
+
|
33 |
+
def add_message(self, message: Message):
|
34 |
+
"""Adds a message to the message history, without generating a response."""
|
35 |
+
bound_message = AgentMessage.from_message(message, self.id, len(self.messages))
|
36 |
+
self.messages.append(bound_message)
|
37 |
+
|
38 |
+
def respond_to(self, message: Message) -> Message:
|
39 |
+
"""Adds a message to the message history, and generates a response message."""
|
40 |
+
self.add_message(message)
|
41 |
+
response = Message(type="agent", content=self._generate_response())
|
42 |
+
self.add_message(response)
|
43 |
+
return response
|
44 |
+
|
45 |
+
def respond_to_formatted(self, message: Message, output_format: Type[BaseModel], max_retries = 3) -> Type[BaseModel]:
|
46 |
+
"""Adds a message to the message history, and generates a response matching the provided format."""
|
47 |
+
initial_response = self.respond_to(message)
|
48 |
+
|
49 |
+
reformat_message = Message(type="format", content=self._get_format_instructions(output_format))
|
50 |
+
|
51 |
+
output = None
|
52 |
+
retries = 0
|
53 |
+
|
54 |
+
while not output and retries < max_retries:
|
55 |
+
try:
|
56 |
+
formatted_response = self.respond_to(reformat_message)
|
57 |
+
output = output_format.model_validate_json(formatted_response.content)
|
58 |
+
except ValidationError as e:
|
59 |
+
if retries > max_retries:
|
60 |
+
raise e
|
61 |
+
reformat_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
|
62 |
+
retries += 1
|
63 |
+
|
64 |
+
return output
|
65 |
+
|
66 |
+
def _generate_response(self) -> str:
|
67 |
+
"""Generates a response from the Agent."""
|
68 |
+
# This is the BaseAgent class, and thus has no response logic
|
69 |
+
# Subclasses should implement this method to generate a response using the message history
|
70 |
+
raise NotImplementedError
|
71 |
+
|
72 |
+
@property
|
73 |
+
def is_ai(self):
|
74 |
+
return not self.is_human
|
75 |
+
|
76 |
+
# This should probably be put on a theoretical output format class...
|
77 |
+
# Or maybe on the Message class with a from format constructor
|
78 |
+
@staticmethod
|
79 |
+
def _get_format_instructions(output_format: Type[BaseModel]):
|
80 |
+
schema = output_format.model_json_schema()
|
81 |
+
|
82 |
+
reduced_schema = schema
|
83 |
+
if "title" in reduced_schema:
|
84 |
+
del reduced_schema["title"]
|
85 |
+
if "type" in reduced_schema:
|
86 |
+
del reduced_schema["type"]
|
87 |
+
|
88 |
+
schema_str = json.dumps(reduced_schema, indent=4)
|
89 |
+
|
90 |
+
return FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
91 |
+
|
92 |
+
|
93 |
+
AgentInterface = NewType("AgentInterface", BaseAgentInterface)
|
94 |
+
|
95 |
+
|
96 |
+
class OpenAIAgentInterface(BaseAgentInterface):
|
97 |
+
"""An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
|
98 |
+
def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
|
99 |
+
super().__init__(agent_id)
|
100 |
+
self.model_name = model_name
|
101 |
+
self.client = OpenAI()
|
102 |
+
|
103 |
+
def _generate_response(self) -> str:
|
104 |
+
"""Generates a response using the message history"""
|
105 |
+
open_ai_messages = [message.to_openai() for message in self.messages]
|
106 |
+
|
107 |
+
completion = self.client.chat.completions.create(
|
108 |
+
model=self.model_name,
|
109 |
+
messages=open_ai_messages
|
110 |
+
)
|
111 |
+
|
112 |
+
return completion.choices[0].message.content
|
113 |
+
|
114 |
+
|
115 |
+
class HumanAgentInterface(BaseAgentInterface):
|
116 |
+
|
117 |
+
is_human = True
|
118 |
+
|
119 |
+
def respond_to_formatted(self, message: Message, output_format: Type[BaseModel], **kwargs) -> Type[BaseModel]:
|
120 |
+
"""For Human agents, we can trust them enough to format their own responses... for now"""
|
121 |
+
response = super().respond_to(message)
|
122 |
+
# only works because current outputs have only 1 field...
|
123 |
+
field_name = output_format.model_fields.copy().popitem()[0]
|
124 |
+
output = output_format.model_validate({field_name: response.content})
|
125 |
+
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
class HumanAgentCLI(HumanAgentInterface):
|
130 |
+
"""A Human agent that uses the command line interface to generate responses."""
|
131 |
+
|
132 |
+
def __init__(self, agent_id: str):
|
133 |
+
super().__init__(agent_id)
|
134 |
+
|
135 |
+
def add_message(self, message: Message):
|
136 |
+
super().add_message(message)
|
137 |
+
if message.type == "verbose":
|
138 |
+
print(Fore.GREEN + message.content + Style.RESET_ALL)
|
139 |
+
elif message.type == "debug":
|
140 |
+
print(Fore.YELLOW + "DEBUG: " + message.content + Style.RESET_ALL)
|
141 |
+
elif message.type != "agent":
|
142 |
+
# Prevents the agent from seeing its own messages on the command line
|
143 |
+
print(message.content)
|
144 |
+
|
145 |
+
def _generate_response(self) -> str:
|
146 |
+
"""Generates a response using the message history"""
|
147 |
+
response = input()
|
148 |
+
return response
|
src/controllers.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from langchain_core.runnables import Runnable
|
4 |
-
from langchain_openai import ChatOpenAI
|
5 |
-
from langchain_core.messages import AIMessage
|
6 |
-
|
7 |
-
MAX_TOKENS = 50
|
8 |
-
|
9 |
-
|
10 |
-
def controller_from_name(name: str) -> Runnable:
|
11 |
-
if name == "tgi":
|
12 |
-
return ChatOpenAI(
|
13 |
-
api_base=os.environ['HF_ENDPOINT_URL'] + "/v1/",
|
14 |
-
api_key=os.environ['HF_API_TOKEN']
|
15 |
-
)
|
16 |
-
elif name == "openai":
|
17 |
-
return ChatOpenAI(model="gpt-3.5-turbo", max_tokens=MAX_TOKENS)
|
18 |
-
elif name == "ollama":
|
19 |
-
return ChatOpenAI(model="mistral", openai_api_key="ollama", openai_api_base="http://localhost:11434/v1", max_tokens=MAX_TOKENS)
|
20 |
-
else:
|
21 |
-
raise ValueError(f"Unknown controller name: {name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/game.py
CHANGED
@@ -8,15 +8,14 @@ from game_utils import *
|
|
8 |
from models import *
|
9 |
from player import Player
|
10 |
from prompts import fetch_prompt, format_prompt
|
11 |
-
|
12 |
-
from
|
13 |
-
from langchain_core.messages import AIMessage
|
14 |
-
from controllers import controller_from_name
|
15 |
|
16 |
# Default Values
|
17 |
NUMBER_OF_PLAYERS = 6
|
18 |
WINNING_SCORE = 11
|
19 |
|
|
|
20 |
class Game:
|
21 |
|
22 |
log_dir = os.path.join(os.pardir, "experiments")
|
@@ -40,11 +39,6 @@ class Game:
|
|
40 |
# Game ID
|
41 |
self.game_id = game_id()
|
42 |
self.start_time = datetime.now().strftime('%y%m%d-%H%M%S')
|
43 |
-
self.log_dir = os.path.join(self.log_dir, f"{self.start_time}-{self.game_id}")
|
44 |
-
os.makedirs(self.log_dir, exist_ok=True)
|
45 |
-
|
46 |
-
# Choose Chameleon
|
47 |
-
self.chameleon_index = random_index(number_of_players)
|
48 |
|
49 |
# Gather Player Names
|
50 |
if human_name:
|
@@ -62,28 +56,22 @@ class Game:
|
|
62 |
# Add Players
|
63 |
self.players = []
|
64 |
for i in range(0, number_of_players):
|
|
|
|
|
65 |
if self.human_index == i:
|
66 |
name = human_name
|
67 |
-
|
68 |
-
controller = RunnableLambda(self.human_input)
|
69 |
else:
|
70 |
name = ai_names.pop()
|
71 |
-
|
72 |
-
controller = controller_from_name(controller_name)
|
73 |
-
|
74 |
-
if self.chameleon_index == i:
|
75 |
-
role = "chameleon"
|
76 |
-
else:
|
77 |
-
role = "herd"
|
78 |
|
79 |
-
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
self.players.append(Player(name, controller, controller_name, player_id, log_filepath=player_log_path))
|
87 |
|
88 |
# Game State
|
89 |
self.player_responses = []
|
@@ -101,55 +89,43 @@ class Game:
|
|
101 |
|
102 |
return formatted_responses
|
103 |
|
|
|
|
|
|
|
|
|
104 |
|
105 |
def game_message(
|
106 |
-
self,
|
107 |
recipient: Optional[Player] = None, # If None, message is broadcast to all players
|
108 |
exclude: bool = False # If True, the message is broadcast to all players except the chosen player
|
109 |
):
|
110 |
"""Sends a message to a player. No response is expected, however it will be included next time the player is prompted"""
|
|
|
|
|
111 |
if exclude or not recipient:
|
112 |
for player in self.players:
|
113 |
if player != recipient:
|
114 |
-
player.
|
115 |
-
|
116 |
-
self.human_message(message)
|
117 |
-
if self.verbose and not self.human_index:
|
118 |
-
self.human_message(message)
|
119 |
else:
|
120 |
-
recipient.
|
121 |
-
if recipient.controller_type == "human":
|
122 |
-
self.human_message(message)
|
123 |
-
|
124 |
-
async def instructional_message(self, message: str, player: Player, output_format: Type[BaseModel]):
|
125 |
-
"""Sends a message to a specific player and gets their response."""
|
126 |
-
if player.controller_type == "human":
|
127 |
-
self.human_message(message)
|
128 |
-
response = await player.respond_to(message, output_format)
|
129 |
-
return response
|
130 |
-
|
131 |
-
# The following methods are used to broadcast messages to a human.
|
132 |
-
# They are design so that they can be overridden by a subclass for a different player interface.
|
133 |
-
@staticmethod
|
134 |
-
async def human_input(prompt: str) -> AIMessage:
|
135 |
-
"""Gets input from the human player."""
|
136 |
-
response = AIMessage(content=input())
|
137 |
-
return response
|
138 |
-
|
139 |
-
@staticmethod
|
140 |
-
def human_message(message: str):
|
141 |
-
"""Sends a message for the human player to read. No response is expected."""
|
142 |
-
print(message)
|
143 |
|
144 |
-
def verbose_message(self,
|
145 |
"""Sends a message for the human player to read. No response is expected."""
|
146 |
if self.verbose:
|
147 |
-
|
|
|
|
|
|
|
148 |
|
149 |
-
def debug_message(self,
|
150 |
"""Sends a message for a human observer. These messages contain secret information about the players such as their role."""
|
151 |
if self.debug:
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
|
154 |
async def start(self):
|
155 |
"""Sets up the game. This includes assigning roles and gathering player names."""
|
@@ -168,8 +144,6 @@ class Game:
|
|
168 |
|
169 |
log(game_log, game_log_path)
|
170 |
|
171 |
-
|
172 |
-
|
173 |
async def run_round(self):
|
174 |
"""Starts the round."""
|
175 |
|
@@ -194,13 +168,14 @@ class Game:
|
|
194 |
|
195 |
self.game_message(f"Each player will now take turns describing themselves:")
|
196 |
for i, current_player in enumerate(self.players):
|
197 |
-
if current_player.
|
198 |
self.verbose_message(f"{current_player.name} is thinking...")
|
199 |
|
200 |
prompt = fetch_prompt("player_describe_animal")
|
201 |
|
202 |
# Get Player Animal Description
|
203 |
-
|
|
|
204 |
|
205 |
self.player_responses.append({"sender": current_player.name, "response": response.description})
|
206 |
|
@@ -209,12 +184,13 @@ class Game:
|
|
209 |
# Phase III: Chameleon Guesses the Animal
|
210 |
|
211 |
self.game_message("All players have spoken. The Chameleon will now guess the secret animal...")
|
212 |
-
if
|
213 |
self.verbose_message("The Chameleon is thinking...")
|
214 |
|
215 |
prompt = fetch_prompt("chameleon_guess_animal")
|
216 |
|
217 |
-
|
|
|
218 |
|
219 |
chameleon_animal_guess = response.animal
|
220 |
|
@@ -224,19 +200,20 @@ class Game:
|
|
224 |
player_votes = []
|
225 |
for player in self.players:
|
226 |
if player.role == "herd":
|
227 |
-
if player.is_ai
|
228 |
self.verbose_message(f"{player.name} is thinking...")
|
229 |
|
230 |
prompt = format_prompt("vote", player_responses=self.format_responses(exclude=player.name))
|
231 |
|
232 |
# Get Player Vote
|
233 |
-
|
|
|
234 |
|
235 |
# check if a valid player was voted for...
|
236 |
|
237 |
# Add Vote to Player Votes
|
238 |
player_votes.append({"voter": player, "vote": response.vote})
|
239 |
-
if player.is_ai
|
240 |
self.debug_message(f"{player.name} voted for {response.vote}")
|
241 |
|
242 |
|
@@ -285,7 +262,7 @@ class Game:
|
|
285 |
# Log Round Info
|
286 |
round_log = {
|
287 |
"herd_animal": herd_animal,
|
288 |
-
"chameleon_name":
|
289 |
"chameleon_guess": chameleon_animal_guess,
|
290 |
"herd_votes": player_votes,
|
291 |
}
|
|
|
8 |
from models import *
|
9 |
from player import Player
|
10 |
from prompts import fetch_prompt, format_prompt
|
11 |
+
from message import Message
|
12 |
+
from agent_interfaces import HumanAgentCLI, OpenAIAgentInterface
|
|
|
|
|
13 |
|
14 |
# Default Values
|
15 |
NUMBER_OF_PLAYERS = 6
|
16 |
WINNING_SCORE = 11
|
17 |
|
18 |
+
|
19 |
class Game:
|
20 |
|
21 |
log_dir = os.path.join(os.pardir, "experiments")
|
|
|
39 |
# Game ID
|
40 |
self.game_id = game_id()
|
41 |
self.start_time = datetime.now().strftime('%y%m%d-%H%M%S')
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# Gather Player Names
|
44 |
if human_name:
|
|
|
56 |
# Add Players
|
57 |
self.players = []
|
58 |
for i in range(0, number_of_players):
|
59 |
+
player_id = f"{self.game_id}-{i + 1}"
|
60 |
+
|
61 |
if self.human_index == i:
|
62 |
name = human_name
|
63 |
+
interface = HumanAgentCLI(player_id)
|
|
|
64 |
else:
|
65 |
name = ai_names.pop()
|
66 |
+
interface = OpenAIAgentInterface(player_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
self.players.append(Player(name, player_id, interface))
|
69 |
|
70 |
+
# Add Observer - an Agent who can see all the messages, but doesn't actually play
|
71 |
+
if self.verbose or self.debug and not self.human_index:
|
72 |
+
self.observer = HumanAgentCLI("{self.game_id}-observer")
|
73 |
+
else:
|
74 |
+
self.observer = None
|
|
|
75 |
|
76 |
# Game State
|
77 |
self.player_responses = []
|
|
|
89 |
|
90 |
return formatted_responses
|
91 |
|
92 |
+
def observer_message(self, message: Message):
|
93 |
+
"""Sends a message to the observer if there is one."""
|
94 |
+
if self.observer:
|
95 |
+
self.observer.add_message(message)
|
96 |
|
97 |
def game_message(
|
98 |
+
self, content: str,
|
99 |
recipient: Optional[Player] = None, # If None, message is broadcast to all players
|
100 |
exclude: bool = False # If True, the message is broadcast to all players except the chosen player
|
101 |
):
|
102 |
"""Sends a message to a player. No response is expected, however it will be included next time the player is prompted"""
|
103 |
+
message = Message(type="info", content=content)
|
104 |
+
|
105 |
if exclude or not recipient:
|
106 |
for player in self.players:
|
107 |
if player != recipient:
|
108 |
+
player.interface.add_message(message)
|
109 |
+
self.observer_message(message)
|
|
|
|
|
|
|
110 |
else:
|
111 |
+
recipient.interface.add_message(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
def verbose_message(self, content: str):
|
114 |
"""Sends a message for the human player to read. No response is expected."""
|
115 |
if self.verbose:
|
116 |
+
message = Message(type="verbose", content=content)
|
117 |
+
if self.human_index:
|
118 |
+
self.players[self.human_index].interface.add_message(message)
|
119 |
+
self.observer_message(message)
|
120 |
|
121 |
+
def debug_message(self, content: str):
|
122 |
"""Sends a message for a human observer. These messages contain secret information about the players such as their role."""
|
123 |
if self.debug:
|
124 |
+
message = Message(type="debug", content=content)
|
125 |
+
if self.human_index:
|
126 |
+
self.players[self.human_index].interface.add_message(message)
|
127 |
+
self.observer_message(message)
|
128 |
+
|
129 |
|
130 |
async def start(self):
|
131 |
"""Sets up the game. This includes assigning roles and gathering player names."""
|
|
|
144 |
|
145 |
log(game_log, game_log_path)
|
146 |
|
|
|
|
|
147 |
async def run_round(self):
|
148 |
"""Starts the round."""
|
149 |
|
|
|
168 |
|
169 |
self.game_message(f"Each player will now take turns describing themselves:")
|
170 |
for i, current_player in enumerate(self.players):
|
171 |
+
if current_player.interface.is_ai:
|
172 |
self.verbose_message(f"{current_player.name} is thinking...")
|
173 |
|
174 |
prompt = fetch_prompt("player_describe_animal")
|
175 |
|
176 |
# Get Player Animal Description
|
177 |
+
message = Message(type="prompt", content=prompt)
|
178 |
+
response = current_player.interface.respond_to_formatted(message, AnimalDescriptionModel)
|
179 |
|
180 |
self.player_responses.append({"sender": current_player.name, "response": response.description})
|
181 |
|
|
|
184 |
# Phase III: Chameleon Guesses the Animal
|
185 |
|
186 |
self.game_message("All players have spoken. The Chameleon will now guess the secret animal...")
|
187 |
+
if chameleon.interface.is_ai or self.observer:
|
188 |
self.verbose_message("The Chameleon is thinking...")
|
189 |
|
190 |
prompt = fetch_prompt("chameleon_guess_animal")
|
191 |
|
192 |
+
message = Message(type="prompt", content=prompt)
|
193 |
+
response = chameleon.interface.respond_to_formatted(message, ChameleonGuessAnimalModel)
|
194 |
|
195 |
chameleon_animal_guess = response.animal
|
196 |
|
|
|
200 |
player_votes = []
|
201 |
for player in self.players:
|
202 |
if player.role == "herd":
|
203 |
+
if player.interface.is_ai:
|
204 |
self.verbose_message(f"{player.name} is thinking...")
|
205 |
|
206 |
prompt = format_prompt("vote", player_responses=self.format_responses(exclude=player.name))
|
207 |
|
208 |
# Get Player Vote
|
209 |
+
message = Message(type="prompt", content=prompt)
|
210 |
+
response = player.interface.respond_to_formatted(message, VoteModel)
|
211 |
|
212 |
# check if a valid player was voted for...
|
213 |
|
214 |
# Add Vote to Player Votes
|
215 |
player_votes.append({"voter": player, "vote": response.vote})
|
216 |
+
if player.interface.is_ai:
|
217 |
self.debug_message(f"{player.name} voted for {response.vote}")
|
218 |
|
219 |
|
|
|
262 |
# Log Round Info
|
263 |
round_log = {
|
264 |
"herd_animal": herd_animal,
|
265 |
+
"chameleon_name": chameleon.name,
|
266 |
"chameleon_guess": chameleon_animal_guess,
|
267 |
"herd_votes": player_votes,
|
268 |
}
|
src/message.py
CHANGED
@@ -1,32 +1,13 @@
|
|
1 |
from typing import Literal
|
2 |
from pydantic import BaseModel, computed_field
|
3 |
|
4 |
-
"""
|
5 |
-
Right now we have two separate systems that use the word "message":
|
6 |
-
|
7 |
-
1. The Game class uses messages to communicate with the players
|
8 |
-
- "game" messages pile up in the queue and are responded to by the player once an "instructional" message is sent.
|
9 |
-
- "verbose", and "debug" currently for the human player only
|
10 |
-
This does **NOT** use the Message class defined below
|
11 |
-
|
12 |
-
2. The Player class uses messages to communicate with the controller (either the AI or the human)
|
13 |
-
- "prompt" type messages come from the Game and are responded to by the player.
|
14 |
-
- "retry", "error", and "format" are internal messages used by the player to ensure the correct format
|
15 |
-
- "player" is used to communicate with the AI or human player.
|
16 |
-
All of these messages are logged, and use the Message class defined below
|
17 |
-
|
18 |
-
For the future we should investigate redesigning/merging these two systems to avoid confusion
|
19 |
-
"""
|
20 |
-
|
21 |
-
MessageType = Literal["prompt", "player", "retry", "error", "format"]
|
22 |
|
23 |
class Message(BaseModel):
|
24 |
-
|
25 |
-
|
26 |
-
message_number: int
|
27 |
-
"""The number of the message, indicating the order in which it was sent."""
|
28 |
type: MessageType
|
29 |
-
"""The type of the message.
|
30 |
content: str
|
31 |
"""The content of the message."""
|
32 |
|
@@ -44,13 +25,35 @@ class Message(BaseModel):
|
|
44 |
else:
|
45 |
return "assistant"
|
46 |
|
47 |
-
@
|
48 |
-
def
|
49 |
-
"""Returns
|
50 |
-
return
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
def to_controller(self) -> tuple[str, str]:
|
53 |
-
"""Returns the message in a format that can be used by the controller."""
|
54 |
-
return self.conversation_role, self.content
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Literal
|
2 |
from pydantic import BaseModel, computed_field
|
3 |
|
4 |
+
MessageType = Literal["prompt", "info", "agent", "retry", "error", "format", "verbose", "debug"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class Message(BaseModel):
|
7 |
+
"""A generic message, these are used to communicate between the game and the players."""
|
8 |
+
|
|
|
|
|
9 |
type: MessageType
|
10 |
+
"""The type of the message."""
|
11 |
content: str
|
12 |
"""The content of the message."""
|
13 |
|
|
|
25 |
else:
|
26 |
return "assistant"
|
27 |
|
28 |
+
@property
|
29 |
+
def requires_response(self) -> bool:
|
30 |
+
"""Returns True if the message requires a response."""
|
31 |
+
return self.type in ["prompt", "retry", "format"]
|
32 |
+
|
33 |
+
def to_openai(self) -> dict[str, str]:
|
34 |
+
"""Returns the message in an OpenAI API compatible format."""
|
35 |
+
return {"role": self.conversation_role, "content": self.content}
|
36 |
|
|
|
|
|
|
|
37 |
|
38 |
+
class AgentMessage(Message):
|
39 |
+
"""A message bound to a specific agent, this happens when an agent receives a message from the game."""
|
40 |
+
|
41 |
+
agent_id: str
|
42 |
+
"""The id of the controller that the message was sent by/to."""
|
43 |
+
message_number: int
|
44 |
+
"""The number of the message, indicating the order in which it was sent."""
|
45 |
|
46 |
+
@computed_field
|
47 |
+
def message_id(self) -> str:
|
48 |
+
"""Returns the message id in the format used by the LLM."""
|
49 |
+
return f"{self.agent_id}-{self.message_number}"
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def from_message(cls, message: Message, agent_id: str, message_number: int) -> "AgentMessage":
|
53 |
+
"""Creates an AgentMessage from a Message."""
|
54 |
+
return cls(
|
55 |
+
type=message.type,
|
56 |
+
content=message.content,
|
57 |
+
agent_id=agent_id,
|
58 |
+
message_number=message_number
|
59 |
+
)
|
src/player.py
CHANGED
@@ -1,18 +1,7 @@
|
|
1 |
-
import
|
2 |
-
from typing import Type, Literal, List
|
3 |
import logging
|
4 |
|
5 |
-
from
|
6 |
-
|
7 |
-
from langchain.output_parsers import PydanticOutputParser
|
8 |
-
from langchain_core.prompts import PromptTemplate
|
9 |
-
|
10 |
-
from langchain_core.exceptions import OutputParserException
|
11 |
-
|
12 |
-
from pydantic import BaseModel
|
13 |
-
|
14 |
-
from game_utils import log
|
15 |
-
from message import Message, MessageType
|
16 |
|
17 |
Role = Literal["chameleon", "herd"]
|
18 |
|
@@ -34,43 +23,12 @@ class Player:
|
|
34 |
def __init__(
|
35 |
self,
|
36 |
name: str,
|
37 |
-
|
38 |
-
|
39 |
-
player_id: str = None,
|
40 |
-
log_filepath: str = None
|
41 |
):
|
42 |
self.name = name
|
43 |
self.id = player_id
|
44 |
-
|
45 |
-
if controller_name == "human":
|
46 |
-
self.controller_type = "human"
|
47 |
-
else:
|
48 |
-
self.controller_type = "ai"
|
49 |
-
|
50 |
-
self.controller = controller
|
51 |
-
"""The controller for the player."""
|
52 |
-
self.log_filepath = log_filepath
|
53 |
-
"""The filepath to the log file. If None, no logs will be written."""
|
54 |
-
self.messages: list[Message] = []
|
55 |
-
"""The messages the player has sent and received."""
|
56 |
-
self.prompt_queue: List[str] = []
|
57 |
-
"""A queue of prompts to be added to the next prompt."""
|
58 |
-
|
59 |
-
if log_filepath:
|
60 |
-
player_info = {
|
61 |
-
"id": self.id,
|
62 |
-
"name": self.name,
|
63 |
-
"role": self.role,
|
64 |
-
"controller": {
|
65 |
-
"name": controller_name,
|
66 |
-
"type": self.controller_type
|
67 |
-
}
|
68 |
-
}
|
69 |
-
log(player_info, log_filepath)
|
70 |
-
|
71 |
-
# initialize the runnables
|
72 |
-
self.generate = RunnableLambda(self._generate)
|
73 |
-
self.format_output = RunnableLambda(self._output_formatter)
|
74 |
|
75 |
def assign_role(self, role: Role):
|
76 |
self.role = role
|
@@ -78,85 +36,3 @@ class Player:
|
|
78 |
self.rounds_played_as_chameleon += 1
|
79 |
elif role == "herd":
|
80 |
self.rounds_played_as_herd += 1
|
81 |
-
|
82 |
-
async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
|
83 |
-
"""Makes the player respond to a prompt. Returns the response in the specified format."""
|
84 |
-
if self.prompt_queue:
|
85 |
-
# If there are prompts in the queue, add them to the current prompt
|
86 |
-
prompt = "\n".join(self.prompt_queue + [prompt])
|
87 |
-
# Clear the prompt queue
|
88 |
-
self.prompt_queue = []
|
89 |
-
|
90 |
-
message = self.player_message("prompt", prompt)
|
91 |
-
output = await self.generate.ainvoke(message)
|
92 |
-
if self.controller_type == "ai":
|
93 |
-
retries = 0
|
94 |
-
try:
|
95 |
-
output = await self.format_output.ainvoke({"output_format": output_format})
|
96 |
-
except OutputParserException as e:
|
97 |
-
if retries < max_retries:
|
98 |
-
retries += 1
|
99 |
-
logger.warning(f"Player {self.id} failed to format response: {output} due to an exception: {e} \n\n Retrying {retries}/{max_retries}")
|
100 |
-
retry_message = self.player_message("retry", f"Error formatting response: {e} \n\n Please try again.")
|
101 |
-
self.add_to_history(retry_message)
|
102 |
-
output = await self.format_output.ainvoke({"output_format": output_format})
|
103 |
-
|
104 |
-
else:
|
105 |
-
error_message = self.player_message("error", f"Error formatting response: {e} \n\n Max retries reached.")
|
106 |
-
self.add_to_history(error_message)
|
107 |
-
logging.error(f"Max retries reached due to Error: {e}")
|
108 |
-
raise e
|
109 |
-
else:
|
110 |
-
# Convert the human message to the pydantic object format
|
111 |
-
field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
|
112 |
-
output = output_format.model_validate({field_name: output.content})
|
113 |
-
|
114 |
-
return output
|
115 |
-
|
116 |
-
def player_message(self, message_type: MessageType, content: str) -> Message:
|
117 |
-
"""Creates a message assigned to the player."""
|
118 |
-
return Message(player_id=self.id, message_number=len(self.messages), type=message_type, content=content)
|
119 |
-
|
120 |
-
|
121 |
-
def add_to_history(self, message: Message):
|
122 |
-
self.messages.append(message)
|
123 |
-
log(message.model_dump(), self.log_filepath)
|
124 |
-
|
125 |
-
def is_human(self):
|
126 |
-
return self.controller_type == "human"
|
127 |
-
|
128 |
-
def is_ai(self):
|
129 |
-
return not self.is_human()
|
130 |
-
|
131 |
-
async def _generate(self, message: Message):
|
132 |
-
"""Entry point for the Runnable generating responses, automatically logs the message."""
|
133 |
-
self.add_to_history(message)
|
134 |
-
|
135 |
-
# AI's need to be fed the whole message history, but humans can just go back and look at it
|
136 |
-
if self.controller_type == "human":
|
137 |
-
response = await self.controller.ainvoke(message.content)
|
138 |
-
else:
|
139 |
-
formatted_messages = [message.to_controller() for message in self.messages]
|
140 |
-
response = await self.controller.ainvoke(formatted_messages)
|
141 |
-
|
142 |
-
self.add_to_history(self.player_message("player", response.content))
|
143 |
-
|
144 |
-
return response
|
145 |
-
|
146 |
-
async def _output_formatter(self, inputs: dict):
|
147 |
-
"""Formats the output of the response."""
|
148 |
-
output_format: BaseModel = inputs["output_format"]
|
149 |
-
|
150 |
-
prompt_template = PromptTemplate.from_template(
|
151 |
-
"Please rewrite your previous response using the following format: \n\n{format_instructions}"
|
152 |
-
)
|
153 |
-
|
154 |
-
parser = PydanticOutputParser(pydantic_object=output_format)
|
155 |
-
|
156 |
-
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
|
157 |
-
|
158 |
-
message = self.player_message("format", prompt.text)
|
159 |
-
|
160 |
-
response = await self.generate.ainvoke(message)
|
161 |
-
|
162 |
-
return await parser.ainvoke(response)
|
|
|
1 |
+
from typing import Literal
|
|
|
2 |
import logging
|
3 |
|
4 |
+
from agent_interfaces import AgentInterface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
Role = Literal["chameleon", "herd"]
|
7 |
|
|
|
23 |
def __init__(
|
24 |
self,
|
25 |
name: str,
|
26 |
+
player_id: str,
|
27 |
+
interface: AgentInterface
|
|
|
|
|
28 |
):
|
29 |
self.name = name
|
30 |
self.id = player_id
|
31 |
+
self.interface = interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def assign_role(self, role: Role):
|
34 |
self.role = role
|
|
|
36 |
self.rounds_played_as_chameleon += 1
|
37 |
elif role == "herd":
|
38 |
self.rounds_played_as_herd += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|