File size: 4,052 Bytes
39c930a d4ec913 39c930a d4ec913 39c930a d4ec913 39c930a 9c55a42 5ddb3df 39c930a d4ec913 39c930a 5ddb3df 9c55a42 39c930a 9c55a42 39c930a 9c55a42 39c930a 9c55a42 39c930a d4ec913 39c930a 9c55a42 39c930a 9c55a42 39c930a 9c55a42 39c930a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""Backend for OpenAI API."""
import json
import logging
import time
from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
import openai
logger = logging.getLogger("aide")
_client: openai.OpenAI = None # type: ignore
OPENAI_TIMEOUT_EXCEPTIONS = (
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
)
# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
SUPPORTED_FUNCTION_CALL_MODELS = {
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
}
@once
def _setup_openai_client():
global _client
_client = openai.OpenAI(max_retries=0)
def is_function_call_supported(model_name: str) -> bool:
"""Return True if the model supports function calling."""
return model_name in SUPPORTED_FUNCTION_CALL_MODELS
def query(
system_message: str | None,
user_message: str | None,
func_spec: FunctionSpec | None = None,
**model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
"""
Query the OpenAI API, optionally with function calling.
Function calling support is only checked for feedback/review operations.
"""
_setup_openai_client()
filtered_kwargs: dict = select_values(notnone, model_kwargs)
model_name = filtered_kwargs.get("model", "")
logger.debug(f"OpenAI query called with model='{model_name}'")
messages = opt_messages_to_list(system_message, user_message)
if func_spec is not None:
# Only check function call support for feedback/search operations
if func_spec.name == "submit_review":
if not is_function_call_supported(model_name):
logger.warning(
f"Review function calling was requested, but model '{model_name}' "
"does not support function calling. Falling back to plain text generation."
)
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)
else:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
t0 = time.time()
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
req_time = time.time() - t0
choice = completion.choices[0]
if func_spec is None or "tools" not in filtered_kwargs:
output = choice.message.content
else:
tool_calls = getattr(choice.message, "tool_calls", None)
if not tool_calls:
logger.warning(
f"No function call used despite function spec. Fallback to text. "
f"Message content: {choice.message.content}"
)
output = choice.message.content
else:
first_call = tool_calls[0]
assert first_call.function.name == func_spec.name, (
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}"
)
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding function arguments:\n{first_call.function.arguments}"
)
raise e
in_tokens = completion.usage.prompt_tokens
out_tokens = completion.usage.completion_tokens
info = {
"system_fingerprint": completion.system_fingerprint,
"model": completion.model,
"created": completion.created,
}
return output, req_time, in_tokens, out_tokens, info
|