Dixing Xu commited on
Commit
d4ec913
·
unverified ·
1 Parent(s): 21ab47d

:bug: fix model issues with beta limitation

Browse files

* use backoff instead of funcy.retry
* fix issue with o1- models (beta-limitation)

aide/backend/__init__.py CHANGED
@@ -33,6 +33,14 @@ def query(
33
  "max_tokens": max_tokens,
34
  }
35
 
 
 
 
 
 
 
 
 
36
  query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
37
  output, req_time, in_tok_count, out_tok_count, info = query_func(
38
  system_message=compile_prompt_to_md(system_message) if system_message else None,
 
33
  "max_tokens": max_tokens,
34
  }
35
 
36
+ # Handle models with beta limitations
37
+ # ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations
38
+ if model.startswith("o1-"):
39
+ if system_message:
40
+ user_message = system_message
41
+ system_message = None
42
+ model_kwargs["temperature"] = 1
43
+
44
  query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
45
  output, req_time, in_tok_count, out_tok_count, info = query_func(
46
  system_message=compile_prompt_to_md(system_message) if system_message else None,
aide/backend/backend_anthropic.py CHANGED
@@ -2,23 +2,25 @@
2
 
3
  import time
4
 
5
- from anthropic import Anthropic, RateLimitError
6
- from .utils import FunctionSpec, OutputType, opt_messages_to_list
7
- from funcy import notnone, once, retry, select_values
8
 
9
- _client: Anthropic = None # type: ignore
10
 
11
- RATELIMIT_RETRIES = 5
12
- retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore
 
 
 
 
13
 
14
 
15
  @once
16
  def _setup_anthropic_client():
17
  global _client
18
- _client = Anthropic()
19
 
20
-
21
- @retry_exp
22
  def query(
23
  system_message: str | None,
24
  user_message: str | None,
@@ -48,7 +50,12 @@ def query(
48
  messages = opt_messages_to_list(None, user_message)
49
 
50
  t0 = time.time()
51
- message = _client.messages.create(messages=messages, **filtered_kwargs) # type: ignore
 
 
 
 
 
52
  req_time = time.time() - t0
53
 
54
  assert len(message.content) == 1 and message.content[0].type == "text"
 
2
 
3
  import time
4
 
5
+ from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
6
+ from funcy import notnone, once, select_values
7
+ import anthropic
8
 
9
+ _client: anthropic.Anthropic = None # type: ignore
10
 
11
+ ANTHROPIC_TIMEOUT_EXCEPTIONS = (
12
+ anthropic.RateLimitError,
13
+ anthropic.APIConnectionError,
14
+ anthropic.APITimeoutError,
15
+ anthropic.InternalServerError,
16
+ )
17
 
18
 
19
  @once
20
  def _setup_anthropic_client():
21
  global _client
22
+ _client = anthropic.Anthropic(max_retries=0)
23
 
 
 
24
  def query(
25
  system_message: str | None,
26
  user_message: str | None,
 
50
  messages = opt_messages_to_list(None, user_message)
51
 
52
  t0 = time.time()
53
+ message = backoff_create(
54
+ _client.messages.create,
55
+ ANTHROPIC_TIMEOUT_EXCEPTIONS,
56
+ messages=messages,
57
+ **filtered_kwargs,
58
+ )
59
  req_time = time.time() - t0
60
 
61
  assert len(message.content) == 1 and message.content[0].type == "text"
aide/backend/backend_openai.py CHANGED
@@ -4,25 +4,26 @@ import json
4
  import logging
5
  import time
6
 
7
- from .utils import FunctionSpec, OutputType, opt_messages_to_list
8
- from funcy import notnone, once, retry, select_values
9
- from openai import OpenAI, RateLimitError
10
 
11
  logger = logging.getLogger("aide")
12
 
13
- _client: OpenAI = None # type: ignore
14
-
15
- RATELIMIT_RETRIES = 5
16
- retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore
17
 
 
 
 
 
 
 
18
 
19
  @once
20
  def _setup_openai_client():
21
  global _client
22
- _client = OpenAI(max_retries=3)
23
-
24
 
25
- @retry_exp
26
  def query(
27
  system_message: str | None,
28
  user_message: str | None,
@@ -40,7 +41,12 @@ def query(
40
  filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
41
 
42
  t0 = time.time()
43
- completion = _client.chat.completions.create(messages=messages, **filtered_kwargs) # type: ignore
 
 
 
 
 
44
  req_time = time.time() - t0
45
 
46
  choice = completion.choices[0]
 
4
  import logging
5
  import time
6
 
7
+ from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
8
+ from funcy import notnone, once, select_values
9
+ import openai
10
 
11
  logger = logging.getLogger("aide")
12
 
13
+ _client: openai.OpenAI = None # type: ignore
 
 
 
14
 
15
+ OPENAI_TIMEOUT_EXCEPTIONS = (
16
+ openai.RateLimitError,
17
+ openai.APIConnectionError,
18
+ openai.APITimeoutError,
19
+ openai.InternalServerError,
20
+ )
21
 
22
  @once
23
  def _setup_openai_client():
24
  global _client
25
+ _client = openai.OpenAI(max_retries=0)
 
26
 
 
27
  def query(
28
  system_message: str | None,
29
  user_message: str | None,
 
41
  filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
42
 
43
  t0 = time.time()
44
+ completion = backoff_create(
45
+ _client.chat.completions.create,
46
+ OPENAI_TIMEOUT_EXCEPTIONS,
47
+ messages=messages,
48
+ **filtered_kwargs,
49
+ )
50
  req_time = time.time() - t0
51
 
52
  choice = completion.choices[0]
aide/backend/utils.py CHANGED
@@ -8,6 +8,27 @@ FunctionCallType = dict
8
  OutputType = str | FunctionCallType
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def opt_messages_to_list(
12
  system_message: str | None, user_message: str | None
13
  ) -> list[dict[str, str]]:
 
8
  OutputType = str | FunctionCallType
9
 
10
 
11
+ import backoff
12
+ import logging
13
+ from typing import Callable
14
+
15
+ logger = logging.getLogger("aide")
16
+
17
+
18
+ @backoff.on_predicate(
19
+ wait_gen=backoff.expo,
20
+ max_value=60,
21
+ factor=1.5,
22
+ )
23
+ def backoff_create(
24
+ create_fn: Callable, retry_exceptions: list[Exception], *args, **kwargs
25
+ ):
26
+ try:
27
+ return create_fn(*args, **kwargs)
28
+ except retry_exceptions as e:
29
+ logger.info(f"Backoff exception: {e}")
30
+ return False
31
+
32
  def opt_messages_to_list(
33
  system_message: str | None, user_message: str | None
34
  ) -> list[dict[str, str]]:
requirements.txt CHANGED
@@ -88,4 +88,5 @@ pdf2image
88
  PyPDF
89
  pyocr
90
  pyarrow
91
- xlrd
 
 
88
  PyPDF
89
  pyocr
90
  pyarrow
91
+ xlrd
92
+ backoff