Eric Botti commited on
Commit
975158e
·
1 Parent(s): 2acaa11

added output parsing for LLMs

Browse files
Files changed (4) hide show
  1. src/game.py +49 -42
  2. src/main.py +2 -1
  3. src/models.py +43 -0
  4. src/parser.py +95 -0
src/game.py CHANGED
@@ -1,6 +1,11 @@
1
- from game_utils import fetch_prompt, random_animal, random_names, random_index
 
 
 
2
  from player import Player
3
 
 
 
4
  # Default Values
5
  NUMBER_OF_PLAYERS = 5
6
 
@@ -37,18 +42,11 @@ class Game:
37
 
38
  self.players.append(Player(name, controller, role))
39
 
40
-
41
  self.player_responses = []
42
 
43
- print("Game Created")
44
-
45
-
46
- def broadcast(self, message):
47
- """Sends a message to all the players, no response required."""
48
- for player_index in range(0, len(self.players)):
49
- self.players[player_index].add_message(message)
50
- if self.human_index == player_index:
51
- print(message)
52
 
53
  def format_responses(self) -> str:
54
  """Formats the responses of the players into a single string."""
@@ -58,34 +56,48 @@ class Game:
58
  """Returns the names of the players."""
59
  return [player.name for player in self.players]
60
 
61
- def start(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  """Starts the game."""
63
- print("Welcome to Chameleon! This is a social deduction game powered by LLMs.")
64
 
65
  self.player_responses = []
66
  herd_animal = random_animal()
67
 
68
  # Collect Player Animal Descriptions
69
  for player in self.players:
70
- match player.role:
71
- case "herd":
72
- prompt_template = fetch_prompt("herd_animal")
73
- prompt = prompt_template.format(animal=herd_animal, player_responses=self.format_responses())
74
-
75
- case "chameleon":
76
- prompt_template = fetch_prompt("chameleon_animal")
77
- prompt = prompt_template.format(player_responses=self.format_responses())
78
 
79
- response = player.collect_input(prompt)
80
- self.player_responses.append({"sender": player.name, "response": response})
 
 
81
 
82
- self.player_votes = []
83
 
84
  # Show All Player Responses
85
- self.broadcast(self.format_responses())
86
 
87
  # Chameleon Decides if they want to guess the animal
88
- # TODO: Add Chameleon Guess Decision Logic
89
  chameleon_will_guess = False
90
 
91
  if chameleon_will_guess:
@@ -94,30 +106,25 @@ class Game:
94
  pass
95
  else:
96
  # All Players Vote for Chameleon
 
97
  for player in self.players:
98
  prompt_template = fetch_prompt("vote")
99
- prompt = prompt_template.format(players=self.get_player_names())
 
 
 
 
 
100
 
101
- response = player.collect_input(prompt)
102
- self.player_responses.append(response)
 
 
103
 
104
  # Assign Points
105
  # Chameleon Wins - 3 Points
106
  # Herd Wins by Failed Chameleon Guess - 1 Point (each)
107
  # Herd Wins by Correctly Guessing Chameleon - 2 points (each)
108
 
109
- @staticmethod
110
- def validate_animal_description(self, description: str) -> bool:
111
- """Validates that the description starts with I and is less than 10 words."""
112
- if not description.startswith("I"):
113
- return False
114
-
115
- if len(description.split(" ")) > 10:
116
- return False
117
 
118
- return True
119
 
120
- def validate_vote(self, vote: str) -> bool:
121
- """Validates that the vote is for a valid player."""
122
- player_names = [player.name.lower() for player in self.players]
123
- return vote.lower() in player_names
 
1
+ import asyncio
2
+
3
+ from game_utils import *
4
+ from models import *
5
  from player import Player
6
 
7
+ from parser import ParserKani
8
+
9
  # Default Values
10
  NUMBER_OF_PLAYERS = 5
11
 
 
42
 
43
  self.players.append(Player(name, controller, role))
44
 
45
+ # Game State
46
  self.player_responses = []
47
 
48
+ # Parser
49
+ self.parser = ParserKani.default()
 
 
 
 
 
 
 
50
 
51
  def format_responses(self) -> str:
52
  """Formats the responses of the players into a single string."""
 
56
  """Returns the names of the players."""
57
  return [player.name for player in self.players]
58
 
59
+ @staticmethod
60
+ def player_action(prompt: str, player, validator: callable = None):
61
+ """Prompts the player to take an action and validates the response."""
62
+ max_attempts = 3
63
+ response = player.respond_to(prompt)
64
+
65
+ if validator:
66
+ attempts = 0
67
+ while not validator(response):
68
+ attempts += 1
69
+ if attempts >= max_attempts:
70
+ raise ValueError(f"Player {player.name} did not provide a valid response to the following prompt:\n{prompt} Response: {response}")
71
+ response = player.respond_to(prompt)
72
+
73
+ return response
74
+
75
+ async def start(self):
76
  """Starts the game."""
77
+ # print("Welcome to Chameleon! This is a social deduction game powered by LLMs.")
78
 
79
  self.player_responses = []
80
  herd_animal = random_animal()
81
 
82
  # Collect Player Animal Descriptions
83
  for player in self.players:
84
+ if player.role == "chameleon":
85
+ prompt_template = fetch_prompt("chameleon_animal")
86
+ prompt = prompt_template.format(player_responses=self.format_responses())
87
+ else:
88
+ prompt_template = fetch_prompt("herd_animal")
89
+ prompt = prompt_template.format(animal=herd_animal, player_responses=self.format_responses())
 
 
90
 
91
+ # Get Player Animal Description
92
+ response = await player.respond_to(prompt)
93
+ # Parse Animal Description
94
+ output = await self.parser.parse(prompt, response, AnimalDescriptionModel)
95
 
96
+ self.player_responses.append({"sender": player.name, "response": output.description})
97
 
98
  # Show All Player Responses
 
99
 
100
  # Chameleon Decides if they want to guess the animal
 
101
  chameleon_will_guess = False
102
 
103
  if chameleon_will_guess:
 
106
  pass
107
  else:
108
  # All Players Vote for Chameleon
109
+ player_votes = []
110
  for player in self.players:
111
  prompt_template = fetch_prompt("vote")
112
+ prompt = prompt_template.format(player_responses=self.format_responses())
113
+
114
+ # Get Player Vote
115
+ response = await player.respond_to(prompt)
116
+ # Parse Vote
117
+ output = await self.parser.parse(prompt, response, VoteModel)
118
 
119
+ # Add Vote to Player Votes
120
+ player_votes.append(output.vote)
121
+
122
+ print(player_votes)
123
 
124
  # Assign Points
125
  # Chameleon Wins - 3 Points
126
  # Herd Wins by Failed Chameleon Guess - 1 Point (each)
127
  # Herd Wins by Correctly Guessing Chameleon - 2 points (each)
128
 
 
 
 
 
 
 
 
 
129
 
 
130
 
 
 
 
 
src/main.py CHANGED
@@ -1,4 +1,5 @@
1
  from game import Game
 
2
  from player import Player
3
 
4
  def main():
@@ -7,7 +8,7 @@ def main():
7
 
8
  game = Game(human_name=name)
9
 
10
- game.start()
11
 
12
 
13
  if __name__ == "__main__":
 
1
  from game import Game
2
+ import asyncio
3
  from player import Player
4
 
5
  def main():
 
8
 
9
  game = Game(human_name=name)
10
 
11
+ asyncio.run(game.start())
12
 
13
 
14
  if __name__ == "__main__":
src/models.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, field_validator
2
+
3
+ MAX_DESCRIPTION_LEN = 10
4
+
5
+
6
+ class AnimalDescriptionModel(BaseModel):
7
+ # Define fields of our class here
8
+ description: str
9
+
10
+ # @field_validator('description')
11
+ # @classmethod
12
+ # def check_starting_character(cls, v) -> str:
13
+ # if not v[0].upper() == 'I':
14
+ # raise ValueError("Description must begin with 'I'")
15
+ # return v
16
+ #
17
+ # @field_validator('description')
18
+ # @classmethod
19
+ # def wordcount(cls, v) -> str:
20
+ # count = len(v.split())
21
+ # if count > MAX_DESCRIPTION_LEN:
22
+ # raise ValueError(f"Animal Description must be {MAX_DESCRIPTION_LEN} words or less")
23
+ # return v
24
+
25
+
26
+ class ChameleonDecisionModel(BaseModel):
27
+ will_guess: bool
28
+
29
+
30
+ class AnimalGuessModel(BaseModel):
31
+ animal_name: str
32
+
33
+
34
+ class VoteModel(BaseModel):
35
+ vote: str
36
+
37
+ # @field_validator('vote')
38
+ # @classmethod
39
+ # def check_player_exists(cls, v) -> str:
40
+ # if v.lower() not in [player.lower() for player in players]:
41
+ # raise ValueError(f"Player {v} does not exist")
42
+ # return v
43
+
src/parser.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+ import asyncio
3
+ import json
4
+
5
+ from kani import Kani
6
+ from kani.engines.openai import OpenAIEngine
7
+ from pydantic import BaseModel, ValidationError
8
+
9
+ FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
10
+ Here is the output schema:
11
+ ```
12
+ {schema}
13
+ ```
14
+ """
15
+
16
+ parser_prompt = """\
17
+ The user gave the following output to the prompt:
18
+ Prompt:
19
+ {prompt}
20
+ Output:
21
+ {message}
22
+
23
+ {format_instructions}
24
+ """
25
+
26
+
27
+ class ParserKani(Kani):
28
+ def __init__(self, engine, *args, **kwargs):
29
+ super().__init__(engine, *args, **kwargs)
30
+
31
+ async def parse(self, prompt: str, message: str, format_model: Type[BaseModel], max_retries: int = 3, **kwargs):
32
+ format_instructions = self.get_format_instructions(format_model)
33
+
34
+ parser_instructions = parser_prompt.format(
35
+ prompt=prompt,
36
+ message=message,
37
+ format_instructions=format_instructions
38
+ )
39
+
40
+ response = await self.chat_round_str(parser_instructions, **kwargs)
41
+
42
+ try:
43
+ output = format_model.model_validate_json(response)
44
+ except ValidationError as e:
45
+ print(f"Output did not conform to the expected format: {e}")
46
+ raise e
47
+
48
+ # Clear the Chat History after successful parse
49
+ self.chat_history = []
50
+
51
+ return output
52
+
53
+ @staticmethod
54
+ def get_format_instructions(format_model: Type[BaseModel]):
55
+ schema = format_model.model_json_schema()
56
+
57
+ reduced_schema = schema
58
+ if "title" in reduced_schema:
59
+ del reduced_schema["title"]
60
+ if "type" in reduced_schema:
61
+ del reduced_schema["type"]
62
+
63
+ schema_str = json.dumps(reduced_schema, indent=4)
64
+
65
+ return FORMAT_INSTRUCTIONS.format(schema=schema_str)
66
+
67
+ @classmethod
68
+ def default(cls):
69
+ """Default ParserKani with OpenAIEngine."""
70
+ engine = OpenAIEngine(model="gpt-3.5-turbo")
71
+ return cls(engine)
72
+
73
+
74
+
75
+ # Testing
76
+ # parser = ParserKani(engine=OpenAIEngine(model="gpt-3.5-turbo"))
77
+ #
78
+ # sample_prompt = """\
79
+ # Below are the responses from all players. Now it is time to vote. Choose from the players below who you think the Chameleon is.
80
+ # - Mallory: I am tall and have a long neck.
81
+ # - Jack: I am a herbivore and have a long neck.
82
+ # - Jill: I am a herbivore and have a long neck.
83
+ # - Bob: I am tall and have a long neck.
84
+ # - Courtney: I am tall and have a long neck.
85
+ # """
86
+ #
87
+ # sample_message = """\
88
+ # I think the Chameleon is Mallory.
89
+ # """
90
+ #
91
+ # test_output = asyncio.run(parser.parse(prompt=sample_prompt, message=sample_message, format_model=VoteModel))
92
+ #
93
+ # print(test_output)
94
+ #
95
+ # print(VoteModel.model_validate_json(test_output))