Eric Botti commited on
Commit
250cc97
·
1 Parent(s): 5dbe83d

enhanced output formatting

Browse files
Files changed (5) hide show
  1. src/agent_interfaces.py +11 -29
  2. src/game.py +5 -6
  3. src/models.py +0 -69
  4. src/output_formats.py +92 -0
  5. src/prompts.py +2 -55
src/agent_interfaces.py CHANGED
@@ -5,17 +5,10 @@ from openai import OpenAI
5
  from colorama import Fore, Style
6
  from pydantic import BaseModel, ValidationError
7
 
 
8
  from message import Message, AgentMessage
9
  from data_collection import save
10
 
11
- FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
12
- Here is the output schema:
13
- ```
14
- {schema}
15
- ```
16
- """
17
-
18
-
19
  class BaseAgentInterface:
20
  """
21
  The interface that agents use to receive info from and interact with the game.
@@ -44,11 +37,11 @@ class BaseAgentInterface:
44
  self.add_message(response)
45
  return response
46
 
47
- def respond_to_formatted(self, message: Message, output_format: Type[BaseModel], max_retries = 3) -> Type[BaseModel]:
48
  """Adds a message to the message history, and generates a response matching the provided format."""
49
  initial_response = self.respond_to(message)
50
 
51
- reformat_message = Message(type="format", content=self._get_format_instructions(output_format))
52
 
53
  output = None
54
  retries = 0
@@ -56,11 +49,17 @@ class BaseAgentInterface:
56
  while not output and retries < max_retries:
57
  try:
58
  formatted_response = self.respond_to(reformat_message)
59
- output = output_format.model_validate_json(formatted_response.content)
60
  except ValidationError as e:
61
  if retries > max_retries:
62
  raise e
63
- reformat_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
 
 
 
 
 
 
64
  retries += 1
65
 
66
  return output
@@ -75,23 +74,6 @@ class BaseAgentInterface:
75
  def is_ai(self):
76
  return not self.is_human
77
 
78
- # This should probably be put on a theoretical output format class...
79
- # Or maybe on the Message class with a from format constructor
80
- @staticmethod
81
- def _get_format_instructions(output_format: Type[BaseModel]):
82
- schema = output_format.model_json_schema()
83
-
84
- reduced_schema = schema
85
- if "title" in reduced_schema:
86
- del reduced_schema["title"]
87
- if "type" in reduced_schema:
88
- del reduced_schema["type"]
89
-
90
- schema_str = json.dumps(reduced_schema, indent=4)
91
-
92
- return FORMAT_INSTRUCTIONS.format(schema=schema_str)
93
-
94
-
95
  AgentInterface = NewType("AgentInterface", BaseAgentInterface)
96
 
97
 
 
5
  from colorama import Fore, Style
6
  from pydantic import BaseModel, ValidationError
7
 
8
+ from output_formats import OutputFormat, OutputFormatModel
9
  from message import Message, AgentMessage
10
  from data_collection import save
11
 
 
 
 
 
 
 
 
 
12
  class BaseAgentInterface:
13
  """
14
  The interface that agents use to receive info from and interact with the game.
 
37
  self.add_message(response)
38
  return response
39
 
40
+ def respond_to_formatted(self, message: Message, output_format: OutputFormat, max_retries = 3) -> OutputFormatModel:
41
  """Adds a message to the message history, and generates a response matching the provided format."""
42
  initial_response = self.respond_to(message)
43
 
44
+ reformat_message = Message(type="format", content=output_format.get_format_instructions())
45
 
46
  output = None
47
  retries = 0
 
49
  while not output and retries < max_retries:
50
  try:
51
  formatted_response = self.respond_to(reformat_message)
52
+ output = output_format.output_format_model.model_validate_json(formatted_response.content)
53
  except ValidationError as e:
54
  if retries > max_retries:
55
  raise e
56
+ retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
57
+ if output_format.few_shot_examples:
58
+ self.add_message(retry_message)
59
+ reformat_message = Message(type="few_shot", content=output_format.get_few_shot()) # not implemented
60
+ else:
61
+ reformat_message = retry_message
62
+
63
  retries += 1
64
 
65
  return output
 
74
  def is_ai(self):
75
  return not self.is_human
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  AgentInterface = NewType("AgentInterface", BaseAgentInterface)
78
 
79
 
src/game.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional, Type
5
  from colorama import Fore, Style
6
 
7
  from game_utils import *
8
- from models import *
9
  from player import Player
10
  from prompts import fetch_prompt, format_prompt
11
  from message import Message
@@ -175,7 +175,7 @@ class Game:
175
 
176
  # Get Player Animal Description
177
  message = Message(type="prompt", content=prompt)
178
- response = current_player.interface.respond_to_formatted(message, AnimalDescriptionModel)
179
 
180
  self.player_responses.append({"sender": current_player.name, "response": response.description})
181
 
@@ -190,7 +190,7 @@ class Game:
190
  prompt = fetch_prompt("chameleon_guess_animal")
191
 
192
  message = Message(type="prompt", content=prompt)
193
- response = chameleon.interface.respond_to_formatted(message, ChameleonGuessAnimalModel)
194
 
195
  chameleon_animal_guess = response.animal
196
 
@@ -207,9 +207,8 @@ class Game:
207
 
208
  # Get Player Vote
209
  message = Message(type="prompt", content=prompt)
210
- response = player.interface.respond_to_formatted(message, VoteModel)
211
-
212
- # check if a valid player was voted for...
213
 
214
  # Add Vote to Player Votes
215
  player_votes.append({"voter": player, "vote": response.vote})
 
5
  from colorama import Fore, Style
6
 
7
  from game_utils import *
8
+ from output_formats import *
9
  from player import Player
10
  from prompts import fetch_prompt, format_prompt
11
  from message import Message
 
175
 
176
  # Get Player Animal Description
177
  message = Message(type="prompt", content=prompt)
178
+ response = current_player.interface.respond_to_formatted(message, OutputFormat(AnimalDescriptionFormat))
179
 
180
  self.player_responses.append({"sender": current_player.name, "response": response.description})
181
 
 
190
  prompt = fetch_prompt("chameleon_guess_animal")
191
 
192
  message = Message(type="prompt", content=prompt)
193
+ response = chameleon.interface.respond_to_formatted(message, OutputFormat(ChameleonGuessFormat))
194
 
195
  chameleon_animal_guess = response.animal
196
 
 
207
 
208
  # Get Player Vote
209
  message = Message(type="prompt", content=prompt)
210
+ player_names = [p.name for p in self.players]
211
+ response = player.interface.respond_to_formatted(message, OutputFormat(HerdVoteFormat, player_names))
 
212
 
213
  # Add Vote to Player Votes
214
  player_votes.append({"voter": player, "vote": response.vote})
src/models.py DELETED
@@ -1,69 +0,0 @@
1
- import random
2
- from typing import Annotated, Type, List
3
- import json
4
-
5
- from pydantic import BaseModel, field_validator
6
-
7
- MAX_DESCRIPTION_LEN = 10
8
-
9
-
10
- class AnimalDescriptionModel(BaseModel):
11
- # Define fields of our class here
12
- description: Annotated[str, "A brief description of the animal"]
13
-
14
- # @field_validator('description')
15
- # @classmethod
16
- # def check_starting_character(cls, v) -> str:
17
- # if not v[0].upper() == 'I':
18
- # raise ValueError("Description must begin with 'I'")
19
- # return v
20
- #
21
- # @field_validator('description')
22
- # @classmethod
23
- # def wordcount(cls, v) -> str:
24
- # count = len(v.split())
25
- # if count > MAX_DESCRIPTION_LEN:
26
- # raise ValueError(f"Animal Description must be {MAX_DESCRIPTION_LEN} words or less")
27
- # return v
28
-
29
-
30
- class ChameleonDecisionModel(BaseModel):
31
- will_guess: bool
32
-
33
-
34
- class AnimalGuessModel(BaseModel):
35
- animal_name: str
36
-
37
-
38
- class ChameleonGuessDecisionModel(BaseModel):
39
- decision: Annotated[str, "Must be one of: ['guess', 'pass']"]
40
-
41
- @field_validator('decision')
42
- @classmethod
43
- def check_decision(cls, v) -> str:
44
- if v.lower() not in ['guess', 'pass']:
45
- raise ValueError("Decision must be one of: ['guess', 'pass']")
46
- return v
47
-
48
-
49
- class ChameleonGuessAnimalModel(BaseModel):
50
- animal: Annotated[str, "The name of the animal the chameleon is guessing"]
51
-
52
- @field_validator('animal')
53
- @classmethod
54
- def is_one_word(cls, v) -> str:
55
- if len(v.split()) > 1:
56
- raise ValueError("Animal's name must be one word")
57
- return v
58
-
59
-
60
- class VoteModel(BaseModel):
61
- vote: Annotated[str, "The name of the player you are voting for"]
62
-
63
- # @field_validator('vote')
64
- # @classmethod
65
- # def check_player_exists(cls, v) -> str:
66
- # if v.lower() not in [player.lower() for player in players]:
67
- # raise ValueError(f"Player {v} does not exist")
68
- # return v
69
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/output_formats.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Annotated, NewType, List, Optional, Type, Literal
3
+ import json
4
+
5
+ from pydantic import BaseModel, field_validator, Field
6
+
7
+ MAX_DESCRIPTION_LEN = 10
8
+ FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
9
+ Here is the output format:
10
+ {schema}
11
+ """
12
+ FEW_SHOT_INSTRUCTIONS = """Here are a few examples of correctly formatted responses: \n
13
+ {examples}
14
+ """
15
+
16
+ OutputFormatModel = NewType("OutputFormatModel", BaseModel)
17
+
18
+
19
+ class OutputFormat:
20
+ """The base class for all output formats."""
21
+
22
+ format_instructions: str = FORMAT_INSTRUCTIONS
23
+ """Instructions for formatting the output, it is combined with the JSON schema of the output format."""
24
+ few_shot_instructions: str = FEW_SHOT_INSTRUCTIONS
25
+ """Instructions for the few shot examples, it is combined with the few shot examples."""
26
+ few_shot_examples: Optional[List[dict]] = None
27
+ """A list of examples to be shown to the agent to help them understand the desired format of the output."""
28
+
29
+ def __init__(self, output_format_model: Type[OutputFormatModel], player_names: List[str] = None):
30
+ self.output_format_model = output_format_model
31
+ self.output_format_model.player_names = player_names
32
+
33
+ def get_format_instructions(self) -> str:
34
+ json_format = self.output_format_model().model_dump_json()
35
+
36
+ return self.format_instructions.format(schema=json_format)
37
+
38
+ def get_few_shot(self, max_examples=3):
39
+ if len(self.few_shot_examples) <= max_examples:
40
+ examples = self.few_shot_examples
41
+ else:
42
+ examples = random.sample(self.few_shot_examples, max_examples)
43
+
44
+ few_shot = "\n\n".join([f"Example Response:\n{json.dumps(example)}" for example in examples])
45
+
46
+ return self.few_shot_instructions.format(examples=few_shot)
47
+
48
+
49
+ class AnimalDescriptionFormat(BaseModel):
50
+ # Define fields of our class here
51
+ description: str = Field("A brief description of the animal")
52
+ """A brief description of the animal"""
53
+
54
+ @field_validator('description')
55
+ @classmethod
56
+ def check_starting_character(cls, v) -> str:
57
+ if not v[0].upper() == 'I':
58
+ raise ValueError("Description must begin with 'I'")
59
+ return v
60
+
61
+ @field_validator('description')
62
+ @classmethod
63
+ def wordcount(cls, v) -> str:
64
+ count = len(v.split())
65
+ if count > MAX_DESCRIPTION_LEN:
66
+ raise ValueError(f"Animal Description must be {MAX_DESCRIPTION_LEN} words or less")
67
+ return v
68
+
69
+
70
+ class ChameleonGuessFormat(BaseModel):
71
+ animal: str = Field("The name of the animal you think the chameleon is")
72
+
73
+ @field_validator('animal')
74
+ @classmethod
75
+ def is_one_word(cls, v) -> str:
76
+ if len(v.split()) > 1:
77
+ raise ValueError("Animal's name must be one word")
78
+ return v
79
+
80
+
81
+ class HerdVoteFormat(BaseModel):
82
+ vote: str = Field("The name of the player you are voting for")
83
+ """The name of the player you are voting for"""
84
+ player_names: List[str] = Field([], exclude=True)
85
+ """The names of the players in the game"""
86
+
87
+ @field_validator('vote')
88
+ @classmethod
89
+ def check_player_exists(cls, v) -> str:
90
+ if v.lower() not in [player.lower() for player in cls.player_names]:
91
+ raise ValueError(f"Player {v} does not exist")
92
+ return v
src/prompts.py CHANGED
@@ -1,66 +1,13 @@
1
- from models import *
2
- from langchain.prompts.few_shot import FewShotPromptTemplate
3
- from langchain.prompts.prompt import PromptTemplate
4
-
5
-
6
  def fetch_prompt(prompt_name):
7
  """Fetches a static prompt."""
8
  return prompts[prompt_name]
 
 
9
  def format_prompt(prompt_name, **kwargs):
10
  """Fetches a template prompt and populates it."""
11
  return fetch_prompt(prompt_name).format(**kwargs)
12
 
13
 
14
- class Task:
15
- def __init__(self, prompt: str, response_format: Type[BaseModel], few_shot_examples: List[dict] = None):
16
- self.prompt = prompt
17
- self.response_format = response_format
18
- self.few_shot_examples = few_shot_examples
19
-
20
- def full_prompt(self, **kwargs):
21
- prompt = self.prompt.format(**kwargs)
22
-
23
- format_instructions = self.get_format_instructions()
24
- if self.few_shot_examples:
25
- few_shot = self.get_few_shot()
26
-
27
- def get_format_instructions(self):
28
- schema = self.get_input_schema()
29
- format_instructions = FORMAT_INSTRUCTIONS.format(schema=schema)
30
-
31
- return format_instructions
32
-
33
- def get_input_schema(self):
34
- schema = self.response_format.model_json_schema()
35
-
36
- reduced_schema = schema
37
- if "title" in reduced_schema:
38
- del reduced_schema["title"]
39
- if "type" in reduced_schema:
40
- del reduced_schema["type"]
41
-
42
- schema_str = json.dumps(reduced_schema, indent=4)
43
-
44
- return schema_str
45
-
46
- def get_few_shot(self, max_examples=3):
47
- if len(self.few_shot_examples) <= max_examples:
48
- examples = self.few_shot_examples
49
- else:
50
- examples = random.sample(self.few_shot_examples, max_examples)
51
-
52
- few_shot = "\n\n".join([self.format_example(ex) for ex in examples])
53
-
54
- return few_shot
55
-
56
- def format_example(self, example):
57
- ex_prompt = self.prompt.format(**example['inputs'])
58
- ex_response = example['response']
59
-
60
- return f"Prompt: {ex_prompt}\nResponse: {ex_response}"
61
-
62
-
63
-
64
  _game_rules = '''\
65
  You are playing a social deduction game where every player pretends the be the same animal.
66
  During the round players go around the room and make an "I"-statement as if they were the animal.
 
 
 
 
 
 
1
  def fetch_prompt(prompt_name):
2
  """Fetches a static prompt."""
3
  return prompts[prompt_name]
4
+
5
+
6
  def format_prompt(prompt_name, **kwargs):
7
  """Fetches a template prompt and populates it."""
8
  return fetch_prompt(prompt_name).format(**kwargs)
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  _game_rules = '''\
12
  You are playing a social deduction game where every player pretends the be the same animal.
13
  During the round players go around the room and make an "I"-statement as if they were the animal.