Eric Botti
commited on
Commit
·
f596e58
1
Parent(s):
1e1d212
reconstructed players using langchain
Browse files- src/agents.py +0 -15
- src/controllers.py +28 -0
- src/game.py +25 -23
- src/models.py +13 -4
- src/parser.py +0 -96
- src/player.py +83 -47
- src/prompts.py +62 -7
src/agents.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
from kani import Kani
|
2 |
-
|
3 |
-
class LogMessagesKani(Kani):
|
4 |
-
def __init__(self, engine, log_filepath: str = None, *args, **kwargs):
|
5 |
-
super().__init__(engine, *args, **kwargs)
|
6 |
-
self.log_filepath = log_filepath
|
7 |
-
|
8 |
-
async def add_to_history(self, message, *args, **kwargs):
|
9 |
-
await super().add_to_history(message, *args, **kwargs)
|
10 |
-
|
11 |
-
# Logs Message to File
|
12 |
-
if self.log_filepath:
|
13 |
-
with open(self.log_filepath, "a+") as log_file:
|
14 |
-
log_file.write(message.model_dump_json())
|
15 |
-
log_file.write("\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/controllers.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from langchain_core.runnables import RunnableLambda
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from langchain_core.messages import AIMessage
|
6 |
+
|
7 |
+
|
8 |
+
def player_input(prompt):
|
9 |
+
print(prompt)
|
10 |
+
# even though they are human, we still need to return an AIMessage, since the HumanMessages are from the GameMaster
|
11 |
+
response = AIMessage(content=input())
|
12 |
+
return response
|
13 |
+
|
14 |
+
|
15 |
+
def controller_from_name(name: str):
|
16 |
+
if name == "tgi":
|
17 |
+
return ChatOpenAI(
|
18 |
+
api_base=os.environ['HF_ENDPOINT_URL'] + "/v1/",
|
19 |
+
api_key=os.environ['HF_API_TOKEN']
|
20 |
+
)
|
21 |
+
elif name == "openai":
|
22 |
+
return ChatOpenAI(model="gpt-3.5-turbo")
|
23 |
+
# elif name == "ollama":
|
24 |
+
# return ollama_controller
|
25 |
+
elif name == "human":
|
26 |
+
return RunnableLambda(player_input)
|
27 |
+
else:
|
28 |
+
raise ValueError(f"Unknown controller name: {name}")
|
src/game.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
from game_utils import *
|
4 |
from models import *
|
5 |
from player import Player
|
6 |
-
from parser import ParserKani
|
7 |
from prompts import fetch_prompt, format_prompt
|
8 |
|
9 |
# Default Values
|
@@ -13,7 +13,6 @@ NUMBER_OF_PLAYERS = 5
|
|
13 |
class Game:
|
14 |
log_dir = os.path.join(os.pardir, "experiments")
|
15 |
player_log_file = "{player_id}.jsonl"
|
16 |
-
parser_log_file = "{game_id}-parser.jsonl"
|
17 |
game_log_file = "{game_id}-game.jsonl"
|
18 |
|
19 |
def __init__(
|
@@ -24,6 +23,9 @@ class Game:
|
|
24 |
|
25 |
# Game ID
|
26 |
self.game_id = game_id()
|
|
|
|
|
|
|
27 |
|
28 |
# Gather Player Names
|
29 |
if human_name:
|
@@ -44,7 +46,7 @@ class Game:
|
|
44 |
controller = "human"
|
45 |
else:
|
46 |
name = ai_names.pop()
|
47 |
-
controller = "
|
48 |
|
49 |
if self.chameleon_index == i:
|
50 |
role = "chameleon"
|
@@ -63,13 +65,18 @@ class Game:
|
|
63 |
# Game State
|
64 |
self.player_responses = []
|
65 |
|
66 |
-
|
67 |
-
parser_log_path = os.path.join(self.log_dir, self.parser_log_file.format(game_id=self.game_id))
|
68 |
-
self.parser = ParserKani.default(parser_log_path)
|
69 |
-
|
70 |
-
def format_responses(self) -> str:
|
71 |
"""Formats the responses of the players into a single string."""
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
def get_player_names(self) -> list[str]:
|
75 |
"""Returns the names of the players."""
|
@@ -93,19 +100,16 @@ class Game:
|
|
93 |
prompt = format_prompt("herd_animal", animal=herd_animal, player_responses=self.format_responses())
|
94 |
|
95 |
# Get Player Animal Description
|
96 |
-
response = await player.respond_to(prompt)
|
97 |
-
# Parse Animal Description
|
98 |
-
output = await self.parser.parse(prompt, response, AnimalDescriptionModel)
|
99 |
|
100 |
-
self.player_responses.append({"sender": player.name, "response":
|
101 |
|
102 |
# Phase II: Chameleon Decides if they want to guess the animal (secretly)
|
103 |
prompt = format_prompt("chameleon_guess_decision", player_responses=self.format_responses())
|
104 |
|
105 |
-
response = await self.players[self.chameleon_index].respond_to(prompt)
|
106 |
-
output = await self.parser.parse(prompt, response, ChameleonGuessDecisionModel)
|
107 |
|
108 |
-
if
|
109 |
chameleon_will_guess = True
|
110 |
else:
|
111 |
chameleon_will_guess = False
|
@@ -115,10 +119,9 @@ class Game:
|
|
115 |
# Chameleon Guesses Animal
|
116 |
prompt = fetch_prompt("chameleon_guess_animal")
|
117 |
|
118 |
-
response = await self.players[self.chameleon_index].respond_to(prompt)
|
119 |
-
output = await self.parser.parse(prompt, response, ChameleonGuessAnimalModel)
|
120 |
|
121 |
-
if
|
122 |
winner = "chameleon"
|
123 |
else:
|
124 |
winner = "herd"
|
@@ -130,14 +133,12 @@ class Game:
|
|
130 |
prompt = format_prompt("vote", player_responses=self.format_responses())
|
131 |
|
132 |
# Get Player Vote
|
133 |
-
response = await player.respond_to(prompt)
|
134 |
-
# Parse Vote
|
135 |
-
output = await self.parser.parse(prompt, response, VoteModel)
|
136 |
|
137 |
# check if a valid player was voted for...
|
138 |
|
139 |
# Add Vote to Player Votes
|
140 |
-
player_votes.append(
|
141 |
|
142 |
print(player_votes)
|
143 |
|
@@ -160,6 +161,7 @@ class Game:
|
|
160 |
# Log Game Info
|
161 |
game_log = {
|
162 |
"game_id": self.game_id,
|
|
|
163 |
"herd_animal": herd_animal,
|
164 |
"number_of_players": len(self.players),
|
165 |
"human_player": self.players[self.human_index].id if self.human_index else "None",
|
|
|
1 |
import os
|
2 |
+
from datetime import datetime
|
3 |
|
4 |
from game_utils import *
|
5 |
from models import *
|
6 |
from player import Player
|
|
|
7 |
from prompts import fetch_prompt, format_prompt
|
8 |
|
9 |
# Default Values
|
|
|
13 |
class Game:
|
14 |
log_dir = os.path.join(os.pardir, "experiments")
|
15 |
player_log_file = "{player_id}.jsonl"
|
|
|
16 |
game_log_file = "{game_id}-game.jsonl"
|
17 |
|
18 |
def __init__(
|
|
|
23 |
|
24 |
# Game ID
|
25 |
self.game_id = game_id()
|
26 |
+
self.start_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
27 |
+
self.log_dir = os.path.join(self.log_dir, f"{self.start_time}-{self.game_id}")
|
28 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
29 |
|
30 |
# Gather Player Names
|
31 |
if human_name:
|
|
|
46 |
controller = "human"
|
47 |
else:
|
48 |
name = ai_names.pop()
|
49 |
+
controller = "openai"
|
50 |
|
51 |
if self.chameleon_index == i:
|
52 |
role = "chameleon"
|
|
|
65 |
# Game State
|
66 |
self.player_responses = []
|
67 |
|
68 |
+
def format_responses(self, exclude: str = None) -> str:
|
|
|
|
|
|
|
|
|
69 |
"""Formats the responses of the players into a single string."""
|
70 |
+
if len(self.player_responses) == 0:
|
71 |
+
return "None, you are the first player!"
|
72 |
+
else:
|
73 |
+
formatted_responses = ""
|
74 |
+
for response in self.player_responses:
|
75 |
+
# Used to exclude the player who is currently responding, so they don't vote for themselves like a fool
|
76 |
+
if response["sender"] != exclude:
|
77 |
+
formatted_responses += f" - {response['sender']}: {response['response']}\n"
|
78 |
+
|
79 |
+
return formatted_responses
|
80 |
|
81 |
def get_player_names(self) -> list[str]:
|
82 |
"""Returns the names of the players."""
|
|
|
100 |
prompt = format_prompt("herd_animal", animal=herd_animal, player_responses=self.format_responses())
|
101 |
|
102 |
# Get Player Animal Description
|
103 |
+
response = await player.respond_to(prompt, AnimalDescriptionModel)
|
|
|
|
|
104 |
|
105 |
+
self.player_responses.append({"sender": player.name, "response": response.description})
|
106 |
|
107 |
# Phase II: Chameleon Decides if they want to guess the animal (secretly)
|
108 |
prompt = format_prompt("chameleon_guess_decision", player_responses=self.format_responses())
|
109 |
|
110 |
+
response = await self.players[self.chameleon_index].respond_to(prompt, ChameleonGuessDecisionModel)
|
|
|
111 |
|
112 |
+
if response.decision == "guess":
|
113 |
chameleon_will_guess = True
|
114 |
else:
|
115 |
chameleon_will_guess = False
|
|
|
119 |
# Chameleon Guesses Animal
|
120 |
prompt = fetch_prompt("chameleon_guess_animal")
|
121 |
|
122 |
+
response = await self.players[self.chameleon_index].respond_to(prompt, ChameleonGuessAnimalModel)
|
|
|
123 |
|
124 |
+
if response.animal == herd_animal:
|
125 |
winner = "chameleon"
|
126 |
else:
|
127 |
winner = "herd"
|
|
|
133 |
prompt = format_prompt("vote", player_responses=self.format_responses())
|
134 |
|
135 |
# Get Player Vote
|
136 |
+
response = await player.respond_to(prompt, VoteModel)
|
|
|
|
|
137 |
|
138 |
# check if a valid player was voted for...
|
139 |
|
140 |
# Add Vote to Player Votes
|
141 |
+
player_votes.append(response.vote)
|
142 |
|
143 |
print(player_votes)
|
144 |
|
|
|
161 |
# Log Game Info
|
162 |
game_log = {
|
163 |
"game_id": self.game_id,
|
164 |
+
"start_time": self.start_time,
|
165 |
"herd_animal": herd_animal,
|
166 |
"number_of_players": len(self.players),
|
167 |
"human_player": self.players[self.human_index].id if self.human_index else "None",
|
src/models.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
from pydantic import BaseModel, field_validator
|
4 |
|
@@ -7,7 +9,7 @@ MAX_DESCRIPTION_LEN = 10
|
|
7 |
|
8 |
class AnimalDescriptionModel(BaseModel):
|
9 |
# Define fields of our class here
|
10 |
-
description: str
|
11 |
|
12 |
# @field_validator('description')
|
13 |
# @classmethod
|
@@ -45,11 +47,18 @@ class ChameleonGuessDecisionModel(BaseModel):
|
|
45 |
|
46 |
|
47 |
class ChameleonGuessAnimalModel(BaseModel):
|
48 |
-
animal: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
class VoteModel(BaseModel):
|
52 |
-
vote: str
|
53 |
|
54 |
# @field_validator('vote')
|
55 |
# @classmethod
|
|
|
1 |
+
import random
|
2 |
+
from typing import Annotated, Type, List
|
3 |
+
import json
|
4 |
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
|
|
|
9 |
|
10 |
class AnimalDescriptionModel(BaseModel):
|
11 |
# Define fields of our class here
|
12 |
+
description: Annotated[str, "A brief description of the animal"]
|
13 |
|
14 |
# @field_validator('description')
|
15 |
# @classmethod
|
|
|
47 |
|
48 |
|
49 |
class ChameleonGuessAnimalModel(BaseModel):
|
50 |
+
animal: Annotated[str, "The name of the animal the chameleon is guessing"]
|
51 |
+
|
52 |
+
@field_validator('animal')
|
53 |
+
@classmethod
|
54 |
+
def is_one_word(cls, v) -> str:
|
55 |
+
if len(v.split()) > 1:
|
56 |
+
raise ValueError("Animal's name must be one word")
|
57 |
+
return v
|
58 |
|
59 |
|
60 |
class VoteModel(BaseModel):
|
61 |
+
vote: Annotated[str, "The name of the player you are voting for"]
|
62 |
|
63 |
# @field_validator('vote')
|
64 |
# @classmethod
|
src/parser.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
from typing import Type
|
2 |
-
import asyncio
|
3 |
-
import json
|
4 |
-
|
5 |
-
from kani.engines.openai import OpenAIEngine
|
6 |
-
from pydantic import BaseModel, ValidationError
|
7 |
-
|
8 |
-
from agents import LogMessagesKani
|
9 |
-
|
10 |
-
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
11 |
-
Here is the output schema:
|
12 |
-
```
|
13 |
-
{schema}
|
14 |
-
```
|
15 |
-
"""
|
16 |
-
|
17 |
-
parser_prompt = """\
|
18 |
-
The user gave the following output to the prompt:
|
19 |
-
Prompt:
|
20 |
-
{prompt}
|
21 |
-
Output:
|
22 |
-
{message}
|
23 |
-
|
24 |
-
{format_instructions}
|
25 |
-
"""
|
26 |
-
|
27 |
-
|
28 |
-
class ParserKani(LogMessagesKani):
|
29 |
-
def __init__(self, engine, *args, **kwargs):
|
30 |
-
super().__init__(engine, *args, **kwargs)
|
31 |
-
|
32 |
-
async def parse(self, prompt: str, message: str, format_model: Type[BaseModel], max_retries: int = 3, **kwargs):
|
33 |
-
format_instructions = self.get_format_instructions(format_model)
|
34 |
-
|
35 |
-
parser_instructions = parser_prompt.format(
|
36 |
-
prompt=prompt,
|
37 |
-
message=message,
|
38 |
-
format_instructions=format_instructions
|
39 |
-
)
|
40 |
-
|
41 |
-
response = await self.chat_round_str(parser_instructions, **kwargs)
|
42 |
-
|
43 |
-
try:
|
44 |
-
output = format_model.model_validate_json(response)
|
45 |
-
except ValidationError as e:
|
46 |
-
print(f"Output did not conform to the expected format: {e}")
|
47 |
-
raise e
|
48 |
-
|
49 |
-
# Clear the Chat History after successful parse
|
50 |
-
self.chat_history = []
|
51 |
-
|
52 |
-
return output
|
53 |
-
|
54 |
-
@staticmethod
|
55 |
-
def get_format_instructions(format_model: Type[BaseModel]):
|
56 |
-
schema = format_model.model_json_schema()
|
57 |
-
|
58 |
-
reduced_schema = schema
|
59 |
-
if "title" in reduced_schema:
|
60 |
-
del reduced_schema["title"]
|
61 |
-
if "type" in reduced_schema:
|
62 |
-
del reduced_schema["type"]
|
63 |
-
|
64 |
-
schema_str = json.dumps(reduced_schema, indent=4)
|
65 |
-
|
66 |
-
return FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
67 |
-
|
68 |
-
@classmethod
|
69 |
-
def default(cls, log_filepath: str = None):
|
70 |
-
"""Default ParserKani with OpenAIEngine."""
|
71 |
-
engine = OpenAIEngine(model="gpt-3.5-turbo")
|
72 |
-
return cls(engine, log_filepath=log_filepath)
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
# Testing
|
77 |
-
# parser = ParserKani(engine=OpenAIEngine(model="gpt-3.5-turbo"))
|
78 |
-
#
|
79 |
-
# sample_prompt = """\
|
80 |
-
# Below are the responses from all players. Now it is time to vote. Choose from the players below who you think the Chameleon is.
|
81 |
-
# - Mallory: I am tall and have a long neck.
|
82 |
-
# - Jack: I am a herbivore and have a long neck.
|
83 |
-
# - Jill: I am a herbivore and have a long neck.
|
84 |
-
# - Bob: I am tall and have a long neck.
|
85 |
-
# - Courtney: I am tall and have a long neck.
|
86 |
-
# """
|
87 |
-
#
|
88 |
-
# sample_message = """\
|
89 |
-
# I think the Chameleon is Mallory.
|
90 |
-
# """
|
91 |
-
#
|
92 |
-
# test_output = asyncio.run(parser.parse(prompt=sample_prompt, message=sample_message, format_model=VoteModel))
|
93 |
-
#
|
94 |
-
# print(test_output)
|
95 |
-
#
|
96 |
-
# print(VoteModel.model_validate_json(test_output))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/player.py
CHANGED
@@ -1,44 +1,35 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import asyncio
|
4 |
|
5 |
-
import
|
6 |
-
from agents import LogMessagesKani
|
7 |
-
from kani import ChatMessage
|
8 |
-
from kani.engines.openai import OpenAIEngine
|
9 |
|
10 |
-
from
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
api_type = "openai"
|
14 |
-
# api_type = "ollama"
|
15 |
-
|
16 |
-
match api_type:
|
17 |
-
case "tgi":
|
18 |
-
# Using TGI Inference Endpoints from Hugging Face
|
19 |
-
default_engine = OpenAIEngine( # type: ignore
|
20 |
-
api_base=os.environ['HF_ENDPOINT_URL'] + "/v1/",
|
21 |
-
api_key=os.environ['HF_API_TOKEN']
|
22 |
-
)
|
23 |
-
case "openai":
|
24 |
-
# Using OpenAI GPT-3.5 Turbo
|
25 |
-
default_engine = OpenAIEngine(model="gpt-3.5-turbo") # type: ignore
|
26 |
-
case "ollama":
|
27 |
-
# Using Ollama
|
28 |
-
default_engine = OpenAIEngine(
|
29 |
-
api_base="http://localhost:11434/v1",
|
30 |
-
api_key="ollama",
|
31 |
-
model="mistral"
|
32 |
-
)
|
33 |
|
|
|
|
|
34 |
|
35 |
class Player:
|
36 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
self.name = name
|
38 |
self.id = id
|
39 |
-
|
40 |
-
if
|
41 |
-
self.
|
|
|
|
|
|
|
|
|
42 |
|
43 |
self.role = role
|
44 |
self.messages = []
|
@@ -50,26 +41,71 @@ class Player:
|
|
50 |
"id": self.id,
|
51 |
"name": self.name,
|
52 |
"role": self.role,
|
53 |
-
"controller":
|
|
|
|
|
|
|
54 |
}
|
55 |
log(player_info, log_filepath)
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
72 |
|
|
|
73 |
|
|
|
74 |
|
|
|
75 |
|
|
|
|
1 |
import os
|
2 |
+
from typing import Type
|
|
|
3 |
|
4 |
+
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
|
|
|
|
|
|
|
5 |
|
6 |
+
from langchain.output_parsers import PydanticOutputParser
|
7 |
+
from langchain_core.prompts import PromptTemplate
|
8 |
+
from langchain_core.messages import HumanMessage
|
9 |
|
10 |
+
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
from game_utils import log
|
13 |
+
from controllers import controller_from_name
|
14 |
|
15 |
class Player:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
name: str,
|
19 |
+
controller: str,
|
20 |
+
role: str,
|
21 |
+
id: str = None,
|
22 |
+
log_filepath: str = None
|
23 |
+
):
|
24 |
self.name = name
|
25 |
self.id = id
|
26 |
+
|
27 |
+
if controller == "human":
|
28 |
+
self.controller_type = "human"
|
29 |
+
else:
|
30 |
+
self.controller_type = "ai"
|
31 |
+
|
32 |
+
self.controller = controller_from_name(controller)
|
33 |
|
34 |
self.role = role
|
35 |
self.messages = []
|
|
|
41 |
"id": self.id,
|
42 |
"name": self.name,
|
43 |
"role": self.role,
|
44 |
+
"controller": {
|
45 |
+
"name": controller,
|
46 |
+
"type": self.controller_type
|
47 |
+
}
|
48 |
}
|
49 |
log(player_info, log_filepath)
|
50 |
|
51 |
+
# initialize the runnables
|
52 |
+
self.generate = RunnableLambda(self._generate)
|
53 |
+
self.format_output = RunnableLambda(self._output_formatter)
|
54 |
+
|
55 |
+
async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
|
56 |
+
"""Makes the player respond to a prompt. Returns the response in the specified format."""
|
57 |
+
message = HumanMessage(content=prompt)
|
58 |
+
output = await self.generate.ainvoke(message)
|
59 |
+
if self.controller_type == "ai":
|
60 |
+
retries = 0
|
61 |
+
try:
|
62 |
+
output = await self.format_output.ainvoke({"output_format": output_format})
|
63 |
+
except ValueError as e:
|
64 |
+
if retries < max_retries:
|
65 |
+
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
|
66 |
+
output = await self.format_output.ainvoke({"output_format": output_format})
|
67 |
+
retries += 1
|
68 |
+
else:
|
69 |
+
raise e
|
70 |
+
else:
|
71 |
+
# Convert the human message to the pydantic object format
|
72 |
+
field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
|
73 |
+
output = output_format.model_validate({field_name: output.content})
|
74 |
+
|
75 |
+
return output
|
76 |
+
|
77 |
+
def add_to_history(self, message):
|
78 |
+
self.messages.append(message)
|
79 |
+
# log(message.model_dump_json(), self.log_filepath)
|
80 |
+
|
81 |
+
def _generate(self, message: HumanMessage):
|
82 |
+
"""Entry point for the Runnable generating responses, automatically logs the message."""
|
83 |
+
self.add_to_history(message)
|
84 |
+
|
85 |
+
# AI's need to be fed the whole message history, but humans can just go back and look at it
|
86 |
+
if self.controller_type == "human":
|
87 |
+
response = self.controller.invoke(message.content)
|
88 |
+
else:
|
89 |
+
response = self.controller.invoke(self.messages)
|
90 |
+
|
91 |
+
self.add_to_history(response)
|
92 |
+
|
93 |
+
return response
|
94 |
+
|
95 |
+
def _output_formatter(self, inputs: dict):
|
96 |
+
"""Formats the output of the response."""
|
97 |
+
output_format: BaseModel = inputs["output_format"]
|
98 |
+
|
99 |
+
prompt_template = PromptTemplate.from_template(
|
100 |
+
"Please rewrite your previous response using the following format: \n\n{format_instructions}"
|
101 |
+
)
|
102 |
|
103 |
+
parser = PydanticOutputParser(pydantic_object=output_format)
|
104 |
|
105 |
+
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
|
106 |
|
107 |
+
message = HumanMessage(content=prompt.text)
|
108 |
|
109 |
+
response = self.generate.invoke(message)
|
110 |
|
111 |
+
return parser.invoke(response)
|
src/prompts.py
CHANGED
@@ -1,12 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def fetch_prompt(prompt_name):
|
2 |
"""Fetches a static prompt."""
|
3 |
return prompts[prompt_name]
|
4 |
-
|
5 |
def format_prompt(prompt_name, **kwargs):
|
6 |
"""Fetches a template prompt and populates it."""
|
7 |
return fetch_prompt(prompt_name).format(**kwargs)
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
_game_rules = '''\
|
11 |
GAME RULES: You are playing a social deduction game where every player pretends the be the same animal.
|
12 |
During the round players go around the room and make an "I"-statement as if they were the animal.
|
@@ -17,11 +71,11 @@ The other players must be careful not to give away too much information with the
|
|
17 |
|
18 |
_herd_animal = """\
|
19 |
You are a {animal}, keep this a secret at all costs.
|
20 |
-
In
|
21 |
-
|
22 |
Previously Mentioned Descriptions:
|
23 |
{player_responses}
|
24 |
-
"""
|
25 |
|
26 |
_chameleon_animal = """\
|
27 |
You are the Chameleon, keep this a secret at all costs.
|
@@ -31,7 +85,7 @@ If no one else has said anything try to say something generic that could be true
|
|
31 |
If the other players realize you are the Chameleon you will LOSE.
|
32 |
Previously Mentioned Descriptions:
|
33 |
{player_responses}
|
34 |
-
"""
|
35 |
|
36 |
_all_responses = """\
|
37 |
Below are the responses from all the other players.
|
@@ -42,6 +96,8 @@ _chameleon_guess_decision = """\
|
|
42 |
You now have the opportunity to guess what animal the other players are pretending to be.
|
43 |
If you guess correctly you will WIN, if you guess incorrectly you will LOSE.
|
44 |
If you believe you know what animal the other players are pretending to be make choose to GUESS, otherwise choose to PASS.
|
|
|
|
|
45 |
"""
|
46 |
|
47 |
_chameleon_guess_animal = """\
|
@@ -53,10 +109,9 @@ Now it is time to vote. Choose from the players above who you think the Chameleo
|
|
53 |
"""
|
54 |
|
55 |
prompts = {
|
56 |
-
"herd_animal":
|
57 |
"chameleon_animal": _game_rules + _chameleon_animal,
|
58 |
"chameleon_guess_decision": _all_responses + _chameleon_guess_decision,
|
59 |
"chameleon_guess_animal": _chameleon_guess_animal,
|
60 |
"vote": _all_responses + _vote_prompt
|
61 |
}
|
62 |
-
|
|
|
1 |
+
from models import *
|
2 |
+
from langchain.prompts.few_shot import FewShotPromptTemplate
|
3 |
+
from langchain.prompts.prompt import PromptTemplate
|
4 |
+
|
5 |
+
|
6 |
def fetch_prompt(prompt_name):
|
7 |
"""Fetches a static prompt."""
|
8 |
return prompts[prompt_name]
|
|
|
9 |
def format_prompt(prompt_name, **kwargs):
|
10 |
"""Fetches a template prompt and populates it."""
|
11 |
return fetch_prompt(prompt_name).format(**kwargs)
|
12 |
|
13 |
|
14 |
+
class Task:
|
15 |
+
def __init__(self, prompt: str, response_format: Type[BaseModel], few_shot_examples: List[dict] = None):
|
16 |
+
self.prompt = prompt
|
17 |
+
self.response_format = response_format
|
18 |
+
self.few_shot_examples = few_shot_examples
|
19 |
+
|
20 |
+
def full_prompt(self, **kwargs):
|
21 |
+
prompt = self.prompt.format(**kwargs)
|
22 |
+
|
23 |
+
format_instructions = self.get_format_instructions()
|
24 |
+
if self.few_shot_examples:
|
25 |
+
few_shot = self.get_few_shot()
|
26 |
+
|
27 |
+
def get_format_instructions(self):
|
28 |
+
schema = self.get_input_schema()
|
29 |
+
format_instructions = FORMAT_INSTRUCTIONS.format(schema=schema)
|
30 |
+
|
31 |
+
return format_instructions
|
32 |
+
|
33 |
+
def get_input_schema(self):
|
34 |
+
schema = self.response_format.model_json_schema()
|
35 |
+
|
36 |
+
reduced_schema = schema
|
37 |
+
if "title" in reduced_schema:
|
38 |
+
del reduced_schema["title"]
|
39 |
+
if "type" in reduced_schema:
|
40 |
+
del reduced_schema["type"]
|
41 |
+
|
42 |
+
schema_str = json.dumps(reduced_schema, indent=4)
|
43 |
+
|
44 |
+
return schema_str
|
45 |
+
|
46 |
+
def get_few_shot(self, max_examples=3):
|
47 |
+
if len(self.few_shot_examples) <= max_examples:
|
48 |
+
examples = self.few_shot_examples
|
49 |
+
else:
|
50 |
+
examples = random.sample(self.few_shot_examples, max_examples)
|
51 |
+
|
52 |
+
few_shot = "\n\n".join([self.format_example(ex) for ex in examples])
|
53 |
+
|
54 |
+
return few_shot
|
55 |
+
|
56 |
+
def format_example(self, example):
|
57 |
+
ex_prompt = self.prompt.format(**example['inputs'])
|
58 |
+
ex_response = example['response']
|
59 |
+
|
60 |
+
return f"Prompt: {ex_prompt}\nResponse: {ex_response}"
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
_game_rules = '''\
|
65 |
GAME RULES: You are playing a social deduction game where every player pretends the be the same animal.
|
66 |
During the round players go around the room and make an "I"-statement as if they were the animal.
|
|
|
71 |
|
72 |
_herd_animal = """\
|
73 |
You are a {animal}, keep this a secret at all costs.
|
74 |
+
In as few words as possible describe of yourself starting with "I". Your description should be vague but true, \
|
75 |
+
since if the Chameleon can guess animal you are, you will LOSE. Do not repeat responses from other players.
|
76 |
Previously Mentioned Descriptions:
|
77 |
{player_responses}
|
78 |
+
Your Response: """
|
79 |
|
80 |
_chameleon_animal = """\
|
81 |
You are the Chameleon, keep this a secret at all costs.
|
|
|
85 |
If the other players realize you are the Chameleon you will LOSE.
|
86 |
Previously Mentioned Descriptions:
|
87 |
{player_responses}
|
88 |
+
Your Response: """
|
89 |
|
90 |
_all_responses = """\
|
91 |
Below are the responses from all the other players.
|
|
|
96 |
You now have the opportunity to guess what animal the other players are pretending to be.
|
97 |
If you guess correctly you will WIN, if you guess incorrectly you will LOSE.
|
98 |
If you believe you know what animal the other players are pretending to be make choose to GUESS, otherwise choose to PASS.
|
99 |
+
Your response should be one of ("GUESS", "PASS")
|
100 |
+
Your Response:
|
101 |
"""
|
102 |
|
103 |
_chameleon_guess_animal = """\
|
|
|
109 |
"""
|
110 |
|
111 |
prompts = {
|
112 |
+
"herd_animal": _game_rules + _herd_animal,
|
113 |
"chameleon_animal": _game_rules + _chameleon_animal,
|
114 |
"chameleon_guess_decision": _all_responses + _chameleon_guess_decision,
|
115 |
"chameleon_guess_animal": _chameleon_guess_animal,
|
116 |
"vote": _all_responses + _vote_prompt
|
117 |
}
|
|