Dixing (Dex) Xu commited on
Commit
9616d52
·
unverified ·
1 Parent(s): 51d09b7

:bug: better handling for function calling errors (#44)

Browse files

* remove the supported model lists
* update error handling #43
* update code block display for html

aide/backend/backend_openai.py CHANGED
@@ -19,23 +19,6 @@ OPENAI_TIMEOUT_EXCEPTIONS = (
19
  openai.InternalServerError,
20
  )
21
 
22
- # (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
23
- SUPPORTED_FUNCTION_CALL_MODELS = {
24
- "gpt-4o",
25
- "gpt-4o-2024-08-06",
26
- "gpt-4o-2024-05-13",
27
- "gpt-4o-mini",
28
- "gpt-4o-mini-2024-07-18",
29
- "gpt-4-turbo",
30
- "gpt-4-turbo-2024-04-09",
31
- "gpt-4-turbo-preview",
32
- "gpt-4-0125-preview",
33
- "gpt-4-1106-preview",
34
- "gpt-3.5-turbo",
35
- "gpt-3.5-turbo-0125",
36
- "gpt-3.5-turbo-1106",
37
- }
38
-
39
 
40
  @once
41
  def _setup_openai_client():
@@ -43,11 +26,6 @@ def _setup_openai_client():
43
  _client = openai.OpenAI(max_retries=0)
44
 
45
 
46
- def is_function_call_supported(model_name: str) -> bool:
47
- """Return True if the model supports function calling."""
48
- return model_name in SUPPORTED_FUNCTION_CALL_MODELS
49
-
50
-
51
  def query(
52
  system_message: str | None,
53
  user_message: str | None,
@@ -56,64 +34,86 @@ def query(
56
  ) -> tuple[OutputType, float, int, int, dict]:
57
  """
58
  Query the OpenAI API, optionally with function calling.
59
- Function calling support is only checked for feedback/review operations.
60
  """
61
  _setup_openai_client()
62
  filtered_kwargs: dict = select_values(notnone, model_kwargs)
63
- model_name = filtered_kwargs.get("model", "")
64
- logger.debug(f"OpenAI query called with model='{model_name}'")
65
 
 
66
  messages = opt_messages_to_list(system_message, user_message)
67
 
 
68
  if func_spec is not None:
69
- # Only check function call support for feedback/search operations
70
- if func_spec.name == "submit_review":
71
- if not is_function_call_supported(model_name):
72
- logger.warning(
73
- f"Review function calling was requested, but model '{model_name}' "
74
- "does not support function calling. Falling back to plain text generation."
75
- )
76
- filtered_kwargs.pop("tools", None)
77
- filtered_kwargs.pop("tool_choice", None)
78
- else:
79
- filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
80
- filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
81
 
 
82
  t0 = time.time()
83
- completion = backoff_create(
84
- _client.chat.completions.create,
85
- OPENAI_TIMEOUT_EXCEPTIONS,
86
- messages=messages,
87
- **filtered_kwargs,
88
- )
89
- req_time = time.time() - t0
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  choice = completion.choices[0]
92
 
 
93
  if func_spec is None or "tools" not in filtered_kwargs:
 
94
  output = choice.message.content
95
  else:
 
96
  tool_calls = getattr(choice.message, "tool_calls", None)
97
-
98
  if not tool_calls:
99
  logger.warning(
100
- f"No function call used despite function spec. Fallback to text. "
101
  f"Message content: {choice.message.content}"
102
  )
103
  output = choice.message.content
104
  else:
105
  first_call = tool_calls[0]
106
- assert first_call.function.name == func_spec.name, (
107
- f"Function name mismatch: expected {func_spec.name}, "
108
- f"got {first_call.function.name}"
109
- )
110
- try:
111
- output = json.loads(first_call.function.arguments)
112
- except json.JSONDecodeError as e:
113
- logger.error(
114
- f"Error decoding function arguments:\n{first_call.function.arguments}"
115
  )
116
- raise e
 
 
 
 
 
 
 
 
 
117
 
118
  in_tokens = completion.usage.prompt_tokens
119
  out_tokens = completion.usage.completion_tokens
 
19
  openai.InternalServerError,
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @once
24
  def _setup_openai_client():
 
26
  _client = openai.OpenAI(max_retries=0)
27
 
28
 
 
 
 
 
 
29
  def query(
30
  system_message: str | None,
31
  user_message: str | None,
 
34
  ) -> tuple[OutputType, float, int, int, dict]:
35
  """
36
  Query the OpenAI API, optionally with function calling.
37
+ If the model doesn't support function calling, gracefully degrade to text generation.
38
  """
39
  _setup_openai_client()
40
  filtered_kwargs: dict = select_values(notnone, model_kwargs)
 
 
41
 
42
+ # Convert system/user messages to the format required by the client
43
  messages = opt_messages_to_list(system_message, user_message)
44
 
45
+ # If function calling is requested, attach the function spec
46
  if func_spec is not None:
47
+ filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
48
+ filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
 
 
 
 
 
 
 
 
 
 
49
 
50
+ completion = None
51
  t0 = time.time()
 
 
 
 
 
 
 
52
 
53
+ # Attempt the API call
54
+ try:
55
+ completion = backoff_create(
56
+ _client.chat.completions.create,
57
+ OPENAI_TIMEOUT_EXCEPTIONS,
58
+ messages=messages,
59
+ **filtered_kwargs,
60
+ )
61
+ except openai.error.InvalidRequestError as e:
62
+ # Check whether the error indicates that function calling is not supported
63
+ if "function calling" in str(e).lower() or "tools" in str(e).lower():
64
+ logger.warning(
65
+ "Function calling was attempted but is not supported by this model. "
66
+ "Falling back to plain text generation."
67
+ )
68
+ # Remove function-calling parameters and retry
69
+ filtered_kwargs.pop("tools", None)
70
+ filtered_kwargs.pop("tool_choice", None)
71
+
72
+ # Retry without function calling
73
+ completion = backoff_create(
74
+ _client.chat.completions.create,
75
+ OPENAI_TIMEOUT_EXCEPTIONS,
76
+ messages=messages,
77
+ **filtered_kwargs,
78
+ )
79
+ else:
80
+ # If it's some other error, re-raise
81
+ raise
82
+
83
+ req_time = time.time() - t0
84
  choice = completion.choices[0]
85
 
86
+ # Decide how to parse the response
87
  if func_spec is None or "tools" not in filtered_kwargs:
88
+ # No function calling was ultimately used
89
  output = choice.message.content
90
  else:
91
+ # Attempt to extract tool calls
92
  tool_calls = getattr(choice.message, "tool_calls", None)
 
93
  if not tool_calls:
94
  logger.warning(
95
+ "No function call was used despite function spec. Fallback to text.\n"
96
  f"Message content: {choice.message.content}"
97
  )
98
  output = choice.message.content
99
  else:
100
  first_call = tool_calls[0]
101
+ # Optional: verify that the function name matches
102
+ if first_call.function.name != func_spec.name:
103
+ logger.warning(
104
+ f"Function name mismatch: expected {func_spec.name}, "
105
+ f"got {first_call.function.name}. Fallback to text."
 
 
 
 
106
  )
107
+ output = choice.message.content
108
+ else:
109
+ try:
110
+ output = json.loads(first_call.function.arguments)
111
+ except json.JSONDecodeError as ex:
112
+ logger.error(
113
+ "Error decoding function arguments:\n"
114
+ f"{first_call.function.arguments}"
115
+ )
116
+ raise ex
117
 
118
  in_tokens = completion.usage.prompt_tokens
119
  out_tokens = completion.usage.completion_tokens
aide/utils/tree_export.py CHANGED
@@ -38,6 +38,19 @@ def normalize_layout(layout: np.ndarray):
38
  return layout
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def cfg_to_tree_struct(cfg, jou: Journal):
42
  edges = list(get_edges(jou))
43
  layout = normalize_layout(generate_layout(len(jou), edges))
@@ -52,7 +65,7 @@ def cfg_to_tree_struct(cfg, jou: Journal):
52
  edges=edges,
53
  layout=layout.tolist(),
54
  plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
55
- code=[n.code for n in jou],
56
  term_out=[n.term_out for n in jou],
57
  analysis=[n.analysis for n in jou],
58
  exp_name=cfg.exp_name,
 
38
  return layout
39
 
40
 
41
+ def strip_code_markers(code: str) -> str:
42
+ """Remove markdown code block markers if present."""
43
+ code = code.strip()
44
+ if code.startswith("```"):
45
+ # Remove opening backticks and optional language identifier
46
+ first_newline = code.find("\n")
47
+ if first_newline != -1:
48
+ code = code[first_newline:].strip()
49
+ if code.endswith("```"):
50
+ code = code[:-3].strip()
51
+ return code
52
+
53
+
54
  def cfg_to_tree_struct(cfg, jou: Journal):
55
  edges = list(get_edges(jou))
56
  layout = normalize_layout(generate_layout(len(jou), edges))
 
65
  edges=edges,
66
  layout=layout.tolist(),
67
  plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
68
+ code=[strip_code_markers(n.code) for n in jou],
69
  term_out=[n.term_out for n in jou],
70
  analysis=[n.analysis for n in jou],
71
  exp_name=cfg.exp_name,