Eric Botti
commited on
Commit
·
abc228d
1
Parent(s):
3ac8e2a
fixed prompt queue and messages appending to class attribute
Browse files- src/player.py +35 -14
src/player.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import os
|
2 |
-
from typing import Type, Literal
|
3 |
import logging
|
4 |
|
5 |
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
|
6 |
|
7 |
from langchain.output_parsers import PydanticOutputParser
|
8 |
from langchain_core.prompts import PromptTemplate
|
9 |
-
from langchain_core.messages import HumanMessage, AnyMessage
|
10 |
|
11 |
from langchain_core.exceptions import OutputParserException
|
12 |
|
@@ -18,6 +17,25 @@ from controllers import controller_from_name
|
|
18 |
Role = Literal["chameleon", "herd"]
|
19 |
|
20 |
logging.basicConfig(level=logging.WARNING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
class Player:
|
23 |
|
@@ -29,10 +47,6 @@ class Player:
|
|
29 |
"""The number of times the player has been in the herd."""
|
30 |
points: int = 0
|
31 |
"""The number of points the player has."""
|
32 |
-
messages: list[AnyMessage] = []
|
33 |
-
"""The messages the player has sent and received."""
|
34 |
-
prompt_queue: list[str] = []
|
35 |
-
"""A queue of prompts to be added to the next prompt."""
|
36 |
|
37 |
def __init__(
|
38 |
self,
|
@@ -50,7 +64,13 @@ class Player:
|
|
50 |
self.controller_type = "ai"
|
51 |
|
52 |
self.controller = controller_from_name(controller)
|
|
|
53 |
self.log_filepath = log_filepath
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
if log_filepath:
|
56 |
player_info = {
|
@@ -83,7 +103,7 @@ class Player:
|
|
83 |
# Clear the prompt queue
|
84 |
self.prompt_queue = []
|
85 |
|
86 |
-
message =
|
87 |
output = await self.generate.ainvoke(message)
|
88 |
if self.controller_type == "ai":
|
89 |
retries = 0
|
@@ -92,7 +112,7 @@ class Player:
|
|
92 |
except OutputParserException as e:
|
93 |
if retries < max_retries:
|
94 |
retries += 1
|
95 |
-
|
96 |
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
|
97 |
output = await self.format_output.ainvoke({"output_format": output_format})
|
98 |
|
@@ -106,9 +126,9 @@ class Player:
|
|
106 |
|
107 |
return output
|
108 |
|
109 |
-
def add_to_history(self, message:
|
110 |
self.messages.append(message)
|
111 |
-
log(message.
|
112 |
|
113 |
def is_human(self):
|
114 |
return self.controller_type == "human"
|
@@ -116,7 +136,7 @@ class Player:
|
|
116 |
def is_ai(self):
|
117 |
return not self.is_human()
|
118 |
|
119 |
-
def _generate(self, message:
|
120 |
"""Entry point for the Runnable generating responses, automatically logs the message."""
|
121 |
self.add_to_history(message)
|
122 |
|
@@ -124,9 +144,10 @@ class Player:
|
|
124 |
if self.controller_type == "human":
|
125 |
response = self.controller.invoke(message.content)
|
126 |
else:
|
127 |
-
|
|
|
128 |
|
129 |
-
self.add_to_history(response)
|
130 |
|
131 |
return response
|
132 |
|
@@ -142,7 +163,7 @@ class Player:
|
|
142 |
|
143 |
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
|
144 |
|
145 |
-
message =
|
146 |
|
147 |
response = self.generate.invoke(message)
|
148 |
|
|
|
1 |
import os
|
2 |
+
from typing import Type, Literal, List
|
3 |
import logging
|
4 |
|
5 |
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
|
6 |
|
7 |
from langchain.output_parsers import PydanticOutputParser
|
8 |
from langchain_core.prompts import PromptTemplate
|
|
|
9 |
|
10 |
from langchain_core.exceptions import OutputParserException
|
11 |
|
|
|
17 |
Role = Literal["chameleon", "herd"]
|
18 |
|
19 |
logging.basicConfig(level=logging.WARNING)
|
20 |
+
logger = logging.getLogger("chameleon")
|
21 |
+
|
22 |
+
|
23 |
+
# Lots of AI Libraries use HumanMessage and AIMessage as the base classes for their messages.
|
24 |
+
# This doesn't make sense for our as Humans and AIs are both players in the game, meaning they have the same role.
|
25 |
+
# The Langchain type field is used to convert to that syntax.
|
26 |
+
class Message(BaseModel):
|
27 |
+
type: Literal["prompt", "player"]
|
28 |
+
"""The type of the message. Can be "prompt" or "player"."""
|
29 |
+
content: str
|
30 |
+
"""The content of the message."""
|
31 |
+
@property
|
32 |
+
def langchain_type(self):
|
33 |
+
"""Returns the langchain message type for the message."""
|
34 |
+
if self.type == "prompt":
|
35 |
+
return "human"
|
36 |
+
else:
|
37 |
+
return "ai"
|
38 |
+
|
39 |
|
40 |
class Player:
|
41 |
|
|
|
47 |
"""The number of times the player has been in the herd."""
|
48 |
points: int = 0
|
49 |
"""The number of points the player has."""
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def __init__(
|
52 |
self,
|
|
|
64 |
self.controller_type = "ai"
|
65 |
|
66 |
self.controller = controller_from_name(controller)
|
67 |
+
"""The controller for the player."""
|
68 |
self.log_filepath = log_filepath
|
69 |
+
"""The filepath to the log file. If None, no logs will be written."""
|
70 |
+
self.messages: list[Message] = []
|
71 |
+
"""The messages the player has sent and received."""
|
72 |
+
self.prompt_queue: List[str] = []
|
73 |
+
"""A queue of prompts to be added to the next prompt."""
|
74 |
|
75 |
if log_filepath:
|
76 |
player_info = {
|
|
|
103 |
# Clear the prompt queue
|
104 |
self.prompt_queue = []
|
105 |
|
106 |
+
message = Message(type="prompt", content=prompt)
|
107 |
output = await self.generate.ainvoke(message)
|
108 |
if self.controller_type == "ai":
|
109 |
retries = 0
|
|
|
112 |
except OutputParserException as e:
|
113 |
if retries < max_retries:
|
114 |
retries += 1
|
115 |
+
logger.warning(f"Player {self.id} failed to format response: {output} due to an exception: {e} \n\n Retrying {retries}/{max_retries}")
|
116 |
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
|
117 |
output = await self.format_output.ainvoke({"output_format": output_format})
|
118 |
|
|
|
126 |
|
127 |
return output
|
128 |
|
129 |
+
def add_to_history(self, message: Message):
|
130 |
self.messages.append(message)
|
131 |
+
log(message.model_dump(), self.log_filepath)
|
132 |
|
133 |
def is_human(self):
|
134 |
return self.controller_type == "human"
|
|
|
136 |
def is_ai(self):
|
137 |
return not self.is_human()
|
138 |
|
139 |
+
def _generate(self, message: Message):
|
140 |
"""Entry point for the Runnable generating responses, automatically logs the message."""
|
141 |
self.add_to_history(message)
|
142 |
|
|
|
144 |
if self.controller_type == "human":
|
145 |
response = self.controller.invoke(message.content)
|
146 |
else:
|
147 |
+
formatted_messages = [(message.langchain_type, message.content) for message in self.messages]
|
148 |
+
response = self.controller.invoke(formatted_messages)
|
149 |
|
150 |
+
self.add_to_history(Message(type="player", content=response.content))
|
151 |
|
152 |
return response
|
153 |
|
|
|
163 |
|
164 |
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
|
165 |
|
166 |
+
message = Message(type="player", content=prompt.text)
|
167 |
|
168 |
response = self.generate.invoke(message)
|
169 |
|