File size: 6,688 Bytes
5c924a1
dfdde45
 
 
 
 
 
 
87c03a0
dfdde45
5dbe83d
dfdde45
47b6f03
dfdde45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81e1c72
 
 
 
dfdde45
 
 
5dbe83d
dfdde45
 
81e1c72
 
dfdde45
81e1c72
dfdde45
81e1c72
dfdde45
 
47b6f03
81e1c72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47b6f03
f7ce19f
81e1c72
47b6f03
 
81e1c72
 
dfdde45
250cc97
dfdde45
 
 
 
81e1c72
dfdde45
 
81e1c72
 
 
 
 
 
 
 
 
dfdde45
 
81e1c72
250cc97
47b6f03
250cc97
dfdde45
 
81e1c72
 
 
 
 
 
 
 
 
 
 
 
 
dfdde45
 
81e1c72
 
 
dfdde45
 
 
 
 
47b6f03
dfdde45
 
 
 
 
47b6f03
dfdde45
 
 
 
 
81e1c72
dfdde45
 
 
 
 
 
 
 
 
 
 
 
 
 
81e1c72
 
 
 
 
 
dfdde45
81e1c72
dfdde45
81e1c72
 
 
47b6f03
dfdde45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81e1c72
dfdde45
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from json import JSONDecodeError
from typing import Type, NewType
import json

from openai import OpenAI
from colorama import Fore, Style
from pydantic import BaseModel, ValidationError

from output_formats import OutputFormatModel
from message import Message, AgentMessage
from data_collection import save


class BaseAgentInterface:
    """
    The interface that agents use to receive info from and interact with the game.
    This is the base class and should not be used directly.
    """

    is_human: bool = False

    def __init__(
            self,
            agent_id: str = None
    ):
        self.id = agent_id
        self.messages = []

    @property
    def is_ai(self):
        return not self.is_human

    def add_message(self, message: Message):
        """Adds a message to the message history, without generating a response."""
        bound_message = AgentMessage.from_message(message, self.id, len(self.messages))
        save(bound_message)
        self.messages.append(bound_message)

    # Respond To methods - These take a message as input and generate a response

    def respond_to(self, message: Message) -> Message:
        """Take a message as input and return a response. Both the message and the response are added to history."""
        self.add_message(message)
        response = self.generate_response()
        return response

    def respond_to_formatted(
            self, message: Message,
            output_format: Type[OutputFormatModel],
            additional_fields: dict = None,
            **kwargs
    ) -> OutputFormatModel:
        """Responds to a message and logs the response."""
        self.add_message(message)
        output = self.generate_formatted_response(output_format, additional_fields, **kwargs)
        return output

    # Generate response methods - These do not take a message as input and only use the current message history

    def generate_response(self) -> Message:
        """Generates a response based on the current messages in the history."""
        response = Message(type="agent", content=self._generate())
        self.add_message(response)
        return response

    def generate_formatted_response(
            self,
            output_format: Type[OutputFormatModel],
            additional_fields: dict = None,
            max_retries=3,
    ) -> OutputFormatModel:
        """Generates a response matching the provided format."""
        initial_response = self.generate_response()

        reformat_message = Message(type="format", content=output_format.get_format_instructions())

        output = None
        retries = 0

        while not output:
            try:
                formatted_response = self.respond_to(reformat_message)

                fields = json.loads(formatted_response.content)
                if additional_fields:
                    fields.update(additional_fields)

                output = output_format.model_validate(fields)

            except ValidationError as e:
                # If the response doesn't match the format, we ask the agent to try again
                if retries > max_retries:
                    raise e

                retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
                reformat_message = retry_message

                retries += 1

            except JSONDecodeError as e:
                # Occasionally models will output json as a code block, which will cause a JSONDecodeError
                if retries > max_retries:
                    raise e

                retry_message = Message(type="retry",
                                        content="There was an Error with your JSON format. Make sure you are not using code blocks."
                                                "i.e. your response should be:\n{...}\n"
                                                "Instead of:\n```json\n{...}\n```\n\n Please try again.")
                reformat_message = retry_message

                retries += 1

        return output

    # How agents actually generate responses

    def _generate(self) -> str:
        """Generates a response from the Agent."""
        # This is the BaseAgent class, and thus has no response logic
        # Subclasses should implement this method to generate a response using the message history
        raise NotImplementedError


AgentInterface = NewType("AgentInterface", BaseAgentInterface)


class OpenAIAgentInterface(BaseAgentInterface):
    """An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""

    def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
        super().__init__(agent_id)
        self.model_name = model_name
        self.client = OpenAI()

    def _generate(self) -> str:
        """Generates a response using the message history"""
        open_ai_messages = [message.to_openai() for message in self.messages]

        completion = self.client.chat.completions.create(
            model=self.model_name,
            messages=open_ai_messages
        )

        return completion.choices[0].message.content


class HumanAgentInterface(BaseAgentInterface):
    is_human = True

    def generate_formatted_response(
            self,
            output_format: Type[OutputFormatModel],
            additional_fields: dict = None,
            max_retries: int = 3
    ) -> OutputFormatModel:
        """For Human agents, we can trust them enough to format their own responses... for now"""
        response = self.generate_response()
        # only works because current outputs have only 1 field...
        fields = {output_format.model_fields.copy().popitem()[0]: response.content}
        if additional_fields:
            fields.update(additional_fields)
        output = output_format.model_validate(fields)

        return output


class HumanAgentCLI(HumanAgentInterface):
    """A Human agent that uses the command line interface to generate responses."""

    def __init__(self, agent_id: str):
        super().__init__(agent_id)

    def add_message(self, message: Message):
        super().add_message(message)
        if message.type == "verbose":
            print(Fore.GREEN + message.content + Style.RESET_ALL)
        elif message.type == "debug":
            print(Fore.YELLOW + "DEBUG: " + message.content + Style.RESET_ALL)
        elif message.type != "agent":
            # Prevents the agent from seeing its own messages on the command line
            print(message.content)

    def _generate(self) -> str:
        """Generates a response using the message history"""
        response = input()
        return response