|
"""Request object.""" |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
ENGINE_SEP = "::" |
|
NOT_CACHE_KEYS = {"client_timeout", "batch_size"} |
|
|
|
DEFAULT_REQUEST_KEYS = { |
|
"client_timeout": ("client_timeout", 60), |
|
"batch_size": ("batch_size", 8), |
|
"run_id": ("run_id", None), |
|
} |
|
|
|
|
|
class Request(BaseModel): |
|
"""Request object.""" |
|
|
|
|
|
prompt: Union[str, List[str]] = "" |
|
|
|
|
|
engine: str = "text-ada-001" |
|
|
|
|
|
n: int = 1 |
|
|
|
|
|
client_timeout: int = 60 |
|
|
|
|
|
run_id: Optional[str] = None |
|
|
|
|
|
batch_size: int = 8 |
|
|
|
def to_dict( |
|
self, allowable_keys: Dict[str, Tuple[str, Any]] = None, add_prompt: bool = True |
|
) -> Dict[str, Any]: |
|
""" |
|
Convert request to a dictionary. |
|
|
|
Handles parameter renaming but does not fill in default values. |
|
It will drop any None values. |
|
|
|
Add prompt ensures the prompt is always in the output dictionary. |
|
""" |
|
if allowable_keys: |
|
include_keys = set(allowable_keys.keys()) |
|
if add_prompt and "prompt": |
|
include_keys.add("prompt") |
|
else: |
|
allowable_keys = {} |
|
include_keys = None |
|
request_dict = { |
|
allowable_keys.get(k, (k, None))[0]: v |
|
for k, v in self.dict(include=include_keys).items() |
|
if v is not None |
|
} |
|
return request_dict |
|
|
|
|
|
class LMRequest(Request): |
|
"""Language Model Request object.""" |
|
|
|
|
|
temperature: float = 0.7 |
|
|
|
|
|
max_tokens: int = 100 |
|
|
|
|
|
top_p: float = 1.0 |
|
|
|
|
|
top_k: int = 50 |
|
|
|
|
|
logprobs: Optional[int] = None |
|
|
|
|
|
stop_sequences: Optional[List[str]] = None |
|
|
|
|
|
num_beams: int = 1 |
|
|
|
|
|
do_sample: bool = False |
|
|
|
|
|
repetition_penalty: float = 1.0 |
|
|
|
|
|
length_penalty: float = 1.0 |
|
|
|
|
|
presence_penalty: float = 0 |
|
|
|
|
|
frequency_penalty: float = 0 |
|
|
|
|
|
class LMChatRequest(LMRequest): |
|
"""Language Model Chat Request object.""" |
|
|
|
prompt: List[Dict[str, str]] = {} |
|
|
|
|
|
class LMScoreRequest(LMRequest): |
|
"""Language Model Score Request object.""" |
|
|
|
pass |
|
|
|
|
|
class EmbeddingRequest(Request): |
|
"""Embedding Request object.""" |
|
|
|
pass |
|
|
|
|
|
class DiffusionRequest(Request): |
|
"""Diffusion Model Request object.""" |
|
|
|
|
|
num_inference_steps: int = 50 |
|
|
|
|
|
height: int = 512 |
|
|
|
|
|
width: int = 512 |
|
|
|
|
|
guidance_scale: float = 7.5 |
|
|
|
|
|
eta: float = 0.0 |
|
|