File size: 3,895 Bytes
778c3d7 f596e58 a92f249 f596e58 a92f249 f596e58 778c3d7 f596e58 a92f249 f596e58 a92f249 5de0b8a f596e58 5de0b8a 172af0f f596e58 a92f249 5de0b8a ea658a2 172af0f f596e58 172af0f f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a f596e58 5de0b8a f596e58 ea658a2 f596e58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
from typing import Type
from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from pydantic import BaseModel
from game_utils import log
from controllers import controller_from_name
class Player:
def __init__(
self,
name: str,
controller: str,
role: str,
id: str = None,
log_filepath: str = None
):
self.name = name
self.id = id
if controller == "human":
self.controller_type = "human"
else:
self.controller_type = "ai"
self.controller = controller_from_name(controller)
self.role = role
self.messages = []
self.log_filepath = log_filepath
if log_filepath:
player_info = {
"id": self.id,
"name": self.name,
"role": self.role,
"controller": {
"name": controller,
"type": self.controller_type
}
}
log(player_info, log_filepath)
# initialize the runnables
self.generate = RunnableLambda(self._generate)
self.format_output = RunnableLambda(self._output_formatter)
async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
"""Makes the player respond to a prompt. Returns the response in the specified format."""
message = HumanMessage(content=prompt)
output = await self.generate.ainvoke(message)
if self.controller_type == "ai":
retries = 0
try:
output = await self.format_output.ainvoke({"output_format": output_format})
except ValueError as e:
if retries < max_retries:
self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
output = await self.format_output.ainvoke({"output_format": output_format})
retries += 1
else:
raise e
else:
# Convert the human message to the pydantic object format
field_name = output_format.model_fields.copy().popitem()[0] # only works because current outputs have only 1 field
output = output_format.model_validate({field_name: output.content})
return output
def add_to_history(self, message):
self.messages.append(message)
# log(message.model_dump_json(), self.log_filepath)
def _generate(self, message: HumanMessage):
"""Entry point for the Runnable generating responses, automatically logs the message."""
self.add_to_history(message)
# AI's need to be fed the whole message history, but humans can just go back and look at it
if self.controller_type == "human":
response = self.controller.invoke(message.content)
else:
response = self.controller.invoke(self.messages)
self.add_to_history(response)
return response
def _output_formatter(self, inputs: dict):
"""Formats the output of the response."""
output_format: BaseModel = inputs["output_format"]
prompt_template = PromptTemplate.from_template(
"Please rewrite your previous response using the following format: \n\n{format_instructions}"
)
parser = PydanticOutputParser(pydantic_object=output_format)
prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})
message = HumanMessage(content=prompt.text)
response = self.generate.invoke(message)
return parser.invoke(response)
|