from . import backend_anthropic, backend_openai, backend_openrouter from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md import re import logging logger = logging.getLogger("aide") def determine_provider(model: str) -> str: if model.startswith("gpt-") or re.match(r"^o\d", model): return "openai" elif model.startswith("claude-"): return "anthropic" # all other models are handle by openrouter else: return "openrouter" provider_to_query_func = { "openai": backend_openai.query, "anthropic": backend_anthropic.query, "openrouter": backend_openrouter.query, } def query( system_message: PromptType | None, user_message: PromptType | None, model: str, temperature: float | None = None, max_tokens: int | None = None, func_spec: FunctionSpec | None = None, **model_kwargs, ) -> OutputType: """ General LLM query for various backends with a single system and user message. Supports function calling for some backends. Args: system_message (PromptType | None): Uncompiled system message (will generate a message following the OpenAI/Anthropic format) user_message (PromptType | None): Uncompiled user message (will generate a message following the OpenAI/Anthropic format) model (str): string identifier for the model to use (e.g. "gpt-4-turbo") temperature (float | None, optional): Temperature to sample at. Defaults to the model-specific default. max_tokens (int | None, optional): Maximum number of tokens to generate. Defaults to the model-specific max tokens. func_spec (FunctionSpec | None, optional): Optional FunctionSpec object defining a function call. If given, the return value will be a dict. Returns: OutputType: A string completion if func_spec is None, otherwise a dict with the function call details. """ model_kwargs = model_kwargs | { "model": model, "temperature": temperature, "max_tokens": max_tokens, } # Handle models with beta limitations # ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations if re.match(r"^o\d", model): if system_message: user_message = system_message system_message = None model_kwargs["temperature"] = 1 provider = determine_provider(model) query_func = provider_to_query_func[provider] output, req_time, in_tok_count, out_tok_count, info = query_func( system_message=compile_prompt_to_md(system_message) if system_message else None, user_message=compile_prompt_to_md(user_message) if user_message else None, func_spec=func_spec, **model_kwargs, ) return output