File size: 2,770 Bytes
975158e
 
 
 
 
 
 
a92f249
 
975158e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a92f249
975158e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c31321
975158e
 
0c31321
975158e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Type
import asyncio
import json

from kani.engines.openai import OpenAIEngine
from pydantic import BaseModel, ValidationError

from agents import LogMessagesKani

FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
Here is the output schema:
```
{schema}
```
"""

parser_prompt = """\
The user gave the following output to the prompt:
Prompt: 
{prompt}
Output:
{message}

{format_instructions}
"""


class ParserKani(LogMessagesKani):
    def __init__(self, engine, *args, **kwargs):
        super().__init__(engine, *args, **kwargs)

    async def parse(self, prompt: str, message: str, format_model: Type[BaseModel], max_retries: int = 3, **kwargs):
        format_instructions = self.get_format_instructions(format_model)

        parser_instructions = parser_prompt.format(
            prompt=prompt,
            message=message,
            format_instructions=format_instructions
        )

        response = await self.chat_round_str(parser_instructions, **kwargs)

        try:
            output = format_model.model_validate_json(response)
        except ValidationError as e:
            print(f"Output did not conform to the expected format: {e}")
            raise e

        # Clear the Chat History after successful parse
        self.chat_history = []

        return output

    @staticmethod
    def get_format_instructions(format_model: Type[BaseModel]):
        schema = format_model.model_json_schema()

        reduced_schema = schema
        if "title" in reduced_schema:
            del reduced_schema["title"]
        if "type" in reduced_schema:
            del reduced_schema["type"]

        schema_str = json.dumps(reduced_schema, indent=4)

        return FORMAT_INSTRUCTIONS.format(schema=schema_str)

    @classmethod
    def default(cls, log_filepath: str = None):
        """Default ParserKani with OpenAIEngine."""
        engine = OpenAIEngine(model="gpt-3.5-turbo")
        return cls(engine, log_filepath=log_filepath)



# Testing
# parser = ParserKani(engine=OpenAIEngine(model="gpt-3.5-turbo"))
#
# sample_prompt = """\
# Below are the responses from all players. Now it is time to vote. Choose from the players below who you think the Chameleon is.
# - Mallory: I am tall and have a long neck.
# - Jack: I am a herbivore and have a long neck.
# - Jill: I am a herbivore and have a long neck.
# - Bob: I am tall and have a long neck.
# - Courtney: I am tall and have a long neck.
# """
#
# sample_message = """\
# I think the Chameleon is Mallory.
# """
#
# test_output = asyncio.run(parser.parse(prompt=sample_prompt, message=sample_message, format_model=VoteModel))
#
# print(test_output)
#
# print(VoteModel.model_validate_json(test_output))