Eric Botti
commited on
Commit
·
47b6f03
1
Parent(s):
667d00f
output format descriptions and kwargs
Browse files- src/agent_interfaces.py +22 -10
- src/game.py +18 -10
- src/message.py +2 -2
- src/output_formats.py +23 -43
src/agent_interfaces.py
CHANGED
@@ -9,6 +9,7 @@ from output_formats import OutputFormat, OutputFormatModel
|
|
9 |
from message import Message, AgentMessage
|
10 |
from data_collection import save
|
11 |
|
|
|
12 |
class BaseAgentInterface:
|
13 |
"""
|
14 |
The interface that agents use to receive info from and interact with the game.
|
@@ -37,7 +38,13 @@ class BaseAgentInterface:
|
|
37 |
self.add_message(response)
|
38 |
return response
|
39 |
|
40 |
-
def respond_to_formatted(
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
"""Adds a message to the message history, and generates a response matching the provided format."""
|
42 |
initial_response = self.respond_to(message)
|
43 |
|
@@ -49,16 +56,18 @@ class BaseAgentInterface:
|
|
49 |
while not output and retries < max_retries:
|
50 |
try:
|
51 |
formatted_response = self.respond_to(reformat_message)
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
except ValidationError as e:
|
54 |
if retries > max_retries:
|
55 |
raise e
|
56 |
retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
|
57 |
-
|
58 |
-
self.add_message(retry_message)
|
59 |
-
reformat_message = Message(type="few_shot", content=output_format.get_few_shot()) # not implemented
|
60 |
-
else:
|
61 |
-
reformat_message = retry_message
|
62 |
|
63 |
retries += 1
|
64 |
|
@@ -74,11 +83,13 @@ class BaseAgentInterface:
|
|
74 |
def is_ai(self):
|
75 |
return not self.is_human
|
76 |
|
|
|
77 |
AgentInterface = NewType("AgentInterface", BaseAgentInterface)
|
78 |
|
79 |
|
80 |
class OpenAIAgentInterface(BaseAgentInterface):
|
81 |
"""An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
|
|
|
82 |
def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
|
83 |
super().__init__(agent_id)
|
84 |
self.model_name = model_name
|
@@ -97,15 +108,16 @@ class OpenAIAgentInterface(BaseAgentInterface):
|
|
97 |
|
98 |
|
99 |
class HumanAgentInterface(BaseAgentInterface):
|
100 |
-
|
101 |
is_human = True
|
102 |
|
103 |
def respond_to_formatted(self, message: Message, output_format: OutputFormat, **kwargs) -> OutputFormatModel:
|
104 |
"""For Human agents, we can trust them enough to format their own responses... for now"""
|
105 |
response = super().respond_to(message)
|
106 |
# only works because current outputs have only 1 field...
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
|
110 |
return output
|
111 |
|
|
|
9 |
from message import Message, AgentMessage
|
10 |
from data_collection import save
|
11 |
|
12 |
+
|
13 |
class BaseAgentInterface:
|
14 |
"""
|
15 |
The interface that agents use to receive info from and interact with the game.
|
|
|
38 |
self.add_message(response)
|
39 |
return response
|
40 |
|
41 |
+
def respond_to_formatted(
|
42 |
+
self,
|
43 |
+
message: Message,
|
44 |
+
output_format: OutputFormat,
|
45 |
+
max_retries=3,
|
46 |
+
**kwargs
|
47 |
+
) -> OutputFormatModel:
|
48 |
"""Adds a message to the message history, and generates a response matching the provided format."""
|
49 |
initial_response = self.respond_to(message)
|
50 |
|
|
|
56 |
while not output and retries < max_retries:
|
57 |
try:
|
58 |
formatted_response = self.respond_to(reformat_message)
|
59 |
+
if kwargs:
|
60 |
+
fields = json.loads(formatted_response.content)
|
61 |
+
fields.update(kwargs)
|
62 |
+
output = output_format.model_validate(fields)
|
63 |
+
else:
|
64 |
+
output = output_format.model_validate_json(formatted_response.content)
|
65 |
+
|
66 |
except ValidationError as e:
|
67 |
if retries > max_retries:
|
68 |
raise e
|
69 |
retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
|
70 |
+
reformat_message = retry_message
|
|
|
|
|
|
|
|
|
71 |
|
72 |
retries += 1
|
73 |
|
|
|
83 |
def is_ai(self):
|
84 |
return not self.is_human
|
85 |
|
86 |
+
|
87 |
AgentInterface = NewType("AgentInterface", BaseAgentInterface)
|
88 |
|
89 |
|
90 |
class OpenAIAgentInterface(BaseAgentInterface):
|
91 |
"""An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
|
92 |
+
|
93 |
def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
|
94 |
super().__init__(agent_id)
|
95 |
self.model_name = model_name
|
|
|
108 |
|
109 |
|
110 |
class HumanAgentInterface(BaseAgentInterface):
|
|
|
111 |
is_human = True
|
112 |
|
113 |
def respond_to_formatted(self, message: Message, output_format: OutputFormat, **kwargs) -> OutputFormatModel:
|
114 |
"""For Human agents, we can trust them enough to format their own responses... for now"""
|
115 |
response = super().respond_to(message)
|
116 |
# only works because current outputs have only 1 field...
|
117 |
+
fields = {output_format.model_fields.copy().popitem()[0], response.content}
|
118 |
+
if kwargs:
|
119 |
+
fields.update(kwargs)
|
120 |
+
output = output_format.model_validate(fields)
|
121 |
|
122 |
return output
|
123 |
|
src/game.py
CHANGED
@@ -13,7 +13,7 @@ from agent_interfaces import HumanAgentCLI, OpenAIAgentInterface
|
|
13 |
|
14 |
# Default Values
|
15 |
NUMBER_OF_PLAYERS = 6
|
16 |
-
WINNING_SCORE =
|
17 |
|
18 |
|
19 |
class Game:
|
@@ -131,7 +131,17 @@ class Game:
|
|
131 |
"""Sets up the game. This includes assigning roles and gathering player names."""
|
132 |
self.game_message(fetch_prompt("game_rules"))
|
133 |
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# Log Game Info
|
137 |
game_log = {
|
@@ -144,7 +154,7 @@ class Game:
|
|
144 |
|
145 |
log(game_log, game_log_path)
|
146 |
|
147 |
-
async def run_round(self):
|
148 |
"""Starts the round."""
|
149 |
|
150 |
# Phase I: Choose Animal and Assign Roles
|
@@ -175,7 +185,7 @@ class Game:
|
|
175 |
|
176 |
# Get Player Animal Description
|
177 |
message = Message(type="prompt", content=prompt)
|
178 |
-
response = current_player.interface.respond_to_formatted(message,
|
179 |
|
180 |
self.player_responses.append({"sender": current_player.name, "response": response.description})
|
181 |
|
@@ -190,7 +200,7 @@ class Game:
|
|
190 |
prompt = fetch_prompt("chameleon_guess_animal")
|
191 |
|
192 |
message = Message(type="prompt", content=prompt)
|
193 |
-
response = chameleon.interface.respond_to_formatted(message,
|
194 |
|
195 |
chameleon_animal_guess = response.animal
|
196 |
|
@@ -208,14 +218,13 @@ class Game:
|
|
208 |
# Get Player Vote
|
209 |
message = Message(type="prompt", content=prompt)
|
210 |
player_names = [p.name for p in self.players]
|
211 |
-
response = player.interface.respond_to_formatted(message,
|
212 |
|
213 |
# Add Vote to Player Votes
|
214 |
player_votes.append({"voter": player, "vote": response.vote})
|
215 |
if player.interface.is_ai:
|
216 |
self.debug_message(f"{player.name} voted for {response.vote}")
|
217 |
|
218 |
-
|
219 |
self.game_message("All players have voted!")
|
220 |
formatted_votes = '\n'.join([f'{vote["voter"].name}: {vote["vote"]}' for vote in player_votes])
|
221 |
self.game_message(f"Votes:\n{formatted_votes}")
|
@@ -225,7 +234,7 @@ class Game:
|
|
225 |
|
226 |
# Phase V: Assign Points
|
227 |
|
228 |
-
self.game_message(f"The round is over.
|
229 |
self.game_message(
|
230 |
f"The Chameleon was {chameleon.name}, and they guessed the secret animal was {chameleon_animal_guess}.")
|
231 |
self.game_message(f"The secret animal was actually was {herd_animal}.")
|
@@ -259,8 +268,7 @@ class Game:
|
|
259 |
|
260 |
self.game_message(f"Current Game Score: {player_points}")
|
261 |
|
262 |
-
|
263 |
-
round_log = {
|
264 |
"herd_animal": herd_animal,
|
265 |
"chameleon_name": chameleon.name,
|
266 |
"chameleon_guess": chameleon_animal_guess,
|
|
|
13 |
|
14 |
# Default Values
|
15 |
NUMBER_OF_PLAYERS = 6
|
16 |
+
WINNING_SCORE = 3
|
17 |
|
18 |
|
19 |
class Game:
|
|
|
131 |
"""Sets up the game. This includes assigning roles and gathering player names."""
|
132 |
self.game_message(fetch_prompt("game_rules"))
|
133 |
|
134 |
+
winner = None
|
135 |
+
round_number = 0
|
136 |
+
|
137 |
+
while not winner:
|
138 |
+
round_results = await self.run_round()
|
139 |
+
round_number += 1
|
140 |
+
|
141 |
+
# Check for a Winner
|
142 |
+
for player in self.players:
|
143 |
+
if player.points >= self.winning_score:
|
144 |
+
winner = player # ignoring the possibility of a tie for now
|
145 |
|
146 |
# Log Game Info
|
147 |
game_log = {
|
|
|
154 |
|
155 |
log(game_log, game_log_path)
|
156 |
|
157 |
+
async def run_round(self) -> dict:
|
158 |
"""Starts the round."""
|
159 |
|
160 |
# Phase I: Choose Animal and Assign Roles
|
|
|
185 |
|
186 |
# Get Player Animal Description
|
187 |
message = Message(type="prompt", content=prompt)
|
188 |
+
response = current_player.interface.respond_to_formatted(message, AnimalDescriptionFormat)
|
189 |
|
190 |
self.player_responses.append({"sender": current_player.name, "response": response.description})
|
191 |
|
|
|
200 |
prompt = fetch_prompt("chameleon_guess_animal")
|
201 |
|
202 |
message = Message(type="prompt", content=prompt)
|
203 |
+
response = chameleon.interface.respond_to_formatted(message, ChameleonGuessFormat)
|
204 |
|
205 |
chameleon_animal_guess = response.animal
|
206 |
|
|
|
218 |
# Get Player Vote
|
219 |
message = Message(type="prompt", content=prompt)
|
220 |
player_names = [p.name for p in self.players]
|
221 |
+
response = player.interface.respond_to_formatted(message, HerdVoteFormat, player_names=player_names)
|
222 |
|
223 |
# Add Vote to Player Votes
|
224 |
player_votes.append({"voter": player, "vote": response.vote})
|
225 |
if player.interface.is_ai:
|
226 |
self.debug_message(f"{player.name} voted for {response.vote}")
|
227 |
|
|
|
228 |
self.game_message("All players have voted!")
|
229 |
formatted_votes = '\n'.join([f'{vote["voter"].name}: {vote["vote"]}' for vote in player_votes])
|
230 |
self.game_message(f"Votes:\n{formatted_votes}")
|
|
|
234 |
|
235 |
# Phase V: Assign Points
|
236 |
|
237 |
+
self.game_message(f"The round is over. Calculating results...")
|
238 |
self.game_message(
|
239 |
f"The Chameleon was {chameleon.name}, and they guessed the secret animal was {chameleon_animal_guess}.")
|
240 |
self.game_message(f"The secret animal was actually was {herd_animal}.")
|
|
|
268 |
|
269 |
self.game_message(f"Current Game Score: {player_points}")
|
270 |
|
271 |
+
return {
|
|
|
272 |
"herd_animal": herd_animal,
|
273 |
"chameleon_name": chameleon.name,
|
274 |
"chameleon_guess": chameleon_animal_guess,
|
src/message.py
CHANGED
@@ -11,7 +11,7 @@ class Message(BaseModel):
|
|
11 |
content: str
|
12 |
"""The content of the message."""
|
13 |
|
14 |
-
@
|
15 |
def conversation_role(self) -> str:
|
16 |
"""The message type in the format used by the LLM."""
|
17 |
|
@@ -20,7 +20,7 @@ class Message(BaseModel):
|
|
20 |
# This can be counterintuitive since they can be controlled by either human or ai
|
21 |
# Further, The programmatic messages from the game are always "user"
|
22 |
|
23 |
-
if self.type
|
24 |
return "user"
|
25 |
else:
|
26 |
return "assistant"
|
|
|
11 |
content: str
|
12 |
"""The content of the message."""
|
13 |
|
14 |
+
@property
|
15 |
def conversation_role(self) -> str:
|
16 |
"""The message type in the format used by the LLM."""
|
17 |
|
|
|
20 |
# This can be counterintuitive since they can be controlled by either human or ai
|
21 |
# Further, The programmatic messages from the game are always "user"
|
22 |
|
23 |
+
if self.type != "agent":
|
24 |
return "user"
|
25 |
else:
|
26 |
return "assistant"
|
src/output_formats.py
CHANGED
@@ -2,52 +2,33 @@ import random
|
|
2 |
from typing import Annotated, NewType, List, Optional, Type, Literal
|
3 |
import json
|
4 |
|
5 |
-
from pydantic import BaseModel, field_validator, Field
|
6 |
|
7 |
FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
|
8 |
Here is the output format:
|
9 |
{schema}
|
10 |
"""
|
11 |
-
FEW_SHOT_INSTRUCTIONS = """Here are a few examples of correctly formatted responses: \n
|
12 |
-
{examples}
|
13 |
-
"""
|
14 |
-
|
15 |
-
OutputFormatModel = NewType("OutputFormatModel", BaseModel)
|
16 |
-
|
17 |
-
|
18 |
-
class OutputFormat:
|
19 |
-
"""The base class for all output formats."""
|
20 |
-
|
21 |
-
format_instructions: str = FORMAT_INSTRUCTIONS
|
22 |
-
"""Instructions for formatting the output, it is combined with the JSON schema of the output format."""
|
23 |
-
few_shot_instructions: str = FEW_SHOT_INSTRUCTIONS
|
24 |
-
"""Instructions for the few shot examples, it is combined with the few shot examples."""
|
25 |
-
few_shot_examples: Optional[List[dict]] = None
|
26 |
-
"""A list of examples to be shown to the agent to help them understand the desired format of the output."""
|
27 |
|
28 |
-
def __init__(self, output_format_model: Type[OutputFormatModel], player_names: List[str] = None):
|
29 |
-
self.output_format_model = output_format_model
|
30 |
-
self.output_format_model.player_names = player_names
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
examples = self.few_shot_examples
|
40 |
-
else:
|
41 |
-
examples = random.sample(self.few_shot_examples, max_examples)
|
42 |
|
43 |
-
few_shot = "\n\n".join([f"Example Response:\n{json.dumps(example)}" for example in examples])
|
44 |
|
45 |
-
|
46 |
|
47 |
|
48 |
-
class AnimalDescriptionFormat(
|
49 |
# Define fields of our class here
|
50 |
-
description: str = Field("A brief description of the animal")
|
51 |
"""A brief description of the animal"""
|
52 |
|
53 |
@field_validator('description')
|
@@ -58,8 +39,8 @@ class AnimalDescriptionFormat(BaseModel):
|
|
58 |
return v
|
59 |
|
60 |
|
61 |
-
class ChameleonGuessFormat(
|
62 |
-
animal: str = Field(
|
63 |
|
64 |
@field_validator('animal')
|
65 |
@classmethod
|
@@ -69,15 +50,14 @@ class ChameleonGuessFormat(BaseModel):
|
|
69 |
return v
|
70 |
|
71 |
|
72 |
-
class HerdVoteFormat(
|
73 |
player_names: List[str] = Field([], exclude=True)
|
74 |
"""The names of the players in the game"""
|
75 |
-
vote: str = Field("The name of the player you are voting for")
|
76 |
"""The name of the player you are voting for"""
|
77 |
|
78 |
-
@
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
return v
|
|
|
2 |
from typing import Annotated, NewType, List, Optional, Type, Literal
|
3 |
import json
|
4 |
|
5 |
+
from pydantic import BaseModel, field_validator, Field, model_validator
|
6 |
|
7 |
FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
|
8 |
Here is the output format:
|
9 |
{schema}
|
10 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
|
|
|
|
|
|
12 |
|
13 |
+
class OutputFormatModel(BaseModel):
|
14 |
+
@classmethod
|
15 |
+
def get_format_instructions(cls) -> str:
|
16 |
+
"""Returns a string with instructions on how to format the output."""
|
17 |
+
json_format = {}
|
18 |
+
for field in cls.model_fields:
|
19 |
+
if not cls.model_fields[field].exclude:
|
20 |
+
json_format[field] = cls.model_fields[field].description
|
21 |
|
22 |
+
# In the future, we could instead use get_annotations() to get the field descriptions
|
23 |
+
return FORMAT_INSTRUCTIONS.format(schema=json_format)
|
|
|
|
|
|
|
24 |
|
|
|
25 |
|
26 |
+
OutputFormat = NewType("OutputFormat", OutputFormatModel)
|
27 |
|
28 |
|
29 |
+
class AnimalDescriptionFormat(OutputFormatModel):
|
30 |
# Define fields of our class here
|
31 |
+
description: str = Field(description="A brief description of the animal")
|
32 |
"""A brief description of the animal"""
|
33 |
|
34 |
@field_validator('description')
|
|
|
39 |
return v
|
40 |
|
41 |
|
42 |
+
class ChameleonGuessFormat(OutputFormatModel):
|
43 |
+
animal: str = Field(description='Name of the animal you think the Herd is in its singular form, e.g. "animal" not "animals"')
|
44 |
|
45 |
@field_validator('animal')
|
46 |
@classmethod
|
|
|
50 |
return v
|
51 |
|
52 |
|
53 |
+
class HerdVoteFormat(OutputFormatModel):
|
54 |
player_names: List[str] = Field([], exclude=True)
|
55 |
"""The names of the players in the game"""
|
56 |
+
vote: str = Field(description="The name of the player you are voting for")
|
57 |
"""The name of the player you are voting for"""
|
58 |
|
59 |
+
@model_validator(mode="after")
|
60 |
+
def check_player_exists(self) -> "HerdVoteFormat":
|
61 |
+
if self.vote.lower() not in [player.lower() for player in self.player_names]:
|
62 |
+
raise ValueError(f"Player {self.vote} does not exist, please vote for one of {self.player_names}")
|
63 |
+
return self
|
|