Dixing (Dex) Xu commited on
Commit
f92d1a2
·
unverified ·
1 Parent(s): 9c55a42

:sparkles: Add anthropic tool use (#39)

Browse files

* :sparkles: Add anthropic tool use

* fix: ruff format

aide/backend/backend_anthropic.py CHANGED
@@ -1,11 +1,14 @@
1
  """Backend for Anthropic API."""
2
 
 
3
  import time
4
 
5
  from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
6
  from funcy import notnone, once, select_values
7
  import anthropic
8
 
 
 
9
  _client: anthropic.Anthropic = None # type: ignore
10
 
11
  ANTHROPIC_TIMEOUT_EXCEPTIONS = (
@@ -15,6 +18,10 @@ ANTHROPIC_TIMEOUT_EXCEPTIONS = (
15
  anthropic.InternalServerError,
16
  )
17
 
 
 
 
 
18
 
19
  @once
20
  def _setup_anthropic_client():
@@ -28,23 +35,32 @@ def query(
28
  func_spec: FunctionSpec | None = None,
29
  **model_kwargs,
30
  ) -> tuple[OutputType, float, int, int, dict]:
 
 
 
31
  _setup_anthropic_client()
32
 
33
  filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
34
  if "max_tokens" not in filtered_kwargs:
35
  filtered_kwargs["max_tokens"] = 4096 # default for Claude models
36
 
37
- if func_spec is not None:
38
- raise NotImplementedError(
39
- "Anthropic does not support function calling for now."
40
- )
 
 
 
 
 
 
41
 
42
- # Anthropic doesn't allow not having a user messages
43
  # if we only have system msg -> use it as user msg
44
  if system_message is not None and user_message is None:
45
  system_message, user_message = user_message, system_message
46
 
47
- # Anthropic passes the system messages as a separate argument
48
  if system_message is not None:
49
  filtered_kwargs["system"] = system_message
50
 
@@ -59,14 +75,33 @@ def query(
59
  )
60
  req_time = time.time() - t0
61
 
62
- assert len(message.content) == 1 and message.content[0].type == "text"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- output: str = message.content[0].text
65
  in_tokens = message.usage.input_tokens
66
  out_tokens = message.usage.output_tokens
67
 
68
  info = {
69
  "stop_reason": message.stop_reason,
 
70
  }
71
 
72
  return output, req_time, in_tokens, out_tokens, info
 
1
  """Backend for Anthropic API."""
2
 
3
+ import logging
4
  import time
5
 
6
  from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
7
  from funcy import notnone, once, select_values
8
  import anthropic
9
 
10
+ logger = logging.getLogger("aide")
11
+
12
  _client: anthropic.Anthropic = None # type: ignore
13
 
14
  ANTHROPIC_TIMEOUT_EXCEPTIONS = (
 
18
  anthropic.InternalServerError,
19
  )
20
 
21
+ ANTHROPIC_MODEL_ALIASES = {
22
+ "claude-3.5-sonnet": "claude-3-sonnet-20241022",
23
+ }
24
+
25
 
26
  @once
27
  def _setup_anthropic_client():
 
35
  func_spec: FunctionSpec | None = None,
36
  **model_kwargs,
37
  ) -> tuple[OutputType, float, int, int, dict]:
38
+ """
39
+ Query Anthropic's API, optionally with tool use (Anthropic's equivalent to function calling).
40
+ """
41
  _setup_anthropic_client()
42
 
43
  filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
44
  if "max_tokens" not in filtered_kwargs:
45
  filtered_kwargs["max_tokens"] = 4096 # default for Claude models
46
 
47
+ model_name = filtered_kwargs.get("model", "")
48
+ logger.debug(f"Anthropic query called with model='{model_name}'")
49
+
50
+ if model_name in ANTHROPIC_MODEL_ALIASES:
51
+ model_name = ANTHROPIC_MODEL_ALIASES[model_name]
52
+
53
+ if func_spec is not None and func_spec.name == "submit_review":
54
+ filtered_kwargs["tools"] = [func_spec.as_anthropic_tool_dict]
55
+ # Force tool use
56
+ filtered_kwargs["tool_choice"] = func_spec.anthropic_tool_choice_dict
57
 
58
+ # Anthropic doesn't allow not having user messages
59
  # if we only have system msg -> use it as user msg
60
  if system_message is not None and user_message is None:
61
  system_message, user_message = user_message, system_message
62
 
63
+ # Anthropic passes system messages as a separate argument
64
  if system_message is not None:
65
  filtered_kwargs["system"] = system_message
66
 
 
75
  )
76
  req_time = time.time() - t0
77
 
78
+ # Handle tool calls if present
79
+ if (
80
+ func_spec is not None
81
+ and "tools" in filtered_kwargs
82
+ and len(message.content) > 0
83
+ and message.content[0].type == "tool_use"
84
+ ):
85
+ block = message.content[0] # This is a "ToolUseBlock"
86
+ # block has attributes: type, id, name, input
87
+ assert (
88
+ block.name == func_spec.name
89
+ ), f"Function name mismatch: expected {func_spec.name}, got {block.name}"
90
+ output = block.input # Anthropic calls the parameters "input"
91
+ else:
92
+ # For non-tool responses, ensure we have text content
93
+ assert len(message.content) == 1, "Expected single content item"
94
+ assert (
95
+ message.content[0].type == "text"
96
+ ), f"Expected text response, got {message.content[0].type}"
97
+ output = message.content[0].text
98
 
 
99
  in_tokens = message.usage.input_tokens
100
  out_tokens = message.usage.output_tokens
101
 
102
  info = {
103
  "stop_reason": message.stop_reason,
104
+ "model": message.model,
105
  }
106
 
107
  return output, req_time, in_tokens, out_tokens, info
aide/backend/utils.py CHANGED
@@ -66,6 +66,7 @@ class FunctionSpec(DataClassJsonMixin):
66
 
67
  @property
68
  def as_openai_tool_dict(self):
 
69
  return {
70
  "type": "function",
71
  "function": {
@@ -81,3 +82,20 @@ class FunctionSpec(DataClassJsonMixin):
81
  "type": "function",
82
  "function": {"name": self.name},
83
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  @property
68
  def as_openai_tool_dict(self):
69
+ """Convert to OpenAI's function format."""
70
  return {
71
  "type": "function",
72
  "function": {
 
82
  "type": "function",
83
  "function": {"name": self.name},
84
  }
85
+
86
+ @property
87
+ def as_anthropic_tool_dict(self):
88
+ """Convert to Anthropic's tool format."""
89
+ return {
90
+ "name": self.name,
91
+ "description": self.description,
92
+ "input_schema": self.json_schema, # Anthropic uses input_schema instead of parameters
93
+ }
94
+
95
+ @property
96
+ def anthropic_tool_choice_dict(self):
97
+ """Convert to Anthropic's tool choice format."""
98
+ return {
99
+ "type": "tool", # Anthropic uses "tool" instead of "function"
100
+ "name": self.name,
101
+ }