|
from . import backend_anthropic, backend_openai, backend_openrouter |
|
from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md |
|
import re |
|
import logging |
|
|
|
logger = logging.getLogger("aide") |
|
|
|
|
|
def determine_provider(model: str) -> str: |
|
if model.startswith("gpt-") or re.match(r"^o\d", model): |
|
return "openai" |
|
elif model.startswith("claude-"): |
|
return "anthropic" |
|
|
|
else: |
|
return "openrouter" |
|
|
|
|
|
provider_to_query_func = { |
|
"openai": backend_openai.query, |
|
"anthropic": backend_anthropic.query, |
|
"openrouter": backend_openrouter.query, |
|
} |
|
|
|
|
|
def query( |
|
system_message: PromptType | None, |
|
user_message: PromptType | None, |
|
model: str, |
|
temperature: float | None = None, |
|
max_tokens: int | None = None, |
|
func_spec: FunctionSpec | None = None, |
|
**model_kwargs, |
|
) -> OutputType: |
|
""" |
|
General LLM query for various backends with a single system and user message. |
|
Supports function calling for some backends. |
|
|
|
Args: |
|
system_message (PromptType | None): Uncompiled system message (will generate a message following the OpenAI/Anthropic format) |
|
user_message (PromptType | None): Uncompiled user message (will generate a message following the OpenAI/Anthropic format) |
|
model (str): string identifier for the model to use (e.g. "gpt-4-turbo") |
|
temperature (float | None, optional): Temperature to sample at. Defaults to the model-specific default. |
|
max_tokens (int | None, optional): Maximum number of tokens to generate. Defaults to the model-specific max tokens. |
|
func_spec (FunctionSpec | None, optional): Optional FunctionSpec object defining a function call. If given, the return value will be a dict. |
|
|
|
Returns: |
|
OutputType: A string completion if func_spec is None, otherwise a dict with the function call details. |
|
""" |
|
|
|
model_kwargs = model_kwargs | { |
|
"model": model, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens, |
|
} |
|
|
|
|
|
|
|
if re.match(r"^o\d", model): |
|
if system_message: |
|
user_message = system_message |
|
system_message = None |
|
model_kwargs["temperature"] = 1 |
|
|
|
provider = determine_provider(model) |
|
query_func = provider_to_query_func[provider] |
|
output, req_time, in_tok_count, out_tok_count, info = query_func( |
|
system_message=compile_prompt_to_md(system_message) if system_message else None, |
|
user_message=compile_prompt_to_md(user_message) if user_message else None, |
|
func_spec=func_spec, |
|
**model_kwargs, |
|
) |
|
|
|
return output |
|
|