Dixing (Dex) Xu
commited on
:bug: handle missing function calls for openai (#35) (#38)
Browse files* :bug: handle missing function calls for openai (#35)
* fix: black format
- aide/backend/backend_openai.py +62 -17
aide/backend/backend_openai.py
CHANGED
@@ -19,6 +19,23 @@ OPENAI_TIMEOUT_EXCEPTIONS = (
|
|
19 |
openai.InternalServerError,
|
20 |
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
@once
|
24 |
def _setup_openai_client():
|
@@ -26,21 +43,41 @@ def _setup_openai_client():
|
|
26 |
_client = openai.OpenAI(max_retries=0)
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
def query(
|
30 |
system_message: str | None,
|
31 |
user_message: str | None,
|
32 |
func_spec: FunctionSpec | None = None,
|
33 |
**model_kwargs,
|
34 |
) -> tuple[OutputType, float, int, int, dict]:
|
|
|
|
|
|
|
|
|
35 |
_setup_openai_client()
|
36 |
-
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
|
|
|
|
37 |
|
38 |
messages = opt_messages_to_list(system_message, user_message)
|
39 |
|
40 |
if func_spec is not None:
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
t0 = time.time()
|
46 |
completion = backoff_create(
|
@@ -53,22 +90,30 @@ def query(
|
|
53 |
|
54 |
choice = completion.choices[0]
|
55 |
|
56 |
-
if func_spec is None:
|
57 |
output = choice.message.content
|
58 |
else:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
output =
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
)
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
in_tokens = completion.usage.prompt_tokens
|
74 |
out_tokens = completion.usage.completion_tokens
|
|
|
19 |
openai.InternalServerError,
|
20 |
)
|
21 |
|
22 |
+
# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
|
23 |
+
SUPPORTED_FUNCTION_CALL_MODELS = {
|
24 |
+
"gpt-4o",
|
25 |
+
"gpt-4o-2024-08-06",
|
26 |
+
"gpt-4o-2024-05-13",
|
27 |
+
"gpt-4o-mini",
|
28 |
+
"gpt-4o-mini-2024-07-18",
|
29 |
+
"gpt-4-turbo",
|
30 |
+
"gpt-4-turbo-2024-04-09",
|
31 |
+
"gpt-4-turbo-preview",
|
32 |
+
"gpt-4-0125-preview",
|
33 |
+
"gpt-4-1106-preview",
|
34 |
+
"gpt-3.5-turbo",
|
35 |
+
"gpt-3.5-turbo-0125",
|
36 |
+
"gpt-3.5-turbo-1106",
|
37 |
+
}
|
38 |
+
|
39 |
|
40 |
@once
|
41 |
def _setup_openai_client():
|
|
|
43 |
_client = openai.OpenAI(max_retries=0)
|
44 |
|
45 |
|
46 |
+
def is_function_call_supported(model_name: str) -> bool:
|
47 |
+
"""Return True if the model supports function calling."""
|
48 |
+
return model_name in SUPPORTED_FUNCTION_CALL_MODELS
|
49 |
+
|
50 |
+
|
51 |
def query(
|
52 |
system_message: str | None,
|
53 |
user_message: str | None,
|
54 |
func_spec: FunctionSpec | None = None,
|
55 |
**model_kwargs,
|
56 |
) -> tuple[OutputType, float, int, int, dict]:
|
57 |
+
"""
|
58 |
+
Query the OpenAI API, optionally with function calling.
|
59 |
+
Function calling support is only checked for feedback/review operations.
|
60 |
+
"""
|
61 |
_setup_openai_client()
|
62 |
+
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
63 |
+
model_name = filtered_kwargs.get("model", "")
|
64 |
+
logger.debug(f"OpenAI query called with model='{model_name}'")
|
65 |
|
66 |
messages = opt_messages_to_list(system_message, user_message)
|
67 |
|
68 |
if func_spec is not None:
|
69 |
+
# Only check function call support for feedback/search operations
|
70 |
+
if func_spec.name == "submit_review":
|
71 |
+
if not is_function_call_supported(model_name):
|
72 |
+
logger.warning(
|
73 |
+
f"Review function calling was requested, but model '{model_name}' "
|
74 |
+
"does not support function calling. Falling back to plain text generation."
|
75 |
+
)
|
76 |
+
filtered_kwargs.pop("tools", None)
|
77 |
+
filtered_kwargs.pop("tool_choice", None)
|
78 |
+
else:
|
79 |
+
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
|
80 |
+
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
81 |
|
82 |
t0 = time.time()
|
83 |
completion = backoff_create(
|
|
|
90 |
|
91 |
choice = completion.choices[0]
|
92 |
|
93 |
+
if func_spec is None or "tools" not in filtered_kwargs:
|
94 |
output = choice.message.content
|
95 |
else:
|
96 |
+
tool_calls = getattr(choice.message, "tool_calls", None)
|
97 |
+
|
98 |
+
if not tool_calls:
|
99 |
+
logger.warning(
|
100 |
+
f"No function call used despite function spec. Fallback to text. "
|
101 |
+
f"Message content: {choice.message.content}"
|
102 |
+
)
|
103 |
+
output = choice.message.content
|
104 |
+
else:
|
105 |
+
first_call = tool_calls[0]
|
106 |
+
assert first_call.function.name == func_spec.name, (
|
107 |
+
f"Function name mismatch: expected {func_spec.name}, "
|
108 |
+
f"got {first_call.function.name}"
|
109 |
)
|
110 |
+
try:
|
111 |
+
output = json.loads(first_call.function.arguments)
|
112 |
+
except json.JSONDecodeError as e:
|
113 |
+
logger.error(
|
114 |
+
f"Error decoding function arguments:\n{first_call.function.arguments}"
|
115 |
+
)
|
116 |
+
raise e
|
117 |
|
118 |
in_tokens = completion.usage.prompt_tokens
|
119 |
out_tokens = completion.usage.completion_tokens
|