Dixing (Dex) Xu
commited on
:bug: better handling for function calling errors (#44)
Browse files* remove the supported model lists
* update error handling #43
* update code block display for html
- aide/backend/backend_openai.py +56 -56
- aide/utils/tree_export.py +14 -1
aide/backend/backend_openai.py
CHANGED
@@ -19,23 +19,6 @@ OPENAI_TIMEOUT_EXCEPTIONS = (
|
|
19 |
openai.InternalServerError,
|
20 |
)
|
21 |
|
22 |
-
# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
|
23 |
-
SUPPORTED_FUNCTION_CALL_MODELS = {
|
24 |
-
"gpt-4o",
|
25 |
-
"gpt-4o-2024-08-06",
|
26 |
-
"gpt-4o-2024-05-13",
|
27 |
-
"gpt-4o-mini",
|
28 |
-
"gpt-4o-mini-2024-07-18",
|
29 |
-
"gpt-4-turbo",
|
30 |
-
"gpt-4-turbo-2024-04-09",
|
31 |
-
"gpt-4-turbo-preview",
|
32 |
-
"gpt-4-0125-preview",
|
33 |
-
"gpt-4-1106-preview",
|
34 |
-
"gpt-3.5-turbo",
|
35 |
-
"gpt-3.5-turbo-0125",
|
36 |
-
"gpt-3.5-turbo-1106",
|
37 |
-
}
|
38 |
-
|
39 |
|
40 |
@once
|
41 |
def _setup_openai_client():
|
@@ -43,11 +26,6 @@ def _setup_openai_client():
|
|
43 |
_client = openai.OpenAI(max_retries=0)
|
44 |
|
45 |
|
46 |
-
def is_function_call_supported(model_name: str) -> bool:
|
47 |
-
"""Return True if the model supports function calling."""
|
48 |
-
return model_name in SUPPORTED_FUNCTION_CALL_MODELS
|
49 |
-
|
50 |
-
|
51 |
def query(
|
52 |
system_message: str | None,
|
53 |
user_message: str | None,
|
@@ -56,64 +34,86 @@ def query(
|
|
56 |
) -> tuple[OutputType, float, int, int, dict]:
|
57 |
"""
|
58 |
Query the OpenAI API, optionally with function calling.
|
59 |
-
|
60 |
"""
|
61 |
_setup_openai_client()
|
62 |
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
63 |
-
model_name = filtered_kwargs.get("model", "")
|
64 |
-
logger.debug(f"OpenAI query called with model='{model_name}'")
|
65 |
|
|
|
66 |
messages = opt_messages_to_list(system_message, user_message)
|
67 |
|
|
|
68 |
if func_spec is not None:
|
69 |
-
|
70 |
-
|
71 |
-
if not is_function_call_supported(model_name):
|
72 |
-
logger.warning(
|
73 |
-
f"Review function calling was requested, but model '{model_name}' "
|
74 |
-
"does not support function calling. Falling back to plain text generation."
|
75 |
-
)
|
76 |
-
filtered_kwargs.pop("tools", None)
|
77 |
-
filtered_kwargs.pop("tool_choice", None)
|
78 |
-
else:
|
79 |
-
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
|
80 |
-
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
81 |
|
|
|
82 |
t0 = time.time()
|
83 |
-
completion = backoff_create(
|
84 |
-
_client.chat.completions.create,
|
85 |
-
OPENAI_TIMEOUT_EXCEPTIONS,
|
86 |
-
messages=messages,
|
87 |
-
**filtered_kwargs,
|
88 |
-
)
|
89 |
-
req_time = time.time() - t0
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
choice = completion.choices[0]
|
92 |
|
|
|
93 |
if func_spec is None or "tools" not in filtered_kwargs:
|
|
|
94 |
output = choice.message.content
|
95 |
else:
|
|
|
96 |
tool_calls = getattr(choice.message, "tool_calls", None)
|
97 |
-
|
98 |
if not tool_calls:
|
99 |
logger.warning(
|
100 |
-
|
101 |
f"Message content: {choice.message.content}"
|
102 |
)
|
103 |
output = choice.message.content
|
104 |
else:
|
105 |
first_call = tool_calls[0]
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
output = json.loads(first_call.function.arguments)
|
112 |
-
except json.JSONDecodeError as e:
|
113 |
-
logger.error(
|
114 |
-
f"Error decoding function arguments:\n{first_call.function.arguments}"
|
115 |
)
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
in_tokens = completion.usage.prompt_tokens
|
119 |
out_tokens = completion.usage.completion_tokens
|
|
|
19 |
openai.InternalServerError,
|
20 |
)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
@once
|
24 |
def _setup_openai_client():
|
|
|
26 |
_client = openai.OpenAI(max_retries=0)
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
def query(
|
30 |
system_message: str | None,
|
31 |
user_message: str | None,
|
|
|
34 |
) -> tuple[OutputType, float, int, int, dict]:
|
35 |
"""
|
36 |
Query the OpenAI API, optionally with function calling.
|
37 |
+
If the model doesn't support function calling, gracefully degrade to text generation.
|
38 |
"""
|
39 |
_setup_openai_client()
|
40 |
filtered_kwargs: dict = select_values(notnone, model_kwargs)
|
|
|
|
|
41 |
|
42 |
+
# Convert system/user messages to the format required by the client
|
43 |
messages = opt_messages_to_list(system_message, user_message)
|
44 |
|
45 |
+
# If function calling is requested, attach the function spec
|
46 |
if func_spec is not None:
|
47 |
+
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
|
48 |
+
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
completion = None
|
51 |
t0 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
# Attempt the API call
|
54 |
+
try:
|
55 |
+
completion = backoff_create(
|
56 |
+
_client.chat.completions.create,
|
57 |
+
OPENAI_TIMEOUT_EXCEPTIONS,
|
58 |
+
messages=messages,
|
59 |
+
**filtered_kwargs,
|
60 |
+
)
|
61 |
+
except openai.error.InvalidRequestError as e:
|
62 |
+
# Check whether the error indicates that function calling is not supported
|
63 |
+
if "function calling" in str(e).lower() or "tools" in str(e).lower():
|
64 |
+
logger.warning(
|
65 |
+
"Function calling was attempted but is not supported by this model. "
|
66 |
+
"Falling back to plain text generation."
|
67 |
+
)
|
68 |
+
# Remove function-calling parameters and retry
|
69 |
+
filtered_kwargs.pop("tools", None)
|
70 |
+
filtered_kwargs.pop("tool_choice", None)
|
71 |
+
|
72 |
+
# Retry without function calling
|
73 |
+
completion = backoff_create(
|
74 |
+
_client.chat.completions.create,
|
75 |
+
OPENAI_TIMEOUT_EXCEPTIONS,
|
76 |
+
messages=messages,
|
77 |
+
**filtered_kwargs,
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
# If it's some other error, re-raise
|
81 |
+
raise
|
82 |
+
|
83 |
+
req_time = time.time() - t0
|
84 |
choice = completion.choices[0]
|
85 |
|
86 |
+
# Decide how to parse the response
|
87 |
if func_spec is None or "tools" not in filtered_kwargs:
|
88 |
+
# No function calling was ultimately used
|
89 |
output = choice.message.content
|
90 |
else:
|
91 |
+
# Attempt to extract tool calls
|
92 |
tool_calls = getattr(choice.message, "tool_calls", None)
|
|
|
93 |
if not tool_calls:
|
94 |
logger.warning(
|
95 |
+
"No function call was used despite function spec. Fallback to text.\n"
|
96 |
f"Message content: {choice.message.content}"
|
97 |
)
|
98 |
output = choice.message.content
|
99 |
else:
|
100 |
first_call = tool_calls[0]
|
101 |
+
# Optional: verify that the function name matches
|
102 |
+
if first_call.function.name != func_spec.name:
|
103 |
+
logger.warning(
|
104 |
+
f"Function name mismatch: expected {func_spec.name}, "
|
105 |
+
f"got {first_call.function.name}. Fallback to text."
|
|
|
|
|
|
|
|
|
106 |
)
|
107 |
+
output = choice.message.content
|
108 |
+
else:
|
109 |
+
try:
|
110 |
+
output = json.loads(first_call.function.arguments)
|
111 |
+
except json.JSONDecodeError as ex:
|
112 |
+
logger.error(
|
113 |
+
"Error decoding function arguments:\n"
|
114 |
+
f"{first_call.function.arguments}"
|
115 |
+
)
|
116 |
+
raise ex
|
117 |
|
118 |
in_tokens = completion.usage.prompt_tokens
|
119 |
out_tokens = completion.usage.completion_tokens
|
aide/utils/tree_export.py
CHANGED
@@ -38,6 +38,19 @@ def normalize_layout(layout: np.ndarray):
|
|
38 |
return layout
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def cfg_to_tree_struct(cfg, jou: Journal):
|
42 |
edges = list(get_edges(jou))
|
43 |
layout = normalize_layout(generate_layout(len(jou), edges))
|
@@ -52,7 +65,7 @@ def cfg_to_tree_struct(cfg, jou: Journal):
|
|
52 |
edges=edges,
|
53 |
layout=layout.tolist(),
|
54 |
plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
|
55 |
-
code=[n.code for n in jou],
|
56 |
term_out=[n.term_out for n in jou],
|
57 |
analysis=[n.analysis for n in jou],
|
58 |
exp_name=cfg.exp_name,
|
|
|
38 |
return layout
|
39 |
|
40 |
|
41 |
+
def strip_code_markers(code: str) -> str:
|
42 |
+
"""Remove markdown code block markers if present."""
|
43 |
+
code = code.strip()
|
44 |
+
if code.startswith("```"):
|
45 |
+
# Remove opening backticks and optional language identifier
|
46 |
+
first_newline = code.find("\n")
|
47 |
+
if first_newline != -1:
|
48 |
+
code = code[first_newline:].strip()
|
49 |
+
if code.endswith("```"):
|
50 |
+
code = code[:-3].strip()
|
51 |
+
return code
|
52 |
+
|
53 |
+
|
54 |
def cfg_to_tree_struct(cfg, jou: Journal):
|
55 |
edges = list(get_edges(jou))
|
56 |
layout = normalize_layout(generate_layout(len(jou), edges))
|
|
|
65 |
edges=edges,
|
66 |
layout=layout.tolist(),
|
67 |
plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
|
68 |
+
code=[strip_code_markers(n.code) for n in jou],
|
69 |
term_out=[n.term_out for n in jou],
|
70 |
analysis=[n.analysis for n in jou],
|
71 |
exp_name=cfg.exp_name,
|