chameleon / src /output_formats.py
Eric Botti
enhanced output formatting
bee27cc
raw
history blame
3.36 kB
import random
from typing import Annotated, NewType, List, Optional, Type, Literal
import json
from pydantic import BaseModel, field_validator, Field
MAX_DESCRIPTION_LEN = 10
FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
Here is the output format:
{schema}
"""
FEW_SHOT_INSTRUCTIONS = """Here are a few examples of correctly formatted responses: \n
{examples}
"""
OutputFormatModel = NewType("OutputFormatModel", BaseModel)
class OutputFormat:
"""The base class for all output formats."""
format_instructions: str = FORMAT_INSTRUCTIONS
"""Instructions for formatting the output, it is combined with the JSON schema of the output format."""
few_shot_instructions: str = FEW_SHOT_INSTRUCTIONS
"""Instructions for the few shot examples, it is combined with the few shot examples."""
few_shot_examples: Optional[List[dict]] = None
"""A list of examples to be shown to the agent to help them understand the desired format of the output."""
def __init__(self, output_format_model: Type[OutputFormatModel], player_names: List[str] = None):
self.output_format_model = output_format_model
self.output_format_model.player_names = player_names
def get_format_instructions(self) -> str:
json_format = self.output_format_model().model_dump_json()
return self.format_instructions.format(schema=json_format)
def get_few_shot(self, max_examples=3):
if len(self.few_shot_examples) <= max_examples:
examples = self.few_shot_examples
else:
examples = random.sample(self.few_shot_examples, max_examples)
few_shot = "\n\n".join([f"Example Response:\n{json.dumps(example)}" for example in examples])
return self.few_shot_instructions.format(examples=few_shot)
class AnimalDescriptionFormat(BaseModel):
# Define fields of our class here
description: str = Field("A brief description of the animal")
"""A brief description of the animal"""
@field_validator('description')
@classmethod
def check_starting_character(cls, v) -> str:
if not v[0].upper() == 'I':
raise ValueError("Description must begin with 'I'")
return v
@field_validator('description')
@classmethod
def wordcount(cls, v) -> str:
count = len(v.split())
if count > MAX_DESCRIPTION_LEN:
raise ValueError(f"Animal Description must be {MAX_DESCRIPTION_LEN} words or less")
return v
class ChameleonGuessFormat(BaseModel):
animal: str = Field("The name of the animal you think the chameleon is")
@field_validator('animal')
@classmethod
def is_one_word(cls, v) -> str:
if len(v.split()) > 1:
raise ValueError("Animal's name must be one word")
return v
class HerdVoteFormat(BaseModel):
player_names: List[str] = Field([], exclude=True)
"""The names of the players in the game"""
vote: str = Field("The name of the player you are voting for")
"""The name of the player you are voting for"""
@field_validator('vote')
@classmethod
def check_player_exists(cls, v) -> str:
if v.lower() not in [player.lower() for player in cls.player_names]:
raise ValueError(f"Player {v} does not exist")
return v