File size: 3,895 Bytes
778c3d7
f596e58
a92f249
f596e58
a92f249
f596e58
 
 
778c3d7
f596e58
a92f249
f596e58
 
a92f249
5de0b8a
f596e58
 
 
 
 
 
 
 
5de0b8a
172af0f
f596e58
 
 
 
 
 
 
a92f249
5de0b8a
 
 
ea658a2
 
172af0f
 
 
 
 
f596e58
 
 
 
172af0f
 
 
f596e58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5de0b8a
f596e58
5de0b8a
f596e58
5de0b8a
f596e58
5de0b8a
f596e58
ea658a2
f596e58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from typing import Type

from langchain_core.runnables import Runnable, RunnableParallel, RunnableLambda, chain

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage

from pydantic import BaseModel

from game_utils import log
from controllers import controller_from_name

class Player:
    def __init__(
            self,
            name: str,
            controller: str,
            role: str,
            id: str = None,
            log_filepath: str = None
    ):
        self.name = name
        self.id = id

        if controller == "human":
            self.controller_type = "human"
        else:
            self.controller_type = "ai"

        self.controller = controller_from_name(controller)

        self.role = role
        self.messages = []

        self.log_filepath = log_filepath

        if log_filepath:
            player_info = {
                "id": self.id,
                "name": self.name,
                "role": self.role,
                "controller": {
                    "name": controller,
                    "type": self.controller_type
                }
            }
            log(player_info, log_filepath)

        # initialize the runnables
        self.generate = RunnableLambda(self._generate)
        self.format_output = RunnableLambda(self._output_formatter)

    async def respond_to(self, prompt: str, output_format: Type[BaseModel], max_retries=3):
        """Makes the player respond to a prompt. Returns the response in the specified format."""
        message = HumanMessage(content=prompt)
        output = await self.generate.ainvoke(message)
        if self.controller_type == "ai":
            retries = 0
            try:
                output = await self.format_output.ainvoke({"output_format": output_format})
            except ValueError as e:
                if retries < max_retries:
                    self.add_to_history(HumanMessage(content=f"Error formatting response: {e} \n\n Please try again."))
                    output = await self.format_output.ainvoke({"output_format": output_format})
                    retries += 1
                else:
                    raise e
        else:
            # Convert the human message to the pydantic object format
            field_name = output_format.model_fields.copy().popitem()[0]  # only works because current outputs have only 1 field
            output = output_format.model_validate({field_name: output.content})

        return output

    def add_to_history(self, message):
        self.messages.append(message)
        # log(message.model_dump_json(), self.log_filepath)

    def _generate(self, message: HumanMessage):
        """Entry point for the Runnable generating responses, automatically logs the message."""
        self.add_to_history(message)

        # AI's need to be fed the whole message history, but humans can just go back and look at it
        if self.controller_type == "human":
            response = self.controller.invoke(message.content)
        else:
            response = self.controller.invoke(self.messages)

        self.add_to_history(response)

        return response

    def _output_formatter(self, inputs: dict):
        """Formats the output of the response."""
        output_format: BaseModel = inputs["output_format"]

        prompt_template = PromptTemplate.from_template(
            "Please rewrite your previous response using the following format: \n\n{format_instructions}"
        )

        parser = PydanticOutputParser(pydantic_object=output_format)

        prompt = prompt_template.invoke({"format_instructions": parser.get_format_instructions()})

        message = HumanMessage(content=prompt.text)

        response = self.generate.invoke(message)

        return parser.invoke(response)