File size: 7,360 Bytes
5c924a1
8d942c4
dfdde45
 
 
 
8d942c4
dfdde45
87c03a0
dfdde45
5dbe83d
dfdde45
47b6f03
8d942c4
dfdde45
 
 
 
 
8d942c4
 
2f9cbc8
 
 
8d942c4
 
 
 
dfdde45
8d942c4
dfdde45
81e1c72
 
 
 
dfdde45
 
2f9cbc8
dfdde45
81e1c72
 
dfdde45
81e1c72
dfdde45
587e98a
81e1c72
dfdde45
 
47b6f03
81e1c72
 
 
 
 
 
 
 
 
 
 
 
c6447fa
81e1c72
c6447fa
 
 
 
2f9cbc8
c6447fa
 
 
81e1c72
 
47b6f03
f7ce19f
81e1c72
47b6f03
 
81e1c72
 
dfdde45
250cc97
dfdde45
 
 
 
81e1c72
dfdde45
 
81e1c72
 
 
 
 
 
 
 
 
dfdde45
 
81e1c72
250cc97
47b6f03
250cc97
dfdde45
 
81e1c72
 
 
 
 
 
 
 
 
 
 
 
 
dfdde45
 
81e1c72
dfdde45
 
 
 
 
47b6f03
dfdde45
 
8d942c4
47b6f03
2f9cbc8
8d942c4
 
 
dfdde45
81e1c72
dfdde45
 
 
 
 
 
 
 
 
 
 
 
8d942c4
dfdde45
81e1c72
 
 
 
 
c6447fa
dfdde45
81e1c72
c6447fa
 
 
672c019
 
 
 
 
 
 
 
 
 
 
c6447fa
 
dfdde45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81e1c72
dfdde45
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from json import JSONDecodeError
from typing import Type, NewType, List, Any
import json

from openai import OpenAI
from colorama import Fore, Style
from pydantic import BaseModel, ValidationError, Field, ConfigDict

from output_formats import OutputFormatModel
from message import Message, AgentMessage
from data_collection import save


class BaseAgentInterface(BaseModel):
    """
    The interface that agents use to receive info from and interact with the game.
    This is the base class and should not be used directly.
    """

    agent_id: str
    """The id of the agent."""
    game_id: str
    """The id of the game the agent is in."""

    log_messages: bool = True
    """Whether to log messages or not."""
    messages: List[Message] = []
    """The message history of the agent."""
    is_human: bool = False
    """Whether the agent is human or not."""

    @property
    def is_ai(self):
        return not self.is_human

    def add_message(self, message: Message):
        """Adds a message to the message history, without generating a response."""
        self.messages.append(message)

    # Respond To methods - These take a message as input and generate a response

    def respond_to(self, message: Message) -> Message:
        """Take a message as input and return a response. Both the message and the response are added to history."""
        self.add_message(message)
        save(AgentMessage.from_message(message, [self.agent_id], self.game_id))
        response = self.generate_response()
        return response

    def respond_to_formatted(
            self, message: Message,
            output_format: Type[OutputFormatModel],
            additional_fields: dict = None,
            **kwargs
    ) -> OutputFormatModel:
        """Responds to a message and logs the response."""
        self.add_message(message)
        output = self.generate_formatted_response(output_format, additional_fields, **kwargs)
        return output

    # Generate response methods - These do not take a message as input and only use the current message history

    def generate_response(self) -> Message | None:
        """Generates a response based on the current messages in the history."""
        content = self._generate()
        if content:
            response = Message(type="agent", content=content)
            self.add_message(response)
            save(AgentMessage.from_message(response, [self.agent_id], self.game_id))
            return response
        else:
            return None

    def generate_formatted_response(
            self,
            output_format: Type[OutputFormatModel],
            additional_fields: dict = None,
            max_retries=3,
    ) -> OutputFormatModel:
        """Generates a response matching the provided format."""
        initial_response = self.generate_response()

        reformat_message = Message(type="format", content=output_format.get_format_instructions())

        output = None
        retries = 0

        while not output:
            try:
                formatted_response = self.respond_to(reformat_message)

                fields = json.loads(formatted_response.content)
                if additional_fields:
                    fields.update(additional_fields)

                output = output_format.model_validate(fields)

            except ValidationError as e:
                # If the response doesn't match the format, we ask the agent to try again
                if retries > max_retries:
                    raise e

                retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
                reformat_message = retry_message

                retries += 1

            except JSONDecodeError as e:
                # Occasionally models will output json as a code block, which will cause a JSONDecodeError
                if retries > max_retries:
                    raise e

                retry_message = Message(type="retry",
                                        content="There was an Error with your JSON format. Make sure you are not using code blocks."
                                                "i.e. your response should be:\n{...}\n"
                                                "Instead of:\n```json\n{...}\n```\n\n Please try again.")
                reformat_message = retry_message

                retries += 1

        return output

    def _generate(self) -> str:
        """Generates a response from the Agent."""
        # This is the BaseAgent class, and thus has no response logic
        # Subclasses should implement this method to generate a response using the message history
        raise NotImplementedError


class OpenAIAgentInterface(BaseAgentInterface):
    """An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
    model_config = ConfigDict(protected_namespaces=())

    model_name: str = "gpt-3.5-turbo"
    """The name of the model to use for generating responses."""
    client: Any = Field(default_factory=OpenAI, exclude=True)
    """The OpenAI client used to generate responses."""

    def _generate(self) -> str:
        """Generates a response using the message history"""
        open_ai_messages = [message.to_openai() for message in self.messages]

        completion = self.client.chat.completions.create(
            model=self.model_name,
            messages=open_ai_messages
        )

        return completion.choices[0].message.content


class HumanAgentInterface(BaseAgentInterface):
    is_human: bool = Field(default=True, frozen=True)

    def generate_formatted_response(
            self,
            output_format: Type[OutputFormatModel],
            additional_fields: dict = None,
            max_retries: int = 3
    ) -> OutputFormatModel | None:
        """For Human agents, we can trust them enough to format their own responses... for now"""
        response = self.generate_response()

        if response:
            # only works because current outputs have only 1 field...
            try:
                fields = {output_format.model_fields.copy().popitem()[0]: response.content}
                if additional_fields:
                    fields.update(additional_fields)
                output = output_format.model_validate(fields)

            except ValidationError as e:
                retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
                self.add_message(retry_message)
                output = None

        else:
            output = None

        return output


class HumanAgentCLI(HumanAgentInterface):
    """A Human agent that uses the command line interface to generate responses."""
    def add_message(self, message: Message):
        super().add_message(message)
        if message.type == "verbose":
            print(Fore.GREEN + message.content + Style.RESET_ALL)
        elif message.type == "debug":
            print(Fore.YELLOW + "DEBUG: " + message.content + Style.RESET_ALL)
        elif message.type != "agent":
            # Prevents the agent from seeing its own messages on the command line
            print(message.content)

    def _generate(self) -> str:
        """Generates a response using the message history"""
        response = input()
        return response