File size: 2,287 Bytes
250cc97
 
 
 
47b6f03
250cc97
 
 
 
 
 
 
47b6f03
 
 
 
 
 
 
 
250cc97
47b6f03
63c7598
250cc97
 
47b6f03
250cc97
47b6f03
250cc97
 
ff62b8a
 
 
 
 
 
250cc97
 
47b6f03
63c7598
250cc97
 
 
 
 
 
 
 
 
47b6f03
250cc97
 
47b6f03
bee27cc
250cc97
47b6f03
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import random
from typing import Annotated, NewType, List, Optional, Type, Literal
import json

from pydantic import BaseModel, field_validator, Field, model_validator

FORMAT_INSTRUCTIONS = """Please reformat your previous response as a JSON instance that conforms to the JSON structure below.
Here is the output format:
{schema}
"""


class OutputFormatModel(BaseModel):
    @classmethod
    def get_format_instructions(cls) -> str:
        """Returns a string with instructions on how to format the output."""
        json_format = {}
        for field in cls.model_fields:
            if not cls.model_fields[field].exclude:
                json_format[field] = cls.model_fields[field].description

        # In the future, we could instead use get_annotations() to get the field descriptions
        return FORMAT_INSTRUCTIONS.format(schema=json.dumps(json_format))


class AnimalDescriptionFormat(OutputFormatModel):
    # Define fields of our class here
    description: str = Field(description="A brief description of the animal")
    """A brief description of the animal"""

    @field_validator('description')
    @classmethod
    def check_starting_character(cls, v) -> str:
        if not v[0].upper() == 'I':
            raise ValueError("Please rewrite your description so that it begins with 'I'")
        return v


class ChameleonGuessFormat(OutputFormatModel):
    animal: str = Field(description="Name of the animal you think the Herd is in its singular form, e.g. 'animal' not 'animals'")

    @field_validator('animal')
    @classmethod
    def is_one_word(cls, v) -> str:
        if len(v.split()) > 1:
            raise ValueError("Animal's name must be one word")
        return v


class HerdVoteFormat(OutputFormatModel):
    player_names: List[str] = Field([], exclude=True)
    """The names of the players in the game"""
    vote: str = Field(description="The name of the player you are voting for")
    """The name of the player you are voting for"""

    @model_validator(mode="after")
    def check_player_exists(self) -> "HerdVoteFormat":
        if self.vote.lower() not in [player.lower() for player in self.player_names]:
            raise ValueError(f"Player {self.vote} does not exist, please vote for one of {self.player_names}")
        return self