File size: 2,770 Bytes
975158e a92f249 975158e a92f249 975158e 0c31321 975158e 0c31321 975158e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from typing import Type
import asyncio
import json
from kani.engines.openai import OpenAIEngine
from pydantic import BaseModel, ValidationError
from agents import LogMessagesKani
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
Here is the output schema:
```
{schema}
```
"""
parser_prompt = """\
The user gave the following output to the prompt:
Prompt:
{prompt}
Output:
{message}
{format_instructions}
"""
class ParserKani(LogMessagesKani):
def __init__(self, engine, *args, **kwargs):
super().__init__(engine, *args, **kwargs)
async def parse(self, prompt: str, message: str, format_model: Type[BaseModel], max_retries: int = 3, **kwargs):
format_instructions = self.get_format_instructions(format_model)
parser_instructions = parser_prompt.format(
prompt=prompt,
message=message,
format_instructions=format_instructions
)
response = await self.chat_round_str(parser_instructions, **kwargs)
try:
output = format_model.model_validate_json(response)
except ValidationError as e:
print(f"Output did not conform to the expected format: {e}")
raise e
# Clear the Chat History after successful parse
self.chat_history = []
return output
@staticmethod
def get_format_instructions(format_model: Type[BaseModel]):
schema = format_model.model_json_schema()
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
schema_str = json.dumps(reduced_schema, indent=4)
return FORMAT_INSTRUCTIONS.format(schema=schema_str)
@classmethod
def default(cls, log_filepath: str = None):
"""Default ParserKani with OpenAIEngine."""
engine = OpenAIEngine(model="gpt-3.5-turbo")
return cls(engine, log_filepath=log_filepath)
# Testing
# parser = ParserKani(engine=OpenAIEngine(model="gpt-3.5-turbo"))
#
# sample_prompt = """\
# Below are the responses from all players. Now it is time to vote. Choose from the players below who you think the Chameleon is.
# - Mallory: I am tall and have a long neck.
# - Jack: I am a herbivore and have a long neck.
# - Jill: I am a herbivore and have a long neck.
# - Bob: I am tall and have a long neck.
# - Courtney: I am tall and have a long neck.
# """
#
# sample_message = """\
# I think the Chameleon is Mallory.
# """
#
# test_output = asyncio.run(parser.parse(prompt=sample_prompt, message=sample_message, format_model=VoteModel))
#
# print(test_output)
#
# print(VoteModel.model_validate_json(test_output)) |