Eric Botti commited on
Commit
abc228d
·
1 Parent(s): 3ac8e2a

fixed prompt queue and messages appending to class attribute

Browse files
Files changed (1) hide show
  1. src/player.py +35 -14
src/player.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
- from typing import Type, Literal
3
  import logging
4
 
5
  from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
6
 
7
  from langchain.output_parsers import PydanticOutputParser
8
  from langchain_core.prompts import PromptTemplate
9
- from langchain_core.messages import HumanMessage, AnyMessage
10
 
11
  from langchain_core.exceptions import OutputParserException
12
 
@@ -18,6 +17,25 @@ from controllers import controller_from_name
18
  Role = Literal["chameleon", "herd"]
19
 
20
  logging.basicConfig(level=logging.WARNING)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  class Player:
23
 
@@ -29,10 +47,6 @@ class Player:
29
  """The number of times the player has been in the herd."""
30
  points: int = 0
31
  """The number of points the player has."""
32
- messages: list[AnyMessage] = []
33
- """The messages the player has sent and received."""
34
- prompt_queue: list[str] = []
35
- """A queue of prompts to be added to the next prompt."""
36
 
37
  def __init__(
38
  self,
@@ -50,7 +64,13 @@ class Player:
50
  self.controller_type = "ai"
51
 
52
  self.controller = controller_from_name(controller)
 
53
  self.log_filepath = log_filepath
 
 
 
 
 
54
 
55
  if log_filepath:
56
  player_info = {
@@ -83,7 +103,7 @@ class Player:
83
  # Clear the prompt queue
84
  self.prompt_queue = []
85
 
86
- message = HumanMessage(content=prompt)
87
  output = await self.generate.ainvoke(message)
88
  if self.controller_type == "ai":
89
  retries = 0
@@ -92,7 +112,7 @@ class Player:
92
  except OutputParserException as e:
93
  if retries < max_retries:
94
  retries += 1
95
- logging.warning(f"Player {self.id} failed to format response: {output} due to an exception: {e} \n\n Retrying {retries}/{max_retries}")
96
  self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
97
  output = await self.format_output.ainvoke({"output_format": output_format})
98
 
@@ -106,9 +126,9 @@ class Player:
106
 
107
  return output
108
 
109
- def add_to_history(self, message: AnyMessage):
110
  self.messages.append(message)
111
- log(message.dict(), self.log_filepath)
112
 
113
  def is_human(self):
114
  return self.controller_type == "human"
@@ -116,7 +136,7 @@ class Player:
116
  def is_ai(self):
117
  return not self.is_human()
118
 
119
- def _generate(self, message: HumanMessage):
120
  """Entry point for the Runnable generating responses, automatically logs the message."""
121
  self.add_to_history(message)
122
 
@@ -124,9 +144,10 @@ class Player:
124
  if self.controller_type == "human":
125
  response = self.controller.invoke(message.content)
126
  else:
127
- response = self.controller.invoke(self.messages)
 
128
 
129
- self.add_to_history(response)
130
 
131
  return response
132
 
@@ -142,7 +163,7 @@ class Player:
142
 
143
  prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
144
 
145
- message = HumanMessage(content=prompt.text)
146
 
147
  response = self.generate.invoke(message)
148
 
 
1
  import os
2
+ from typing import Type, Literal, List
3
  import logging
4
 
5
  from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
6
 
7
  from langchain.output_parsers import PydanticOutputParser
8
  from langchain_core.prompts import PromptTemplate
 
9
 
10
  from langchain_core.exceptions import OutputParserException
11
 
 
17
  Role = Literal["chameleon", "herd"]
18
 
19
  logging.basicConfig(level=logging.WARNING)
20
+ logger = logging.getLogger("chameleon")
21
+
22
+
23
+ # Lots of AI Libraries use HumanMessage and AIMessage as the base classes for their messages.
24
+ # This doesn't make sense for our as Humans and AIs are both players in the game, meaning they have the same role.
25
+ # The Langchain type field is used to convert to that syntax.
26
+ class Message(BaseModel):
27
+ type: Literal["prompt", "player"]
28
+ """The type of the message. Can be "prompt" or "player"."""
29
+ content: str
30
+ """The content of the message."""
31
+ @property
32
+ def langchain_type(self):
33
+ """Returns the langchain message type for the message."""
34
+ if self.type == "prompt":
35
+ return "human"
36
+ else:
37
+ return "ai"
38
+
39
 
40
  class Player:
41
 
 
47
  """The number of times the player has been in the herd."""
48
  points: int = 0
49
  """The number of points the player has."""
 
 
 
 
50
 
51
  def __init__(
52
  self,
 
64
  self.controller_type = "ai"
65
 
66
  self.controller = controller_from_name(controller)
67
+ """The controller for the player."""
68
  self.log_filepath = log_filepath
69
+ """The filepath to the log file. If None, no logs will be written."""
70
+ self.messages: list[Message] = []
71
+ """The messages the player has sent and received."""
72
+ self.prompt_queue: List[str] = []
73
+ """A queue of prompts to be added to the next prompt."""
74
 
75
  if log_filepath:
76
  player_info = {
 
103
  # Clear the prompt queue
104
  self.prompt_queue = []
105
 
106
+ message = Message(type="prompt", content=prompt)
107
  output = await self.generate.ainvoke(message)
108
  if self.controller_type == "ai":
109
  retries = 0
 
112
  except OutputParserException as e:
113
  if retries < max_retries:
114
  retries += 1
115
+ logger.warning(f"Player {self.id} failed to format response: {output} due to an exception: {e} \n\n Retrying {retries}/{max_retries}")
116
  self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
117
  output = await self.format_output.ainvoke({"output_format": output_format})
118
 
 
126
 
127
  return output
128
 
129
+ def add_to_history(self, message: Message):
130
  self.messages.append(message)
131
+ log(message.model_dump(), self.log_filepath)
132
 
133
  def is_human(self):
134
  return self.controller_type == "human"
 
136
  def is_ai(self):
137
  return not self.is_human()
138
 
139
+ def _generate(self, message: Message):
140
  """Entry point for the Runnable generating responses, automatically logs the message."""
141
  self.add_to_history(message)
142
 
 
144
  if self.controller_type == "human":
145
  response = self.controller.invoke(message.content)
146
  else:
147
+ formatted_messages = [(message.langchain_type, message.content) for message in self.messages]
148
+ response = self.controller.invoke(formatted_messages)
149
 
150
+ self.add_to_history(Message(type="player", content=response.content))
151
 
152
  return response
153
 
 
163
 
164
  prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
165
 
166
+ message = Message(type="player", content=prompt.text)
167
 
168
  response = self.generate.invoke(message)
169