File size: 6,374 Bytes
778c3d7 abc228d c6c2b98 a92f249 f596e58 a92f249 f596e58 778c3d7 7877562 f596e58 a92f249 f596e58 a92f249 7877562 c6c2b98 abc228d c2392fe 5de0b8a 7877562 f596e58 c2392fe f596e58 5de0b8a c2392fe f596e58 abc228d ea658a2 abc228d ea658a2 172af0f f596e58 172af0f f596e58 7877562 f596e58 c2392fe 7877562 c2392fe abc228d f596e58 c6c2b98 abc228d f596e58 c6c2b98 f596e58 c6c2b98 f596e58 abc228d f596e58 abc228d f596e58 7877562 abc228d f596e58 abc228d f596e58 abc228d f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a abc228d 5de0b8a f596e58 ea658a2 f596e58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
from typing import Type, Literal, List
import logging
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.exceptions import OutputParserException
from pydantic import BaseModel
from game_utils import log
from controllers import controller_from_name
Role = Literal["chameleon", "herd"]
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger("chameleon")
# Lots of AI Libraries use HumanMessage and AIMessage as the base classes for their messages.
# This doesn't make sense for our as Humans and AIs are both players in the game, meaning they have the same role.
# The Langchain type field is used to convert to that syntax.
class Message(BaseModel):
type: Literal["prompt", "player"]
"""The type of the message. Can be "prompt" or "player"."""
content: str
"""The content of the message."""
@property
def langchain_type(self):
"""Returns the langchain message type for the message."""
if self.type == "prompt":
return "human"
else:
return "ai"
class Player:
role: Role | None = None
"""The role of the player in the game. Can be "chameleon" or "herd"."""
rounds_played_as_chameleon: int = 0
"""The number of times the player has been the chameleon."""
rounds_played_as_herd: int = 0
"""The number of times the player has been in the herd."""
points: int = 0
"""The number of points the player has."""
def __init__(
self,
name: str,
controller: str,
player_id: str = None,
log_filepath: str = None
):
self.name = name
self.id = player_id
if controller == "human":
self.controller_type = "human"
else:
self.controller_type = "ai"
self.controller = controller_from_name(controller)
"""The controller for the player."""
self.log_filepath = log_filepath
"""The filepath to the log file. If None, no logs will be written."""
self.messages: list[Message] = []
"""The messages the player has sent and received."""
self.prompt_queue: List[str] = []
"""A queue of prompts to be added to the next prompt."""
if log_filepath:
player_info = {
"id": self.id,
"name": self.name,
"role": self.role,
"controller": {
"name": controller,
"type": self.controller_type
}
}
log(player_info, log_filepath)
# initialize the runnables
self.generate = RunnableLambda(self._generate)
self.format_output = RunnableLambda(self._output_formatter)
def assign_role(self, role: Role):
self.role = role
if role == "chameleon":
self.rounds_played_as_chameleon += 1
elif role == "herd":
self.rounds_played_as_herd += 1
async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
"""Makes the player respond to a prompt. Returns the response in the specified format."""
if self.prompt_queue:
# If there are prompts in the queue, add them to the current prompt
prompt = "\n".join(self.prompt_queue + [prompt])
# Clear the prompt queue
self.prompt_queue = []
message = Message(type="prompt", content=prompt)
output = await self.generate.ainvoke(message)
if self.controller_type == "ai":
retries = 0
try:
output = await self.format_output.ainvoke({"output_format": output_format})
except OutputParserException as e:
if retries < max_retries:
retries += 1
logger.warning(f"Player {self.id} failed to format response: {output} due to an exception: {e} \n\n Retrying {retries}/{max_retries}")
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
output = await self.format_output.ainvoke({"output_format": output_format})
else:
logging.error(f"Max retries reached due to Error: {e}")
raise e
else:
# Convert the human message to the pydantic object format
field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
output = output_format.model_validate({field_name: output.content})
return output
def add_to_history(self, message: Message):
self.messages.append(message)
log(message.model_dump(), self.log_filepath)
def is_human(self):
return self.controller_type == "human"
def is_ai(self):
return not self.is_human()
def _generate(self, message: Message):
"""Entry point for the Runnable generating responses, automatically logs the message."""
self.add_to_history(message)
# AI's need to be fed the whole message history, but humans can just go back and look at it
if self.controller_type == "human":
response = self.controller.invoke(message.content)
else:
formatted_messages = [(message.langchain_type, message.content) for message in self.messages]
response = self.controller.invoke(formatted_messages)
self.add_to_history(Message(type="player", content=response.content))
return response
def _output_formatter(self, inputs: dict):
"""Formats the output of the response."""
output_format: BaseModel = inputs["output_format"]
prompt_template = PromptTemplate.from_template(
"Please rewrite your previous response using the following format: \n\n{format_instructions}"
)
parser = PydanticOutputParser(pydantic_object=output_format)
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
message = Message(type="player", content=prompt.text)
response = self.generate.invoke(message)
return parser.invoke(response)
|