|
from typing import Type |
|
import asyncio |
|
import json |
|
|
|
from kani.engines.openai import OpenAIEngine |
|
from pydantic import BaseModel, ValidationError |
|
|
|
from agents import LogMessagesKani |
|
|
|
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below. |
|
Here is the output schema: |
|
``` |
|
{schema} |
|
``` |
|
""" |
|
|
|
parser_prompt = """\ |
|
The user gave the following output to the prompt: |
|
Prompt: |
|
{prompt} |
|
Output: |
|
{message} |
|
|
|
{format_instructions} |
|
""" |
|
|
|
|
|
class ParserKani(LogMessagesKani): |
|
def __init__(self, engine, *args, **kwargs): |
|
super().__init__(engine, *args, **kwargs) |
|
|
|
async def parse(self, prompt: str, message: str, format_model: Type[BaseModel], max_retries: int = 3, **kwargs): |
|
format_instructions = self.get_format_instructions(format_model) |
|
|
|
parser_instructions = parser_prompt.format( |
|
prompt=prompt, |
|
message=message, |
|
format_instructions=format_instructions |
|
) |
|
|
|
response = await self.chat_round_str(parser_instructions, **kwargs) |
|
|
|
try: |
|
output = format_model.model_validate_json(response) |
|
except ValidationError as e: |
|
print(f"Output did not conform to the expected format: {e}") |
|
raise e |
|
|
|
|
|
self.chat_history = [] |
|
|
|
return output |
|
|
|
@staticmethod |
|
def get_format_instructions(format_model: Type[BaseModel]): |
|
schema = format_model.model_json_schema() |
|
|
|
reduced_schema = schema |
|
if "title" in reduced_schema: |
|
del reduced_schema["title"] |
|
if "type" in reduced_schema: |
|
del reduced_schema["type"] |
|
|
|
schema_str = json.dumps(reduced_schema, indent=4) |
|
|
|
return FORMAT_INSTRUCTIONS.format(schema=schema_str) |
|
|
|
@classmethod |
|
def default(cls, log_filepath: str = None): |
|
"""Default ParserKani with OpenAIEngine.""" |
|
engine = OpenAIEngine(model="gpt-3.5-turbo") |
|
return cls(engine, log_filepath=log_filepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|