Eric Botti
commited on
Commit
·
250cc97
1
Parent(s):
5dbe83d
enhanced output formatting
Browse files- src/agent_interfaces.py +11 -29
- src/game.py +5 -6
- src/models.py +0 -69
- src/output_formats.py +92 -0
- src/prompts.py +2 -55
src/agent_interfaces.py
CHANGED
@@ -5,17 +5,10 @@ from openai import OpenAI
|
|
5 |
from colorama import Fore, Style
|
6 |
from pydantic import BaseModel, ValidationError
|
7 |
|
|
|
8 |
from message import Message, AgentMessage
|
9 |
from data_collection import save
|
10 |
|
11 |
-
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
12 |
-
Here is the output schema:
|
13 |
-
```
|
14 |
-
{schema}
|
15 |
-
```
|
16 |
-
"""
|
17 |
-
|
18 |
-
|
19 |
class BaseAgentInterface:
|
20 |
"""
|
21 |
The interface that agents use to receive info from and interact with the game.
|
@@ -44,11 +37,11 @@ class BaseAgentInterface:
|
|
44 |
self.add_message(response)
|
45 |
return response
|
46 |
|
47 |
-
def respond_to_formatted(self, message: Message, output_format:
|
48 |
"""Adds a message to the message history, and generates a response matching the provided format."""
|
49 |
initial_response = self.respond_to(message)
|
50 |
|
51 |
-
reformat_message = Message(type="format", content=
|
52 |
|
53 |
output = None
|
54 |
retries = 0
|
@@ -56,11 +49,17 @@ class BaseAgentInterface:
|
|
56 |
while not output and retries < max_retries:
|
57 |
try:
|
58 |
formatted_response = self.respond_to(reformat_message)
|
59 |
-
output = output_format.model_validate_json(formatted_response.content)
|
60 |
except ValidationError as e:
|
61 |
if retries > max_retries:
|
62 |
raise e
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
retries += 1
|
65 |
|
66 |
return output
|
@@ -75,23 +74,6 @@ class BaseAgentInterface:
|
|
75 |
def is_ai(self):
|
76 |
return not self.is_human
|
77 |
|
78 |
-
# This should probably be put on a theoretical output format class...
|
79 |
-
# Or maybe on the Message class with a from format constructor
|
80 |
-
@staticmethod
|
81 |
-
def _get_format_instructions(output_format: Type[BaseModel]):
|
82 |
-
schema = output_format.model_json_schema()
|
83 |
-
|
84 |
-
reduced_schema = schema
|
85 |
-
if "title" in reduced_schema:
|
86 |
-
del reduced_schema["title"]
|
87 |
-
if "type" in reduced_schema:
|
88 |
-
del reduced_schema["type"]
|
89 |
-
|
90 |
-
schema_str = json.dumps(reduced_schema, indent=4)
|
91 |
-
|
92 |
-
return FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
93 |
-
|
94 |
-
|
95 |
AgentInterface = NewType("AgentInterface", BaseAgentInterface)
|
96 |
|
97 |
|
|
|
5 |
from colorama import Fore, Style
|
6 |
from pydantic import BaseModel, ValidationError
|
7 |
|
8 |
+
from output_formats import OutputFormat, OutputFormatModel
|
9 |
from message import Message, AgentMessage
|
10 |
from data_collection import save
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class BaseAgentInterface:
|
13 |
"""
|
14 |
The interface that agents use to receive info from and interact with the game.
|
|
|
37 |
self.add_message(response)
|
38 |
return response
|
39 |
|
40 |
+
def respond_to_formatted(self, message: Message, output_format: OutputFormat, max_retries = 3) -> OutputFormatModel:
|
41 |
"""Adds a message to the message history, and generates a response matching the provided format."""
|
42 |
initial_response = self.respond_to(message)
|
43 |
|
44 |
+
reformat_message = Message(type="format", content=output_format.get_format_instructions())
|
45 |
|
46 |
output = None
|
47 |
retries = 0
|
|
|
49 |
while not output and retries < max_retries:
|
50 |
try:
|
51 |
formatted_response = self.respond_to(reformat_message)
|
52 |
+
output = output_format.output_format_model.model_validate_json(formatted_response.content)
|
53 |
except ValidationError as e:
|
54 |
if retries > max_retries:
|
55 |
raise e
|
56 |
+
retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
|
57 |
+
if output_format.few_shot_examples:
|
58 |
+
self.add_message(retry_message)
|
59 |
+
reformat_message = Message(type="few_shot", content=output_format.get_few_shot()) # not implemented
|
60 |
+
else:
|
61 |
+
reformat_message = retry_message
|
62 |
+
|
63 |
retries += 1
|
64 |
|
65 |
return output
|
|
|
74 |
def is_ai(self):
|
75 |
return not self.is_human
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
AgentInterface = NewType("AgentInterface", BaseAgentInterface)
|
78 |
|
79 |
|
src/game.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Optional, Type
|
|
5 |
from colorama import Fore, Style
|
6 |
|
7 |
from game_utils import *
|
8 |
-
from
|
9 |
from player import Player
|
10 |
from prompts import fetch_prompt, format_prompt
|
11 |
from message import Message
|
@@ -175,7 +175,7 @@ class Game:
|
|
175 |
|
176 |
# Get Player Animal Description
|
177 |
message = Message(type="prompt", content=prompt)
|
178 |
-
response = current_player.interface.respond_to_formatted(message,
|
179 |
|
180 |
self.player_responses.append({"sender": current_player.name, "response": response.description})
|
181 |
|
@@ -190,7 +190,7 @@ class Game:
|
|
190 |
prompt = fetch_prompt("chameleon_guess_animal")
|
191 |
|
192 |
message = Message(type="prompt", content=prompt)
|
193 |
-
response = chameleon.interface.respond_to_formatted(message,
|
194 |
|
195 |
chameleon_animal_guess = response.animal
|
196 |
|
@@ -207,9 +207,8 @@ class Game:
|
|
207 |
|
208 |
# Get Player Vote
|
209 |
message = Message(type="prompt", content=prompt)
|
210 |
-
|
211 |
-
|
212 |
-
# check if a valid player was voted for...
|
213 |
|
214 |
# Add Vote to Player Votes
|
215 |
player_votes.append({"voter": player, "vote": response.vote})
|
|
|
5 |
from colorama import Fore, Style
|
6 |
|
7 |
from game_utils import *
|
8 |
+
from output_formats import *
|
9 |
from player import Player
|
10 |
from prompts import fetch_prompt, format_prompt
|
11 |
from message import Message
|
|
|
175 |
|
176 |
# Get Player Animal Description
|
177 |
message = Message(type="prompt", content=prompt)
|
178 |
+
response = current_player.interface.respond_to_formatted(message, OutputFormat(AnimalDescriptionFormat))
|
179 |
|
180 |
self.player_responses.append({"sender": current_player.name, "response": response.description})
|
181 |
|
|
|
190 |
prompt = fetch_prompt("chameleon_guess_animal")
|
191 |
|
192 |
message = Message(type="prompt", content=prompt)
|
193 |
+
response = chameleon.interface.respond_to_formatted(message, OutputFormat(ChameleonGuessFormat))
|
194 |
|
195 |
chameleon_animal_guess = response.animal
|
196 |
|
|
|
207 |
|
208 |
# Get Player Vote
|
209 |
message = Message(type="prompt", content=prompt)
|
210 |
+
player_names = [p.name for p in self.players]
|
211 |
+
response = player.interface.respond_to_formatted(message, OutputFormat(HerdVoteFormat, player_names))
|
|
|
212 |
|
213 |
# Add Vote to Player Votes
|
214 |
player_votes.append({"voter": player, "vote": response.vote})
|
src/models.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
import random
|
2 |
-
from typing import Annotated, Type, List
|
3 |
-
import json
|
4 |
-
|
5 |
-
from pydantic import BaseModel, field_validator
|
6 |
-
|
7 |
-
MAX_DESCRIPTION_LEN = 10
|
8 |
-
|
9 |
-
|
10 |
-
class AnimalDescriptionModel(BaseModel):
|
11 |
-
# Define fields of our class here
|
12 |
-
description: Annotated[str, "A brief description of the animal"]
|
13 |
-
|
14 |
-
# @field_validator('description')
|
15 |
-
# @classmethod
|
16 |
-
# def check_starting_character(cls, v) -> str:
|
17 |
-
# if not v[0].upper() == 'I':
|
18 |
-
# raise ValueError("Description must begin with 'I'")
|
19 |
-
# return v
|
20 |
-
#
|
21 |
-
# @field_validator('description')
|
22 |
-
# @classmethod
|
23 |
-
# def wordcount(cls, v) -> str:
|
24 |
-
# count = len(v.split())
|
25 |
-
# if count > MAX_DESCRIPTION_LEN:
|
26 |
-
# raise ValueError(f"Animal Description must be {MAX_DESCRIPTION_LEN} words or less")
|
27 |
-
# return v
|
28 |
-
|
29 |
-
|
30 |
-
class ChameleonDecisionModel(BaseModel):
|
31 |
-
will_guess: bool
|
32 |
-
|
33 |
-
|
34 |
-
class AnimalGuessModel(BaseModel):
|
35 |
-
animal_name: str
|
36 |
-
|
37 |
-
|
38 |
-
class ChameleonGuessDecisionModel(BaseModel):
|
39 |
-
decision: Annotated[str, "Must be one of: ['guess', 'pass']"]
|
40 |
-
|
41 |
-
@field_validator('decision')
|
42 |
-
@classmethod
|
43 |
-
def check_decision(cls, v) -> str:
|
44 |
-
if v.lower() not in ['guess', 'pass']:
|
45 |
-
raise ValueError("Decision must be one of: ['guess', 'pass']")
|
46 |
-
return v
|
47 |
-
|
48 |
-
|
49 |
-
class ChameleonGuessAnimalModel(BaseModel):
|
50 |
-
animal: Annotated[str, "The name of the animal the chameleon is guessing"]
|
51 |
-
|
52 |
-
@field_validator('animal')
|
53 |
-
@classmethod
|
54 |
-
def is_one_word(cls, v) -> str:
|
55 |
-
if len(v.split()) > 1:
|
56 |
-
raise ValueError("Animal's name must be one word")
|
57 |
-
return v
|
58 |
-
|
59 |
-
|
60 |
-
class VoteModel(BaseModel):
|
61 |
-
vote: Annotated[str, "The name of the player you are voting for"]
|
62 |
-
|
63 |
-
# @field_validator('vote')
|
64 |
-
# @classmethod
|
65 |
-
# def check_player_exists(cls, v) -> str:
|
66 |
-
# if v.lower() not in [player.lower() for player in players]:
|
67 |
-
# raise ValueError(f"Player {v} does not exist")
|
68 |
-
# return v
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/output_formats.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Annotated, NewType, List, Optional, Type, Literal
|
3 |
+
import json
|
4 |
+
|
5 |
+
from pydantic import BaseModel, field_validator, Field
|
6 |
+
|
7 |
+
MAX_DESCRIPTION_LEN = 10
|
8 |
+
FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
|
9 |
+
Here is the output format:
|
10 |
+
{schema}
|
11 |
+
"""
|
12 |
+
FEW_SHOT_INSTRUCTIONS = """Here are a few examples of correctly formatted responses: \n
|
13 |
+
{examples}
|
14 |
+
"""
|
15 |
+
|
16 |
+
OutputFormatModel = NewType("OutputFormatModel", BaseModel)
|
17 |
+
|
18 |
+
|
19 |
+
class OutputFormat:
|
20 |
+
"""The base class for all output formats."""
|
21 |
+
|
22 |
+
format_instructions: str = FORMAT_INSTRUCTIONS
|
23 |
+
"""Instructions for formatting the output, it is combined with the JSON schema of the output format."""
|
24 |
+
few_shot_instructions: str = FEW_SHOT_INSTRUCTIONS
|
25 |
+
"""Instructions for the few shot examples, it is combined with the few shot examples."""
|
26 |
+
few_shot_examples: Optional[List[dict]] = None
|
27 |
+
"""A list of examples to be shown to the agent to help them understand the desired format of the output."""
|
28 |
+
|
29 |
+
def __init__(self, output_format_model: Type[OutputFormatModel], player_names: List[str] = None):
|
30 |
+
self.output_format_model = output_format_model
|
31 |
+
self.output_format_model.player_names = player_names
|
32 |
+
|
33 |
+
def get_format_instructions(self) -> str:
|
34 |
+
json_format = self.output_format_model().model_dump_json()
|
35 |
+
|
36 |
+
return self.format_instructions.format(schema=json_format)
|
37 |
+
|
38 |
+
def get_few_shot(self, max_examples=3):
|
39 |
+
if len(self.few_shot_examples) <= max_examples:
|
40 |
+
examples = self.few_shot_examples
|
41 |
+
else:
|
42 |
+
examples = random.sample(self.few_shot_examples, max_examples)
|
43 |
+
|
44 |
+
few_shot = "\n\n".join([f"Example Response:\n{json.dumps(example)}" for example in examples])
|
45 |
+
|
46 |
+
return self.few_shot_instructions.format(examples=few_shot)
|
47 |
+
|
48 |
+
|
49 |
+
class AnimalDescriptionFormat(BaseModel):
|
50 |
+
# Define fields of our class here
|
51 |
+
description: str = Field("A brief description of the animal")
|
52 |
+
"""A brief description of the animal"""
|
53 |
+
|
54 |
+
@field_validator('description')
|
55 |
+
@classmethod
|
56 |
+
def check_starting_character(cls, v) -> str:
|
57 |
+
if not v[0].upper() == 'I':
|
58 |
+
raise ValueError("Description must begin with 'I'")
|
59 |
+
return v
|
60 |
+
|
61 |
+
@field_validator('description')
|
62 |
+
@classmethod
|
63 |
+
def wordcount(cls, v) -> str:
|
64 |
+
count = len(v.split())
|
65 |
+
if count > MAX_DESCRIPTION_LEN:
|
66 |
+
raise ValueError(f"Animal Description must be {MAX_DESCRIPTION_LEN} words or less")
|
67 |
+
return v
|
68 |
+
|
69 |
+
|
70 |
+
class ChameleonGuessFormat(BaseModel):
|
71 |
+
animal: str = Field("The name of the animal you think the chameleon is")
|
72 |
+
|
73 |
+
@field_validator('animal')
|
74 |
+
@classmethod
|
75 |
+
def is_one_word(cls, v) -> str:
|
76 |
+
if len(v.split()) > 1:
|
77 |
+
raise ValueError("Animal's name must be one word")
|
78 |
+
return v
|
79 |
+
|
80 |
+
|
81 |
+
class HerdVoteFormat(BaseModel):
|
82 |
+
vote: str = Field("The name of the player you are voting for")
|
83 |
+
"""The name of the player you are voting for"""
|
84 |
+
player_names: List[str] = Field([], exclude=True)
|
85 |
+
"""The names of the players in the game"""
|
86 |
+
|
87 |
+
@field_validator('vote')
|
88 |
+
@classmethod
|
89 |
+
def check_player_exists(cls, v) -> str:
|
90 |
+
if v.lower() not in [player.lower() for player in cls.player_names]:
|
91 |
+
raise ValueError(f"Player {v} does not exist")
|
92 |
+
return v
|
src/prompts.py
CHANGED
@@ -1,66 +1,13 @@
|
|
1 |
-
from models import *
|
2 |
-
from langchain.prompts.few_shot import FewShotPromptTemplate
|
3 |
-
from langchain.prompts.prompt import PromptTemplate
|
4 |
-
|
5 |
-
|
6 |
def fetch_prompt(prompt_name):
|
7 |
"""Fetches a static prompt."""
|
8 |
return prompts[prompt_name]
|
|
|
|
|
9 |
def format_prompt(prompt_name, **kwargs):
|
10 |
"""Fetches a template prompt and populates it."""
|
11 |
return fetch_prompt(prompt_name).format(**kwargs)
|
12 |
|
13 |
|
14 |
-
class Task:
|
15 |
-
def __init__(self, prompt: str, response_format: Type[BaseModel], few_shot_examples: List[dict] = None):
|
16 |
-
self.prompt = prompt
|
17 |
-
self.response_format = response_format
|
18 |
-
self.few_shot_examples = few_shot_examples
|
19 |
-
|
20 |
-
def full_prompt(self, **kwargs):
|
21 |
-
prompt = self.prompt.format(**kwargs)
|
22 |
-
|
23 |
-
format_instructions = self.get_format_instructions()
|
24 |
-
if self.few_shot_examples:
|
25 |
-
few_shot = self.get_few_shot()
|
26 |
-
|
27 |
-
def get_format_instructions(self):
|
28 |
-
schema = self.get_input_schema()
|
29 |
-
format_instructions = FORMAT_INSTRUCTIONS.format(schema=schema)
|
30 |
-
|
31 |
-
return format_instructions
|
32 |
-
|
33 |
-
def get_input_schema(self):
|
34 |
-
schema = self.response_format.model_json_schema()
|
35 |
-
|
36 |
-
reduced_schema = schema
|
37 |
-
if "title" in reduced_schema:
|
38 |
-
del reduced_schema["title"]
|
39 |
-
if "type" in reduced_schema:
|
40 |
-
del reduced_schema["type"]
|
41 |
-
|
42 |
-
schema_str = json.dumps(reduced_schema, indent=4)
|
43 |
-
|
44 |
-
return schema_str
|
45 |
-
|
46 |
-
def get_few_shot(self, max_examples=3):
|
47 |
-
if len(self.few_shot_examples) <= max_examples:
|
48 |
-
examples = self.few_shot_examples
|
49 |
-
else:
|
50 |
-
examples = random.sample(self.few_shot_examples, max_examples)
|
51 |
-
|
52 |
-
few_shot = "\n\n".join([self.format_example(ex) for ex in examples])
|
53 |
-
|
54 |
-
return few_shot
|
55 |
-
|
56 |
-
def format_example(self, example):
|
57 |
-
ex_prompt = self.prompt.format(**example['inputs'])
|
58 |
-
ex_response = example['response']
|
59 |
-
|
60 |
-
return f"Prompt: {ex_prompt}\nResponse: {ex_response}"
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
_game_rules = '''\
|
65 |
You are playing a social deduction game where every player pretends the be the same animal.
|
66 |
During the round players go around the room and make an "I"-statement as if they were the animal.
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def fetch_prompt(prompt_name):
|
2 |
"""Fetches a static prompt."""
|
3 |
return prompts[prompt_name]
|
4 |
+
|
5 |
+
|
6 |
def format_prompt(prompt_name, **kwargs):
|
7 |
"""Fetches a template prompt and populates it."""
|
8 |
return fetch_prompt(prompt_name).format(**kwargs)
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
_game_rules = '''\
|
12 |
You are playing a social deduction game where every player pretends the be the same animal.
|
13 |
During the round players go around the room and make an "I"-statement as if they were the animal.
|