|
from dataclasses import dataclass |
|
|
|
import jsonschema |
|
from dataclasses_json import DataClassJsonMixin |
|
|
|
PromptType = str | dict | list |
|
FunctionCallType = dict |
|
OutputType = str | FunctionCallType |
|
|
|
|
|
def opt_messages_to_list( |
|
system_message: str | None, user_message: str | None |
|
) -> list[dict[str, str]]: |
|
messages = [] |
|
if system_message: |
|
messages.append({"role": "system", "content": system_message}) |
|
if user_message: |
|
messages.append({"role": "user", "content": user_message}) |
|
return messages |
|
|
|
|
|
def compile_prompt_to_md(prompt: PromptType, _header_depth: int = 1) -> str: |
|
if isinstance(prompt, str): |
|
return prompt.strip() + "\n" |
|
elif isinstance(prompt, list): |
|
return "\n".join([f"- {s.strip()}" for s in prompt] + ["\n"]) |
|
|
|
out = [] |
|
header_prefix = "#" * _header_depth |
|
for k, v in prompt.items(): |
|
out.append(f"{header_prefix} {k}\n") |
|
out.append(compile_prompt_to_md(v, _header_depth=_header_depth + 1)) |
|
return "\n".join(out) |
|
|
|
|
|
@dataclass |
|
class FunctionSpec(DataClassJsonMixin): |
|
name: str |
|
json_schema: dict |
|
description: str |
|
|
|
def __post_init__(self): |
|
|
|
jsonschema.Draft7Validator.check_schema(self.json_schema) |
|
|
|
@property |
|
def as_openai_tool_dict(self): |
|
return { |
|
"type": "function", |
|
"function": { |
|
"name": self.name, |
|
"description": self.description, |
|
"parameters": self.json_schema, |
|
}, |
|
} |
|
|
|
@property |
|
def openai_tool_choice_dict(self): |
|
return { |
|
"type": "function", |
|
"function": {"name": self.name}, |
|
} |
|
|