Eric Botti commited on
Commit
47b6f03
·
1 Parent(s): 667d00f

output format descriptions and kwargs

Browse files
Files changed (4) hide show
  1. src/agent_interfaces.py +22 -10
  2. src/game.py +18 -10
  3. src/message.py +2 -2
  4. src/output_formats.py +23 -43
src/agent_interfaces.py CHANGED
@@ -9,6 +9,7 @@ from output_formats import OutputFormat, OutputFormatModel
9
  from message import Message, AgentMessage
10
  from data_collection import save
11
 
 
12
  class BaseAgentInterface:
13
  """
14
  The interface that agents use to receive info from and interact with the game.
@@ -37,7 +38,13 @@ class BaseAgentInterface:
37
  self.add_message(response)
38
  return response
39
 
40
- def respond_to_formatted(self, message: Message, output_format: OutputFormat, max_retries = 3) -> OutputFormatModel:
 
 
 
 
 
 
41
  """Adds a message to the message history, and generates a response matching the provided format."""
42
  initial_response = self.respond_to(message)
43
 
@@ -49,16 +56,18 @@ class BaseAgentInterface:
49
  while not output and retries < max_retries:
50
  try:
51
  formatted_response = self.respond_to(reformat_message)
52
- output = output_format.output_format_model.model_validate_json(formatted_response.content)
 
 
 
 
 
 
53
  except ValidationError as e:
54
  if retries > max_retries:
55
  raise e
56
  retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
57
- if output_format.few_shot_examples:
58
- self.add_message(retry_message)
59
- reformat_message = Message(type="few_shot", content=output_format.get_few_shot()) # not implemented
60
- else:
61
- reformat_message = retry_message
62
 
63
  retries += 1
64
 
@@ -74,11 +83,13 @@ class BaseAgentInterface:
74
  def is_ai(self):
75
  return not self.is_human
76
 
 
77
  AgentInterface = NewType("AgentInterface", BaseAgentInterface)
78
 
79
 
80
  class OpenAIAgentInterface(BaseAgentInterface):
81
  """An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
 
82
  def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
83
  super().__init__(agent_id)
84
  self.model_name = model_name
@@ -97,15 +108,16 @@ class OpenAIAgentInterface(BaseAgentInterface):
97
 
98
 
99
  class HumanAgentInterface(BaseAgentInterface):
100
-
101
  is_human = True
102
 
103
  def respond_to_formatted(self, message: Message, output_format: OutputFormat, **kwargs) -> OutputFormatModel:
104
  """For Human agents, we can trust them enough to format their own responses... for now"""
105
  response = super().respond_to(message)
106
  # only works because current outputs have only 1 field...
107
- field_name = output_format.output_format_model.model_fields.copy().popitem()[0]
108
- output = output_format.output_format_model.model_validate({field_name: response.content})
 
 
109
 
110
  return output
111
 
 
9
  from message import Message, AgentMessage
10
  from data_collection import save
11
 
12
+
13
  class BaseAgentInterface:
14
  """
15
  The interface that agents use to receive info from and interact with the game.
 
38
  self.add_message(response)
39
  return response
40
 
41
+ def respond_to_formatted(
42
+ self,
43
+ message: Message,
44
+ output_format: OutputFormat,
45
+ max_retries=3,
46
+ **kwargs
47
+ ) -> OutputFormatModel:
48
  """Adds a message to the message history, and generates a response matching the provided format."""
49
  initial_response = self.respond_to(message)
50
 
 
56
  while not output and retries < max_retries:
57
  try:
58
  formatted_response = self.respond_to(reformat_message)
59
+ if kwargs:
60
+ fields = json.loads(formatted_response.content)
61
+ fields.update(kwargs)
62
+ output = output_format.model_validate(fields)
63
+ else:
64
+ output = output_format.model_validate_json(formatted_response.content)
65
+
66
  except ValidationError as e:
67
  if retries > max_retries:
68
  raise e
69
  retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
70
+ reformat_message = retry_message
 
 
 
 
71
 
72
  retries += 1
73
 
 
83
  def is_ai(self):
84
  return not self.is_human
85
 
86
+
87
  AgentInterface = NewType("AgentInterface", BaseAgentInterface)
88
 
89
 
90
  class OpenAIAgentInterface(BaseAgentInterface):
91
  """An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
92
+
93
  def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
94
  super().__init__(agent_id)
95
  self.model_name = model_name
 
108
 
109
 
110
  class HumanAgentInterface(BaseAgentInterface):
 
111
  is_human = True
112
 
113
  def respond_to_formatted(self, message: Message, output_format: OutputFormat, **kwargs) -> OutputFormatModel:
114
  """For Human agents, we can trust them enough to format their own responses... for now"""
115
  response = super().respond_to(message)
116
  # only works because current outputs have only 1 field...
117
+ fields = {output_format.model_fields.copy().popitem()[0], response.content}
118
+ if kwargs:
119
+ fields.update(kwargs)
120
+ output = output_format.model_validate(fields)
121
 
122
  return output
123
 
src/game.py CHANGED
@@ -13,7 +13,7 @@ from agent_interfaces import HumanAgentCLI, OpenAIAgentInterface
13
 
14
  # Default Values
15
  NUMBER_OF_PLAYERS = 6
16
- WINNING_SCORE = 11
17
 
18
 
19
  class Game:
@@ -131,7 +131,17 @@ class Game:
131
  """Sets up the game. This includes assigning roles and gathering player names."""
132
  self.game_message(fetch_prompt("game_rules"))
133
 
134
- await self.run_round()
 
 
 
 
 
 
 
 
 
 
135
 
136
  # Log Game Info
137
  game_log = {
@@ -144,7 +154,7 @@ class Game:
144
 
145
  log(game_log, game_log_path)
146
 
147
- async def run_round(self):
148
  """Starts the round."""
149
 
150
  # Phase I: Choose Animal and Assign Roles
@@ -175,7 +185,7 @@ class Game:
175
 
176
  # Get Player Animal Description
177
  message = Message(type="prompt", content=prompt)
178
- response = current_player.interface.respond_to_formatted(message, OutputFormat(AnimalDescriptionFormat))
179
 
180
  self.player_responses.append({"sender": current_player.name, "response": response.description})
181
 
@@ -190,7 +200,7 @@ class Game:
190
  prompt = fetch_prompt("chameleon_guess_animal")
191
 
192
  message = Message(type="prompt", content=prompt)
193
- response = chameleon.interface.respond_to_formatted(message, OutputFormat(ChameleonGuessFormat))
194
 
195
  chameleon_animal_guess = response.animal
196
 
@@ -208,14 +218,13 @@ class Game:
208
  # Get Player Vote
209
  message = Message(type="prompt", content=prompt)
210
  player_names = [p.name for p in self.players]
211
- response = player.interface.respond_to_formatted(message, OutputFormat(HerdVoteFormat, player_names))
212
 
213
  # Add Vote to Player Votes
214
  player_votes.append({"voter": player, "vote": response.vote})
215
  if player.interface.is_ai:
216
  self.debug_message(f"{player.name} voted for {response.vote}")
217
 
218
-
219
  self.game_message("All players have voted!")
220
  formatted_votes = '\n'.join([f'{vote["voter"].name}: {vote["vote"]}' for vote in player_votes])
221
  self.game_message(f"Votes:\n{formatted_votes}")
@@ -225,7 +234,7 @@ class Game:
225
 
226
  # Phase V: Assign Points
227
 
228
- self.game_message(f"The round is over. Caclulating results...")
229
  self.game_message(
230
  f"The Chameleon was {chameleon.name}, and they guessed the secret animal was {chameleon_animal_guess}.")
231
  self.game_message(f"The secret animal was actually was {herd_animal}.")
@@ -259,8 +268,7 @@ class Game:
259
 
260
  self.game_message(f"Current Game Score: {player_points}")
261
 
262
- # Log Round Info
263
- round_log = {
264
  "herd_animal": herd_animal,
265
  "chameleon_name": chameleon.name,
266
  "chameleon_guess": chameleon_animal_guess,
 
13
 
14
  # Default Values
15
  NUMBER_OF_PLAYERS = 6
16
+ WINNING_SCORE = 3
17
 
18
 
19
  class Game:
 
131
  """Sets up the game. This includes assigning roles and gathering player names."""
132
  self.game_message(fetch_prompt("game_rules"))
133
 
134
+ winner = None
135
+ round_number = 0
136
+
137
+ while not winner:
138
+ round_results = await self.run_round()
139
+ round_number += 1
140
+
141
+ # Check for a Winner
142
+ for player in self.players:
143
+ if player.points >= self.winning_score:
144
+ winner = player # ignoring the possibility of a tie for now
145
 
146
  # Log Game Info
147
  game_log = {
 
154
 
155
  log(game_log, game_log_path)
156
 
157
+ async def run_round(self) -> dict:
158
  """Starts the round."""
159
 
160
  # Phase I: Choose Animal and Assign Roles
 
185
 
186
  # Get Player Animal Description
187
  message = Message(type="prompt", content=prompt)
188
+ response = current_player.interface.respond_to_formatted(message, AnimalDescriptionFormat)
189
 
190
  self.player_responses.append({"sender": current_player.name, "response": response.description})
191
 
 
200
  prompt = fetch_prompt("chameleon_guess_animal")
201
 
202
  message = Message(type="prompt", content=prompt)
203
+ response = chameleon.interface.respond_to_formatted(message, ChameleonGuessFormat)
204
 
205
  chameleon_animal_guess = response.animal
206
 
 
218
  # Get Player Vote
219
  message = Message(type="prompt", content=prompt)
220
  player_names = [p.name for p in self.players]
221
+ response = player.interface.respond_to_formatted(message, HerdVoteFormat, player_names=player_names)
222
 
223
  # Add Vote to Player Votes
224
  player_votes.append({"voter": player, "vote": response.vote})
225
  if player.interface.is_ai:
226
  self.debug_message(f"{player.name} voted for {response.vote}")
227
 
 
228
  self.game_message("All players have voted!")
229
  formatted_votes = '\n'.join([f'{vote["voter"].name}: {vote["vote"]}' for vote in player_votes])
230
  self.game_message(f"Votes:\n{formatted_votes}")
 
234
 
235
  # Phase V: Assign Points
236
 
237
+ self.game_message(f"The round is over. Calculating results...")
238
  self.game_message(
239
  f"The Chameleon was {chameleon.name}, and they guessed the secret animal was {chameleon_animal_guess}.")
240
  self.game_message(f"The secret animal was actually was {herd_animal}.")
 
268
 
269
  self.game_message(f"Current Game Score: {player_points}")
270
 
271
+ return {
 
272
  "herd_animal": herd_animal,
273
  "chameleon_name": chameleon.name,
274
  "chameleon_guess": chameleon_animal_guess,
src/message.py CHANGED
@@ -11,7 +11,7 @@ class Message(BaseModel):
11
  content: str
12
  """The content of the message."""
13
 
14
- @computed_field
15
  def conversation_role(self) -> str:
16
  """The message type in the format used by the LLM."""
17
 
@@ -20,7 +20,7 @@ class Message(BaseModel):
20
  # This can be counterintuitive since they can be controlled by either human or ai
21
  # Further, The programmatic messages from the game are always "user"
22
 
23
- if self.type in ["prompt", "retry", "error", "format"]:
24
  return "user"
25
  else:
26
  return "assistant"
 
11
  content: str
12
  """The content of the message."""
13
 
14
+ @property
15
  def conversation_role(self) -> str:
16
  """The message type in the format used by the LLM."""
17
 
 
20
  # This can be counterintuitive since they can be controlled by either human or ai
21
  # Further, The programmatic messages from the game are always "user"
22
 
23
+ if self.type != "agent":
24
  return "user"
25
  else:
26
  return "assistant"
src/output_formats.py CHANGED
@@ -2,52 +2,33 @@ import random
2
  from typing import Annotated, NewType, List, Optional, Type, Literal
3
  import json
4
 
5
- from pydantic import BaseModel, field_validator, Field
6
 
7
  FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
8
  Here is the output format:
9
  {schema}
10
  """
11
- FEW_SHOT_INSTRUCTIONS = """Here are a few examples of correctly formatted responses: \n
12
- {examples}
13
- """
14
-
15
- OutputFormatModel = NewType("OutputFormatModel", BaseModel)
16
-
17
-
18
- class OutputFormat:
19
- """The base class for all output formats."""
20
-
21
- format_instructions: str = FORMAT_INSTRUCTIONS
22
- """Instructions for formatting the output, it is combined with the JSON schema of the output format."""
23
- few_shot_instructions: str = FEW_SHOT_INSTRUCTIONS
24
- """Instructions for the few shot examples, it is combined with the few shot examples."""
25
- few_shot_examples: Optional[List[dict]] = None
26
- """A list of examples to be shown to the agent to help them understand the desired format of the output."""
27
 
28
- def __init__(self, output_format_model: Type[OutputFormatModel], player_names: List[str] = None):
29
- self.output_format_model = output_format_model
30
- self.output_format_model.player_names = player_names
31
 
32
- def get_format_instructions(self) -> str:
33
- json_format = self.output_format_model().model_dump_json()
34
-
35
- return self.format_instructions.format(schema=json_format)
 
 
 
 
36
 
37
- def get_few_shot(self, max_examples=3):
38
- if len(self.few_shot_examples) <= max_examples:
39
- examples = self.few_shot_examples
40
- else:
41
- examples = random.sample(self.few_shot_examples, max_examples)
42
 
43
- few_shot = "\n\n".join([f"Example Response:\n{json.dumps(example)}" for example in examples])
44
 
45
- return self.few_shot_instructions.format(examples=few_shot)
46
 
47
 
48
- class AnimalDescriptionFormat(BaseModel):
49
  # Define fields of our class here
50
- description: str = Field("A brief description of the animal")
51
  """A brief description of the animal"""
52
 
53
  @field_validator('description')
@@ -58,8 +39,8 @@ class AnimalDescriptionFormat(BaseModel):
58
  return v
59
 
60
 
61
- class ChameleonGuessFormat(BaseModel):
62
- animal: str = Field("The name of the animal you think the chameleon is")
63
 
64
  @field_validator('animal')
65
  @classmethod
@@ -69,15 +50,14 @@ class ChameleonGuessFormat(BaseModel):
69
  return v
70
 
71
 
72
- class HerdVoteFormat(BaseModel):
73
  player_names: List[str] = Field([], exclude=True)
74
  """The names of the players in the game"""
75
- vote: str = Field("The name of the player you are voting for")
76
  """The name of the player you are voting for"""
77
 
78
- @field_validator('vote')
79
- @classmethod
80
- def check_player_exists(cls, v) -> str:
81
- if v.lower() not in [player.lower() for player in cls.player_names]:
82
- raise ValueError(f"Player {v} does not exist")
83
- return v
 
2
  from typing import Annotated, NewType, List, Optional, Type, Literal
3
  import json
4
 
5
+ from pydantic import BaseModel, field_validator, Field, model_validator
6
 
7
  FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
8
  Here is the output format:
9
  {schema}
10
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
12
 
13
+ class OutputFormatModel(BaseModel):
14
+ @classmethod
15
+ def get_format_instructions(cls) -> str:
16
+ """Returns a string with instructions on how to format the output."""
17
+ json_format = {}
18
+ for field in cls.model_fields:
19
+ if not cls.model_fields[field].exclude:
20
+ json_format[field] = cls.model_fields[field].description
21
 
22
+ # In the future, we could instead use get_annotations() to get the field descriptions
23
+ return FORMAT_INSTRUCTIONS.format(schema=json_format)
 
 
 
24
 
 
25
 
26
+ OutputFormat = NewType("OutputFormat", OutputFormatModel)
27
 
28
 
29
+ class AnimalDescriptionFormat(OutputFormatModel):
30
  # Define fields of our class here
31
+ description: str = Field(description="A brief description of the animal")
32
  """A brief description of the animal"""
33
 
34
  @field_validator('description')
 
39
  return v
40
 
41
 
42
+ class ChameleonGuessFormat(OutputFormatModel):
43
+ animal: str = Field(description='Name of the animal you think the Herd is in its singular form, e.g. "animal" not "animals"')
44
 
45
  @field_validator('animal')
46
  @classmethod
 
50
  return v
51
 
52
 
53
+ class HerdVoteFormat(OutputFormatModel):
54
  player_names: List[str] = Field([], exclude=True)
55
  """The names of the players in the game"""
56
+ vote: str = Field(description="The name of the player you are voting for")
57
  """The name of the player you are voting for"""
58
 
59
+ @model_validator(mode="after")
60
+ def check_player_exists(self) -> "HerdVoteFormat":
61
+ if self.vote.lower() not in [player.lower() for player in self.player_names]:
62
+ raise ValueError(f"Player {self.vote} does not exist, please vote for one of {self.player_names}")
63
+ return self