File size: 4,052 Bytes
39c930a
 
 
 
 
 
d4ec913
 
 
39c930a
 
 
d4ec913
39c930a
d4ec913
 
 
 
 
 
39c930a
9c55a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ddb3df
39c930a
 
 
d4ec913
39c930a
5ddb3df
9c55a42
 
 
 
 
39c930a
 
 
 
 
 
9c55a42
 
 
 
39c930a
9c55a42
 
 
39c930a
 
 
 
9c55a42
 
 
 
 
 
 
 
 
 
 
 
39c930a
 
d4ec913
 
 
 
 
 
39c930a
 
 
 
9c55a42
39c930a
 
9c55a42
 
 
 
 
 
 
 
 
 
 
 
 
39c930a
9c55a42
 
 
 
 
 
 
39c930a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Backend for OpenAI API."""

import json
import logging
import time

from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
import openai

logger = logging.getLogger("aide")

_client: openai.OpenAI = None  # type: ignore

OPENAI_TIMEOUT_EXCEPTIONS = (
    openai.RateLimitError,
    openai.APIConnectionError,
    openai.APITimeoutError,
    openai.InternalServerError,
)

# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
SUPPORTED_FUNCTION_CALL_MODELS = {
    "gpt-4o",
    "gpt-4o-2024-08-06",
    "gpt-4o-2024-05-13",
    "gpt-4o-mini",
    "gpt-4o-mini-2024-07-18",
    "gpt-4-turbo",
    "gpt-4-turbo-2024-04-09",
    "gpt-4-turbo-preview",
    "gpt-4-0125-preview",
    "gpt-4-1106-preview",
    "gpt-3.5-turbo",
    "gpt-3.5-turbo-0125",
    "gpt-3.5-turbo-1106",
}


@once
def _setup_openai_client():
    global _client
    _client = openai.OpenAI(max_retries=0)


def is_function_call_supported(model_name: str) -> bool:
    """Return True if the model supports function calling."""
    return model_name in SUPPORTED_FUNCTION_CALL_MODELS


def query(
    system_message: str | None,
    user_message: str | None,
    func_spec: FunctionSpec | None = None,
    **model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
    """
    Query the OpenAI API, optionally with function calling.
    Function calling support is only checked for feedback/review operations.
    """
    _setup_openai_client()
    filtered_kwargs: dict = select_values(notnone, model_kwargs)
    model_name = filtered_kwargs.get("model", "")
    logger.debug(f"OpenAI query called with model='{model_name}'")

    messages = opt_messages_to_list(system_message, user_message)

    if func_spec is not None:
        # Only check function call support for feedback/search operations
        if func_spec.name == "submit_review":
            if not is_function_call_supported(model_name):
                logger.warning(
                    f"Review function calling was requested, but model '{model_name}' "
                    "does not support function calling. Falling back to plain text generation."
                )
                filtered_kwargs.pop("tools", None)
                filtered_kwargs.pop("tool_choice", None)
            else:
                filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
                filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

    t0 = time.time()
    completion = backoff_create(
        _client.chat.completions.create,
        OPENAI_TIMEOUT_EXCEPTIONS,
        messages=messages,
        **filtered_kwargs,
    )
    req_time = time.time() - t0

    choice = completion.choices[0]

    if func_spec is None or "tools" not in filtered_kwargs:
        output = choice.message.content
    else:
        tool_calls = getattr(choice.message, "tool_calls", None)

        if not tool_calls:
            logger.warning(
                f"No function call used despite function spec. Fallback to text. "
                f"Message content: {choice.message.content}"
            )
            output = choice.message.content
        else:
            first_call = tool_calls[0]
            assert first_call.function.name == func_spec.name, (
                f"Function name mismatch: expected {func_spec.name}, "
                f"got {first_call.function.name}"
            )
            try:
                output = json.loads(first_call.function.arguments)
            except json.JSONDecodeError as e:
                logger.error(
                    f"Error decoding function arguments:\n{first_call.function.arguments}"
                )
                raise e

    in_tokens = completion.usage.prompt_tokens
    out_tokens = completion.usage.completion_tokens

    info = {
        "system_fingerprint": completion.system_fingerprint,
        "model": completion.model,
        "created": completion.created,
    }

    return output, req_time, in_tokens, out_tokens, info