File size: 5,044 Bytes
dfdde45 250cc97 dfdde45 5dbe83d dfdde45 5dbe83d dfdde45 250cc97 dfdde45 250cc97 dfdde45 250cc97 dfdde45 250cc97 dfdde45 bee27cc dfdde45 bee27cc dfdde45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import Type, NewType
import json
from openai import OpenAI
from colorama import Fore, Style
from pydantic import BaseModel, ValidationError
from output_formats import OutputFormat, OutputFormatModel
from message import Message, AgentMessage
from data_collection import save
class BaseAgentInterface:
"""
The interface that agents use to receive info from and interact with the game.
This is the base class and should not be used directly.
"""
is_human: bool = False
def __init__(
self,
agent_id: str = None
):
self.id = agent_id
self.messages = []
def add_message(self, message: Message):
"""Adds a message to the message history, without generating a response."""
bound_message = AgentMessage.from_message(message, self.id, len(self.messages))
save(bound_message)
self.messages.append(bound_message)
def respond_to(self, message: Message) -> Message:
"""Adds a message to the message history, and generates a response message."""
self.add_message(message)
response = Message(type="agent", content=self._generate_response())
self.add_message(response)
return response
def respond_to_formatted(self, message: Message, output_format: OutputFormat, max_retries = 3) -> OutputFormatModel:
"""Adds a message to the message history, and generates a response matching the provided format."""
initial_response = self.respond_to(message)
reformat_message = Message(type="format", content=output_format.get_format_instructions())
output = None
retries = 0
while not output and retries < max_retries:
try:
formatted_response = self.respond_to(reformat_message)
output = output_format.output_format_model.model_validate_json(formatted_response.content)
except ValidationError as e:
if retries > max_retries:
raise e
retry_message = Message(type="retry", content=f"Error formatting response: {e} \n\n Please try again.")
if output_format.few_shot_examples:
self.add_message(retry_message)
reformat_message = Message(type="few_shot", content=output_format.get_few_shot()) # not implemented
else:
reformat_message = retry_message
retries += 1
return output
def _generate_response(self) -> str:
"""Generates a response from the Agent."""
# This is the BaseAgent class, and thus has no response logic
# Subclasses should implement this method to generate a response using the message history
raise NotImplementedError
@property
def is_ai(self):
return not self.is_human
AgentInterface = NewType("AgentInterface", BaseAgentInterface)
class OpenAIAgentInterface(BaseAgentInterface):
"""An interface that uses the OpenAI API (or compatible 3rd parties) to generate responses."""
def __init__(self, agent_id: str, model_name: str = "gpt-3.5-turbo"):
super().__init__(agent_id)
self.model_name = model_name
self.client = OpenAI()
def _generate_response(self) -> str:
"""Generates a response using the message history"""
open_ai_messages = [message.to_openai() for message in self.messages]
completion = self.client.chat.completions.create(
model=self.model_name,
messages=open_ai_messages
)
return completion.choices[0].message.content
class HumanAgentInterface(BaseAgentInterface):
is_human = True
def respond_to_formatted(self, message: Message, output_format: OutputFormat, **kwargs) -> OutputFormatModel:
"""For Human agents, we can trust them enough to format their own responses... for now"""
response = super().respond_to(message)
# only works because current outputs have only 1 field...
field_name = output_format.output_format_model.model_fields.copy().popitem()[0]
output = output_format.output_format_model.model_validate({field_name: response.content})
return output
class HumanAgentCLI(HumanAgentInterface):
"""A Human agent that uses the command line interface to generate responses."""
def __init__(self, agent_id: str):
super().__init__(agent_id)
def add_message(self, message: Message):
super().add_message(message)
if message.type == "verbose":
print(Fore.GREEN + message.content + Style.RESET_ALL)
elif message.type == "debug":
print(Fore.YELLOW + "DEBUG: " + message.content + Style.RESET_ALL)
elif message.type != "agent":
# Prevents the agent from seeing its own messages on the command line
print(message.content)
def _generate_response(self) -> str:
"""Generates a response using the message history"""
response = input()
return response
|