Dixing (Dex) Xu commited on
Commit
9c55a42
·
unverified ·
1 Parent(s): c92f194

:bug: handle missing function calls for openai (#35) (#38)

Browse files

* :bug: handle missing function calls for openai (#35)

* fix: black format

Files changed (1) hide show
  1. aide/backend/backend_openai.py +62 -17
aide/backend/backend_openai.py CHANGED
@@ -19,6 +19,23 @@ OPENAI_TIMEOUT_EXCEPTIONS = (
19
  openai.InternalServerError,
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @once
24
  def _setup_openai_client():
@@ -26,21 +43,41 @@ def _setup_openai_client():
26
  _client = openai.OpenAI(max_retries=0)
27
 
28
 
 
 
 
 
 
29
  def query(
30
  system_message: str | None,
31
  user_message: str | None,
32
  func_spec: FunctionSpec | None = None,
33
  **model_kwargs,
34
  ) -> tuple[OutputType, float, int, int, dict]:
 
 
 
 
35
  _setup_openai_client()
36
- filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
 
 
37
 
38
  messages = opt_messages_to_list(system_message, user_message)
39
 
40
  if func_spec is not None:
41
- filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
42
- # force the model the use the function
43
- filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
 
 
 
 
 
 
 
 
 
44
 
45
  t0 = time.time()
46
  completion = backoff_create(
@@ -53,22 +90,30 @@ def query(
53
 
54
  choice = completion.choices[0]
55
 
56
- if func_spec is None:
57
  output = choice.message.content
58
  else:
59
- assert (
60
- choice.message.tool_calls
61
- ), f"function_call is empty, it is not a function call: {choice.message}"
62
- assert (
63
- choice.message.tool_calls[0].function.name == func_spec.name
64
- ), "Function name mismatch"
65
- try:
66
- output = json.loads(choice.message.tool_calls[0].function.arguments)
67
- except json.JSONDecodeError as e:
68
- logger.error(
69
- f"Error decoding the function arguments: {choice.message.tool_calls[0].function.arguments}"
 
 
70
  )
71
- raise e
 
 
 
 
 
 
72
 
73
  in_tokens = completion.usage.prompt_tokens
74
  out_tokens = completion.usage.completion_tokens
 
19
  openai.InternalServerError,
20
  )
21
 
22
+ # (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
23
+ SUPPORTED_FUNCTION_CALL_MODELS = {
24
+ "gpt-4o",
25
+ "gpt-4o-2024-08-06",
26
+ "gpt-4o-2024-05-13",
27
+ "gpt-4o-mini",
28
+ "gpt-4o-mini-2024-07-18",
29
+ "gpt-4-turbo",
30
+ "gpt-4-turbo-2024-04-09",
31
+ "gpt-4-turbo-preview",
32
+ "gpt-4-0125-preview",
33
+ "gpt-4-1106-preview",
34
+ "gpt-3.5-turbo",
35
+ "gpt-3.5-turbo-0125",
36
+ "gpt-3.5-turbo-1106",
37
+ }
38
+
39
 
40
  @once
41
  def _setup_openai_client():
 
43
  _client = openai.OpenAI(max_retries=0)
44
 
45
 
46
+ def is_function_call_supported(model_name: str) -> bool:
47
+ """Return True if the model supports function calling."""
48
+ return model_name in SUPPORTED_FUNCTION_CALL_MODELS
49
+
50
+
51
  def query(
52
  system_message: str | None,
53
  user_message: str | None,
54
  func_spec: FunctionSpec | None = None,
55
  **model_kwargs,
56
  ) -> tuple[OutputType, float, int, int, dict]:
57
+ """
58
+ Query the OpenAI API, optionally with function calling.
59
+ Function calling support is only checked for feedback/review operations.
60
+ """
61
  _setup_openai_client()
62
+ filtered_kwargs: dict = select_values(notnone, model_kwargs)
63
+ model_name = filtered_kwargs.get("model", "")
64
+ logger.debug(f"OpenAI query called with model='{model_name}'")
65
 
66
  messages = opt_messages_to_list(system_message, user_message)
67
 
68
  if func_spec is not None:
69
+ # Only check function call support for feedback/search operations
70
+ if func_spec.name == "submit_review":
71
+ if not is_function_call_supported(model_name):
72
+ logger.warning(
73
+ f"Review function calling was requested, but model '{model_name}' "
74
+ "does not support function calling. Falling back to plain text generation."
75
+ )
76
+ filtered_kwargs.pop("tools", None)
77
+ filtered_kwargs.pop("tool_choice", None)
78
+ else:
79
+ filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
80
+ filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
81
 
82
  t0 = time.time()
83
  completion = backoff_create(
 
90
 
91
  choice = completion.choices[0]
92
 
93
+ if func_spec is None or "tools" not in filtered_kwargs:
94
  output = choice.message.content
95
  else:
96
+ tool_calls = getattr(choice.message, "tool_calls", None)
97
+
98
+ if not tool_calls:
99
+ logger.warning(
100
+ f"No function call used despite function spec. Fallback to text. "
101
+ f"Message content: {choice.message.content}"
102
+ )
103
+ output = choice.message.content
104
+ else:
105
+ first_call = tool_calls[0]
106
+ assert first_call.function.name == func_spec.name, (
107
+ f"Function name mismatch: expected {func_spec.name}, "
108
+ f"got {first_call.function.name}"
109
  )
110
+ try:
111
+ output = json.loads(first_call.function.arguments)
112
+ except json.JSONDecodeError as e:
113
+ logger.error(
114
+ f"Error decoding function arguments:\n{first_call.function.arguments}"
115
+ )
116
+ raise e
117
 
118
  in_tokens = completion.usage.prompt_tokens
119
  out_tokens = completion.usage.completion_tokens