Eric Botti commited on
Commit
f596e58
·
1 Parent(s): 1e1d212

reconstructed players using langchain

Browse files
Files changed (7) hide show
  1. src/agents.py +0 -15
  2. src/controllers.py +28 -0
  3. src/game.py +25 -23
  4. src/models.py +13 -4
  5. src/parser.py +0 -96
  6. src/player.py +83 -47
  7. src/prompts.py +62 -7
src/agents.py DELETED
@@ -1,15 +0,0 @@
1
- from kani import Kani
2
-
3
- class LogMessagesKani(Kani):
4
- def __init__(self, engine, log_filepath: str = None, *args, **kwargs):
5
- super().__init__(engine, *args, **kwargs)
6
- self.log_filepath = log_filepath
7
-
8
- async def add_to_history(self, message, *args, **kwargs):
9
- await super().add_to_history(message, *args, **kwargs)
10
-
11
- # Logs Message to File
12
- if self.log_filepath:
13
- with open(self.log_filepath, "a+") as log_file:
14
- log_file.write(message.model_dump_json())
15
- log_file.write("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/controllers.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_core.runnables import RunnableLambda
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.messages import AIMessage
6
+
7
+
8
+ def player_input(prompt):
9
+ print(prompt)
10
+ # even though they are human, we still need to return an AIMessage, since the HumanMessages are from the GameMaster
11
+ response = AIMessage(content=input())
12
+ return response
13
+
14
+
15
+ def controller_from_name(name: str):
16
+ if name == "tgi":
17
+ return ChatOpenAI(
18
+ api_base=os.environ['HF_ENDPOINT_URL'] + "/v1/",
19
+ api_key=os.environ['HF_API_TOKEN']
20
+ )
21
+ elif name == "openai":
22
+ return ChatOpenAI(model="gpt-3.5-turbo")
23
+ # elif name == "ollama":
24
+ # return ollama_controller
25
+ elif name == "human":
26
+ return RunnableLambda(player_input)
27
+ else:
28
+ raise ValueError(f"Unknown controller name: {name}")
src/game.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
 
2
 
3
  from game_utils import *
4
  from models import *
5
  from player import Player
6
- from parser import ParserKani
7
  from prompts import fetch_prompt, format_prompt
8
 
9
  # Default Values
@@ -13,7 +13,6 @@ NUMBER_OF_PLAYERS = 5
13
  class Game:
14
  log_dir = os.path.join(os.pardir, "experiments")
15
  player_log_file = "{player_id}.jsonl"
16
- parser_log_file = "{game_id}-parser.jsonl"
17
  game_log_file = "{game_id}-game.jsonl"
18
 
19
  def __init__(
@@ -24,6 +23,9 @@ class Game:
24
 
25
  # Game ID
26
  self.game_id = game_id()
 
 
 
27
 
28
  # Gather Player Names
29
  if human_name:
@@ -44,7 +46,7 @@ class Game:
44
  controller = "human"
45
  else:
46
  name = ai_names.pop()
47
- controller = "ai"
48
 
49
  if self.chameleon_index == i:
50
  role = "chameleon"
@@ -63,13 +65,18 @@ class Game:
63
  # Game State
64
  self.player_responses = []
65
 
66
- # Parser
67
- parser_log_path = os.path.join(self.log_dir, self.parser_log_file.format(game_id=self.game_id))
68
- self.parser = ParserKani.default(parser_log_path)
69
-
70
- def format_responses(self) -> str:
71
  """Formats the responses of the players into a single string."""
72
- return "\n".join([f" - {response['sender']}: {response['response']}" for response in self.player_responses])
 
 
 
 
 
 
 
 
 
73
 
74
  def get_player_names(self) -> list[str]:
75
  """Returns the names of the players."""
@@ -93,19 +100,16 @@ class Game:
93
  prompt = format_prompt("herd_animal", animal=herd_animal, player_responses=self.format_responses())
94
 
95
  # Get Player Animal Description
96
- response = await player.respond_to(prompt)
97
- # Parse Animal Description
98
- output = await self.parser.parse(prompt, response, AnimalDescriptionModel)
99
 
100
- self.player_responses.append({"sender": player.name, "response": output.description})
101
 
102
  # Phase II: Chameleon Decides if they want to guess the animal (secretly)
103
  prompt = format_prompt("chameleon_guess_decision", player_responses=self.format_responses())
104
 
105
- response = await self.players[self.chameleon_index].respond_to(prompt)
106
- output = await self.parser.parse(prompt, response, ChameleonGuessDecisionModel)
107
 
108
- if output.decision == "guess":
109
  chameleon_will_guess = True
110
  else:
111
  chameleon_will_guess = False
@@ -115,10 +119,9 @@ class Game:
115
  # Chameleon Guesses Animal
116
  prompt = fetch_prompt("chameleon_guess_animal")
117
 
118
- response = await self.players[self.chameleon_index].respond_to(prompt)
119
- output = await self.parser.parse(prompt, response, ChameleonGuessAnimalModel)
120
 
121
- if output.animal == herd_animal:
122
  winner = "chameleon"
123
  else:
124
  winner = "herd"
@@ -130,14 +133,12 @@ class Game:
130
  prompt = format_prompt("vote", player_responses=self.format_responses())
131
 
132
  # Get Player Vote
133
- response = await player.respond_to(prompt)
134
- # Parse Vote
135
- output = await self.parser.parse(prompt, response, VoteModel)
136
 
137
  # check if a valid player was voted for...
138
 
139
  # Add Vote to Player Votes
140
- player_votes.append(output.vote)
141
 
142
  print(player_votes)
143
 
@@ -160,6 +161,7 @@ class Game:
160
  # Log Game Info
161
  game_log = {
162
  "game_id": self.game_id,
 
163
  "herd_animal": herd_animal,
164
  "number_of_players": len(self.players),
165
  "human_player": self.players[self.human_index].id if self.human_index else "None",
 
1
  import os
2
+ from datetime import datetime
3
 
4
  from game_utils import *
5
  from models import *
6
  from player import Player
 
7
  from prompts import fetch_prompt, format_prompt
8
 
9
  # Default Values
 
13
  class Game:
14
  log_dir = os.path.join(os.pardir, "experiments")
15
  player_log_file = "{player_id}.jsonl"
 
16
  game_log_file = "{game_id}-game.jsonl"
17
 
18
  def __init__(
 
23
 
24
  # Game ID
25
  self.game_id = game_id()
26
+ self.start_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
27
+ self.log_dir = os.path.join(self.log_dir, f"{self.start_time}-{self.game_id}")
28
+ os.makedirs(self.log_dir, exist_ok=True)
29
 
30
  # Gather Player Names
31
  if human_name:
 
46
  controller = "human"
47
  else:
48
  name = ai_names.pop()
49
+ controller = "openai"
50
 
51
  if self.chameleon_index == i:
52
  role = "chameleon"
 
65
  # Game State
66
  self.player_responses = []
67
 
68
+ def format_responses(self, exclude: str = None) -> str:
 
 
 
 
69
  """Formats the responses of the players into a single string."""
70
+ if len(self.player_responses) == 0:
71
+ return "None, you are the first player!"
72
+ else:
73
+ formatted_responses = ""
74
+ for response in self.player_responses:
75
+ # Used to exclude the player who is currently responding, so they don't vote for themselves like a fool
76
+ if response["sender"] != exclude:
77
+ formatted_responses += f" - {response['sender']}: {response['response']}\n"
78
+
79
+ return formatted_responses
80
 
81
  def get_player_names(self) -> list[str]:
82
  """Returns the names of the players."""
 
100
  prompt = format_prompt("herd_animal", animal=herd_animal, player_responses=self.format_responses())
101
 
102
  # Get Player Animal Description
103
+ response = await player.respond_to(prompt, AnimalDescriptionModel)
 
 
104
 
105
+ self.player_responses.append({"sender": player.name, "response": response.description})
106
 
107
  # Phase II: Chameleon Decides if they want to guess the animal (secretly)
108
  prompt = format_prompt("chameleon_guess_decision", player_responses=self.format_responses())
109
 
110
+ response = await self.players[self.chameleon_index].respond_to(prompt, ChameleonGuessDecisionModel)
 
111
 
112
+ if response.decision == "guess":
113
  chameleon_will_guess = True
114
  else:
115
  chameleon_will_guess = False
 
119
  # Chameleon Guesses Animal
120
  prompt = fetch_prompt("chameleon_guess_animal")
121
 
122
+ response = await self.players[self.chameleon_index].respond_to(prompt, ChameleonGuessAnimalModel)
 
123
 
124
+ if response.animal == herd_animal:
125
  winner = "chameleon"
126
  else:
127
  winner = "herd"
 
133
  prompt = format_prompt("vote", player_responses=self.format_responses())
134
 
135
  # Get Player Vote
136
+ response = await player.respond_to(prompt, VoteModel)
 
 
137
 
138
  # check if a valid player was voted for...
139
 
140
  # Add Vote to Player Votes
141
+ player_votes.append(response.vote)
142
 
143
  print(player_votes)
144
 
 
161
  # Log Game Info
162
  game_log = {
163
  "game_id": self.game_id,
164
+ "start_time": self.start_time,
165
  "herd_animal": herd_animal,
166
  "number_of_players": len(self.players),
167
  "human_player": self.players[self.human_index].id if self.human_index else "None",
src/models.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import Annotated
 
 
2
 
3
  from pydantic import BaseModel, field_validator
4
 
@@ -7,7 +9,7 @@ MAX_DESCRIPTION_LEN = 10
7
 
8
  class AnimalDescriptionModel(BaseModel):
9
  # Define fields of our class here
10
- description: str
11
 
12
  # @field_validator('description')
13
  # @classmethod
@@ -45,11 +47,18 @@ class ChameleonGuessDecisionModel(BaseModel):
45
 
46
 
47
  class ChameleonGuessAnimalModel(BaseModel):
48
- animal: str
 
 
 
 
 
 
 
49
 
50
 
51
  class VoteModel(BaseModel):
52
- vote: str
53
 
54
  # @field_validator('vote')
55
  # @classmethod
 
1
+ import random
2
+ from typing import Annotated, Type, List
3
+ import json
4
 
5
  from pydantic import BaseModel, field_validator
6
 
 
9
 
10
  class AnimalDescriptionModel(BaseModel):
11
  # Define fields of our class here
12
+ description: Annotated[str, "A brief description of the animal"]
13
 
14
  # @field_validator('description')
15
  # @classmethod
 
47
 
48
 
49
  class ChameleonGuessAnimalModel(BaseModel):
50
+ animal: Annotated[str, "The name of the animal the chameleon is guessing"]
51
+
52
+ @field_validator('animal')
53
+ @classmethod
54
+ def is_one_word(cls, v) -> str:
55
+ if len(v.split()) > 1:
56
+ raise ValueError("Animal's name must be one word")
57
+ return v
58
 
59
 
60
  class VoteModel(BaseModel):
61
+ vote: Annotated[str, "The name of the player you are voting for"]
62
 
63
  # @field_validator('vote')
64
  # @classmethod
src/parser.py DELETED
@@ -1,96 +0,0 @@
1
- from typing import Type
2
- import asyncio
3
- import json
4
-
5
- from kani.engines.openai import OpenAIEngine
6
- from pydantic import BaseModel, ValidationError
7
-
8
- from agents import LogMessagesKani
9
-
10
- FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
11
- Here is the output schema:
12
- ```
13
- {schema}
14
- ```
15
- """
16
-
17
- parser_prompt = """\
18
- The user gave the following output to the prompt:
19
- Prompt:
20
- {prompt}
21
- Output:
22
- {message}
23
-
24
- {format_instructions}
25
- """
26
-
27
-
28
- class ParserKani(LogMessagesKani):
29
- def __init__(self, engine, *args, **kwargs):
30
- super().__init__(engine, *args, **kwargs)
31
-
32
- async def parse(self, prompt: str, message: str, format_model: Type[BaseModel], max_retries: int = 3, **kwargs):
33
- format_instructions = self.get_format_instructions(format_model)
34
-
35
- parser_instructions = parser_prompt.format(
36
- prompt=prompt,
37
- message=message,
38
- format_instructions=format_instructions
39
- )
40
-
41
- response = await self.chat_round_str(parser_instructions, **kwargs)
42
-
43
- try:
44
- output = format_model.model_validate_json(response)
45
- except ValidationError as e:
46
- print(f"Output did not conform to the expected format: {e}")
47
- raise e
48
-
49
- # Clear the Chat History after successful parse
50
- self.chat_history = []
51
-
52
- return output
53
-
54
- @staticmethod
55
- def get_format_instructions(format_model: Type[BaseModel]):
56
- schema = format_model.model_json_schema()
57
-
58
- reduced_schema = schema
59
- if "title" in reduced_schema:
60
- del reduced_schema["title"]
61
- if "type" in reduced_schema:
62
- del reduced_schema["type"]
63
-
64
- schema_str = json.dumps(reduced_schema, indent=4)
65
-
66
- return FORMAT_INSTRUCTIONS.format(schema=schema_str)
67
-
68
- @classmethod
69
- def default(cls, log_filepath: str = None):
70
- """Default ParserKani with OpenAIEngine."""
71
- engine = OpenAIEngine(model="gpt-3.5-turbo")
72
- return cls(engine, log_filepath=log_filepath)
73
-
74
-
75
-
76
- # Testing
77
- # parser = ParserKani(engine=OpenAIEngine(model="gpt-3.5-turbo"))
78
- #
79
- # sample_prompt = """\
80
- # Below are the responses from all players. Now it is time to vote. Choose from the players below who you think the Chameleon is.
81
- # - Mallory: I am tall and have a long neck.
82
- # - Jack: I am a herbivore and have a long neck.
83
- # - Jill: I am a herbivore and have a long neck.
84
- # - Bob: I am tall and have a long neck.
85
- # - Courtney: I am tall and have a long neck.
86
- # """
87
- #
88
- # sample_message = """\
89
- # I think the Chameleon is Mallory.
90
- # """
91
- #
92
- # test_output = asyncio.run(parser.parse(prompt=sample_prompt, message=sample_message, format_model=VoteModel))
93
- #
94
- # print(test_output)
95
- #
96
- # print(VoteModel.model_validate_json(test_output))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/player.py CHANGED
@@ -1,44 +1,35 @@
1
  import os
2
- import json
3
- import asyncio
4
 
5
- import openai
6
- from agents import LogMessagesKani
7
- from kani import ChatMessage
8
- from kani.engines.openai import OpenAIEngine
9
 
10
- from game_utils import log
 
 
11
 
12
- # api_type = "tgi"
13
- api_type = "openai"
14
- # api_type = "ollama"
15
-
16
- match api_type:
17
- case "tgi":
18
- # Using TGI Inference Endpoints from Hugging Face
19
- default_engine = OpenAIEngine( # type: ignore
20
- api_base=os.environ['HF_ENDPOINT_URL'] + "/v1/",
21
- api_key=os.environ['HF_API_TOKEN']
22
- )
23
- case "openai":
24
- # Using OpenAI GPT-3.5 Turbo
25
- default_engine = OpenAIEngine(model="gpt-3.5-turbo") # type: ignore
26
- case "ollama":
27
- # Using Ollama
28
- default_engine = OpenAIEngine(
29
- api_base="http://localhost:11434/v1",
30
- api_key="ollama",
31
- model="mistral"
32
- )
33
 
 
 
34
 
35
  class Player:
36
- def __init__(self, name: str, controller_type: str, role: str, id: str = None, log_filepath: str = None):
 
 
 
 
 
 
 
37
  self.name = name
38
  self.id = id
39
- self.controller = controller_type
40
- if controller_type == "ai":
41
- self.kani = LogMessagesKani(default_engine, log_filepath=log_filepath)
 
 
 
 
42
 
43
  self.role = role
44
  self.messages = []
@@ -50,26 +41,71 @@ class Player:
50
  "id": self.id,
51
  "name": self.name,
52
  "role": self.role,
53
- "controller": controller_type,
 
 
 
54
  }
55
  log(player_info, log_filepath)
56
 
57
- async def respond_to(self, prompt: str) -> str:
58
- """Makes the player respond to a prompt. Returns the response."""
59
- if self.controller == "human":
60
- # We're pretending the human is an ai for logging purposes... I don't love this but it's fine for now
61
- log(ChatMessage.user(prompt).model_dump_json(), self.log_filepath)
62
- print(prompt)
63
- output = input()
64
- log(ChatMessage.assistant(output).model_dump_json(), self.log_filepath)
65
-
66
- return output
67
-
68
- elif self.controller == "ai":
69
- output = await self.kani.chat_round_str(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- return output
72
 
 
73
 
 
74
 
 
75
 
 
 
1
  import os
2
+ from typing import Type
 
3
 
4
+ from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
 
 
 
5
 
6
+ from langchain.output_parsers import PydanticOutputParser
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.messages import HumanMessage
9
 
10
+ from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ from game_utils import log
13
+ from controllers import controller_from_name
14
 
15
  class Player:
16
+ def __init__(
17
+ self,
18
+ name: str,
19
+ controller: str,
20
+ role: str,
21
+ id: str = None,
22
+ log_filepath: str = None
23
+ ):
24
  self.name = name
25
  self.id = id
26
+
27
+ if controller == "human":
28
+ self.controller_type = "human"
29
+ else:
30
+ self.controller_type = "ai"
31
+
32
+ self.controller = controller_from_name(controller)
33
 
34
  self.role = role
35
  self.messages = []
 
41
  "id": self.id,
42
  "name": self.name,
43
  "role": self.role,
44
+ "controller": {
45
+ "name": controller,
46
+ "type": self.controller_type
47
+ }
48
  }
49
  log(player_info, log_filepath)
50
 
51
+ # initialize the runnables
52
+ self.generate = RunnableLambda(self._generate)
53
+ self.format_output = RunnableLambda(self._output_formatter)
54
+
55
+ async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
56
+ """Makes the player respond to a prompt. Returns the response in the specified format."""
57
+ message = HumanMessage(content=prompt)
58
+ output = await self.generate.ainvoke(message)
59
+ if self.controller_type == "ai":
60
+ retries = 0
61
+ try:
62
+ output = await self.format_output.ainvoke({"output_format": output_format})
63
+ except ValueError as e:
64
+ if retries < max_retries:
65
+ self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
66
+ output = await self.format_output.ainvoke({"output_format": output_format})
67
+ retries += 1
68
+ else:
69
+ raise e
70
+ else:
71
+ # Convert the human message to the pydantic object format
72
+ field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
73
+ output = output_format.model_validate({field_name: output.content})
74
+
75
+ return output
76
+
77
+ def add_to_history(self, message):
78
+ self.messages.append(message)
79
+ # log(message.model_dump_json(), self.log_filepath)
80
+
81
+ def _generate(self, message: HumanMessage):
82
+ """Entry point for the Runnable generating responses, automatically logs the message."""
83
+ self.add_to_history(message)
84
+
85
+ # AI's need to be fed the whole message history, but humans can just go back and look at it
86
+ if self.controller_type == "human":
87
+ response = self.controller.invoke(message.content)
88
+ else:
89
+ response = self.controller.invoke(self.messages)
90
+
91
+ self.add_to_history(response)
92
+
93
+ return response
94
+
95
+ def _output_formatter(self, inputs: dict):
96
+ """Formats the output of the response."""
97
+ output_format: BaseModel = inputs["output_format"]
98
+
99
+ prompt_template = PromptTemplate.from_template(
100
+ "Please rewrite your previous response using the following format: \n\n{format_instructions}"
101
+ )
102
 
103
+ parser = PydanticOutputParser(pydantic_object=output_format)
104
 
105
+ prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
106
 
107
+ message = HumanMessage(content=prompt.text)
108
 
109
+ response = self.generate.invoke(message)
110
 
111
+ return parser.invoke(response)
src/prompts.py CHANGED
@@ -1,12 +1,66 @@
 
 
 
 
 
1
  def fetch_prompt(prompt_name):
2
  """Fetches a static prompt."""
3
  return prompts[prompt_name]
4
-
5
  def format_prompt(prompt_name, **kwargs):
6
  """Fetches a template prompt and populates it."""
7
  return fetch_prompt(prompt_name).format(**kwargs)
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  _game_rules = '''\
11
  GAME RULES: You are playing a social deduction game where every player pretends the be the same animal.
12
  During the round players go around the room and make an "I"-statement as if they were the animal.
@@ -17,11 +71,11 @@ The other players must be careful not to give away too much information with the
17
 
18
  _herd_animal = """\
19
  You are a {animal}, keep this a secret at all costs.
20
- In 10 words or less give a description of yourself starting with "I". The description should not give away too much information about the {animal} as you do not want the Chameleon to be able to guess what animal you are. Do not repeat responses from other players.
21
- If the Chameleon can guess what animal you really are you will LOSE.
22
  Previously Mentioned Descriptions:
23
  {player_responses}
24
- """
25
 
26
  _chameleon_animal = """\
27
  You are the Chameleon, keep this a secret at all costs.
@@ -31,7 +85,7 @@ If no one else has said anything try to say something generic that could be true
31
  If the other players realize you are the Chameleon you will LOSE.
32
  Previously Mentioned Descriptions:
33
  {player_responses}
34
- """
35
 
36
  _all_responses = """\
37
  Below are the responses from all the other players.
@@ -42,6 +96,8 @@ _chameleon_guess_decision = """\
42
  You now have the opportunity to guess what animal the other players are pretending to be.
43
  If you guess correctly you will WIN, if you guess incorrectly you will LOSE.
44
  If you believe you know what animal the other players are pretending to be make choose to GUESS, otherwise choose to PASS.
 
 
45
  """
46
 
47
  _chameleon_guess_animal = """\
@@ -53,10 +109,9 @@ Now it is time to vote. Choose from the players above who you think the Chameleo
53
  """
54
 
55
  prompts = {
56
- "herd_animal": _game_rules + _herd_animal,
57
  "chameleon_animal": _game_rules + _chameleon_animal,
58
  "chameleon_guess_decision": _all_responses + _chameleon_guess_decision,
59
  "chameleon_guess_animal": _chameleon_guess_animal,
60
  "vote": _all_responses + _vote_prompt
61
  }
62
-
 
1
+ from models import *
2
+ from langchain.prompts.few_shot import FewShotPromptTemplate
3
+ from langchain.prompts.prompt import PromptTemplate
4
+
5
+
6
  def fetch_prompt(prompt_name):
7
  """Fetches a static prompt."""
8
  return prompts[prompt_name]
 
9
  def format_prompt(prompt_name, **kwargs):
10
  """Fetches a template prompt and populates it."""
11
  return fetch_prompt(prompt_name).format(**kwargs)
12
 
13
 
14
+ class Task:
15
+ def __init__(self, prompt: str, response_format: Type[BaseModel], few_shot_examples: List[dict] = None):
16
+ self.prompt = prompt
17
+ self.response_format = response_format
18
+ self.few_shot_examples = few_shot_examples
19
+
20
+ def full_prompt(self, **kwargs):
21
+ prompt = self.prompt.format(**kwargs)
22
+
23
+ format_instructions = self.get_format_instructions()
24
+ if self.few_shot_examples:
25
+ few_shot = self.get_few_shot()
26
+
27
+ def get_format_instructions(self):
28
+ schema = self.get_input_schema()
29
+ format_instructions = FORMAT_INSTRUCTIONS.format(schema=schema)
30
+
31
+ return format_instructions
32
+
33
+ def get_input_schema(self):
34
+ schema = self.response_format.model_json_schema()
35
+
36
+ reduced_schema = schema
37
+ if "title" in reduced_schema:
38
+ del reduced_schema["title"]
39
+ if "type" in reduced_schema:
40
+ del reduced_schema["type"]
41
+
42
+ schema_str = json.dumps(reduced_schema, indent=4)
43
+
44
+ return schema_str
45
+
46
+ def get_few_shot(self, max_examples=3):
47
+ if len(self.few_shot_examples) <= max_examples:
48
+ examples = self.few_shot_examples
49
+ else:
50
+ examples = random.sample(self.few_shot_examples, max_examples)
51
+
52
+ few_shot = "\n\n".join([self.format_example(ex) for ex in examples])
53
+
54
+ return few_shot
55
+
56
+ def format_example(self, example):
57
+ ex_prompt = self.prompt.format(**example['inputs'])
58
+ ex_response = example['response']
59
+
60
+ return f"Prompt: {ex_prompt}\nResponse: {ex_response}"
61
+
62
+
63
+
64
  _game_rules = '''\
65
  GAME RULES: You are playing a social deduction game where every player pretends the be the same animal.
66
  During the round players go around the room and make an "I"-statement as if they were the animal.
 
71
 
72
  _herd_animal = """\
73
  You are a {animal}, keep this a secret at all costs.
74
+ In as few words as possible describe of yourself starting with "I". Your description should be vague but true, \
75
+ since if the Chameleon can guess animal you are, you will LOSE. Do not repeat responses from other players.
76
  Previously Mentioned Descriptions:
77
  {player_responses}
78
+ Your Response: """
79
 
80
  _chameleon_animal = """\
81
  You are the Chameleon, keep this a secret at all costs.
 
85
  If the other players realize you are the Chameleon you will LOSE.
86
  Previously Mentioned Descriptions:
87
  {player_responses}
88
+ Your Response: """
89
 
90
  _all_responses = """\
91
  Below are the responses from all the other players.
 
96
  You now have the opportunity to guess what animal the other players are pretending to be.
97
  If you guess correctly you will WIN, if you guess incorrectly you will LOSE.
98
  If you believe you know what animal the other players are pretending to be make choose to GUESS, otherwise choose to PASS.
99
+ Your response should be one of ("GUESS", "PASS")
100
+ Your Response:
101
  """
102
 
103
  _chameleon_guess_animal = """\
 
109
  """
110
 
111
  prompts = {
112
+ "herd_animal": _game_rules + _herd_animal,
113
  "chameleon_animal": _game_rules + _chameleon_animal,
114
  "chameleon_guess_decision": _all_responses + _chameleon_guess_decision,
115
  "chameleon_guess_animal": _chameleon_guess_animal,
116
  "vote": _all_responses + _vote_prompt
117
  }