File size: 6,374 Bytes
778c3d7
abc228d
c6c2b98
a92f249
f596e58
a92f249
f596e58
 
778c3d7
7877562
 
f596e58
a92f249
f596e58
 
a92f249
7877562
 
c6c2b98
abc228d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2392fe
5de0b8a
7877562
 
 
 
 
 
 
 
 
 
f596e58
 
 
 
c2392fe
f596e58
 
5de0b8a
c2392fe
f596e58
 
 
 
 
 
 
abc228d
ea658a2
abc228d
 
 
 
 
ea658a2
172af0f
 
 
 
 
f596e58
 
 
 
172af0f
 
 
f596e58
 
 
 
7877562
 
 
 
 
 
 
f596e58
 
c2392fe
 
 
7877562
 
c2392fe
abc228d
f596e58
 
 
 
 
c6c2b98
 
 
abc228d
f596e58
 
c6c2b98
f596e58
c6c2b98
f596e58
 
 
 
 
 
 
 
abc228d
f596e58
abc228d
f596e58
7877562
 
 
 
 
 
abc228d
f596e58
 
 
 
 
 
 
abc228d
 
f596e58
abc228d
f596e58
 
 
 
 
 
 
 
 
 
5de0b8a
f596e58
5de0b8a
f596e58
5de0b8a
abc228d
5de0b8a
f596e58
ea658a2
f596e58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import os
from typing import Type, Literal, List
import logging

from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate

from langchain_core.exceptions import OutputParserException

from pydantic import BaseModel

from game_utils import log
from controllers import controller_from_name

Role = Literal["chameleon", "herd"]

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger("chameleon")


# Lots of AI Libraries use HumanMessage and AIMessage as the base classes for their messages.
# This doesn't make sense for our as Humans and AIs are both players in the game, meaning they have the same role.
# The Langchain type field is used to convert to that syntax.
class Message(BaseModel):
    type: Literal["prompt", "player"]
    """The type of the message. Can be "prompt" or "player"."""
    content: str
    """The content of the message."""
    @property
    def langchain_type(self):
        """Returns the langchain message type for the message."""
        if self.type == "prompt":
            return "human"
        else:
            return "ai"


class Player:

    role: Role | None = None
    """The role of the player in the game. Can be "chameleon" or "herd"."""
    rounds_played_as_chameleon: int = 0
    """The number of times the player has been the chameleon."""
    rounds_played_as_herd: int = 0
    """The number of times the player has been in the herd."""
    points: int = 0
    """The number of points the player has."""

    def __init__(
            self,
            name: str,
            controller: str,
            player_id: str = None,
            log_filepath: str = None
    ):
        self.name = name
        self.id = player_id

        if controller == "human":
            self.controller_type = "human"
        else:
            self.controller_type = "ai"

        self.controller = controller_from_name(controller)
        """The controller for the player."""
        self.log_filepath = log_filepath
        """The filepath to the log file. If None, no logs will be written."""
        self.messages: list[Message] = []
        """The messages the player has sent and received."""
        self.prompt_queue: List[str] = []
        """A queue of prompts to be added to the next prompt."""

        if log_filepath:
            player_info = {
                "id": self.id,
                "name": self.name,
                "role": self.role,
                "controller": {
                    "name": controller,
                    "type": self.controller_type
                }
            }
            log(player_info, log_filepath)

        # initialize the runnables
        self.generate = RunnableLambda(self._generate)
        self.format_output = RunnableLambda(self._output_formatter)

    def assign_role(self, role: Role):
        self.role = role
        if role == "chameleon":
            self.rounds_played_as_chameleon += 1
        elif role == "herd":
            self.rounds_played_as_herd += 1

    async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
        """Makes the player respond to a prompt. Returns the response in the specified format."""
        if self.prompt_queue:
            # If there are prompts in the queue, add them to the current prompt
            prompt = "\n".join(self.prompt_queue + [prompt])
            # Clear the prompt queue
            self.prompt_queue = []

        message = Message(type="prompt", content=prompt)
        output = await self.generate.ainvoke(message)
        if self.controller_type == "ai":
            retries = 0
            try:
                output = await self.format_output.ainvoke({"output_format": output_format})
            except OutputParserException as e:
                if retries < max_retries:
                    retries += 1
                    logger.warning(f"Player {self.id} failed to format response: {output} due to an exception: {e} \n\n Retrying {retries}/{max_retries}")
                    self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
                    output = await self.format_output.ainvoke({"output_format": output_format})

                else:
                    logging.error(f"Max retries reached due to Error: {e}")
                    raise e
        else:
            # Convert the human message to the pydantic object format
            field_name = output_format.model_fields.copy().popitem()[0]  # only works because current outputs have only 1 field
            output = output_format.model_validate({field_name: output.content})

        return output

    def add_to_history(self, message: Message):
        self.messages.append(message)
        log(message.model_dump(), self.log_filepath)

    def is_human(self):
        return self.controller_type == "human"

    def is_ai(self):
        return not self.is_human()

    def _generate(self, message: Message):
        """Entry point for the Runnable generating responses, automatically logs the message."""
        self.add_to_history(message)

        # AI's need to be fed the whole message history, but humans can just go back and look at it
        if self.controller_type == "human":
            response = self.controller.invoke(message.content)
        else:
            formatted_messages = [(message.langchain_type, message.content) for message in self.messages]
            response = self.controller.invoke(formatted_messages)

        self.add_to_history(Message(type="player", content=response.content))

        return response

    def _output_formatter(self, inputs: dict):
        """Formats the output of the response."""
        output_format: BaseModel = inputs["output_format"]

        prompt_template = PromptTemplate.from_template(
            "Please rewrite your previous response using the following format: \n\n{format_instructions}"
        )

        parser = PydanticOutputParser(pydantic_object=output_format)

        prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})

        message = Message(type="player", content=prompt.text)

        response = self.generate.invoke(message)

        return parser.invoke(response)