evalstate nsarrazin HF Staff commited on
Commit
69d5a1c
·
unverified ·
1 Parent(s): c2502a3

Anthropic Tool Support (#1594)

Browse files

* support anthropic PDF beta

* upstream merge, remove commented out console log line

* Fixing type errors.
the anthropic API does not yet include a "DocumentBlock" for
support PDFs, so an extended type has been added to the endpoint.

* changed document processor to async (matching image processor)

* use the beta api types rather than custom extension

* rudimentary tool testing

* interim commit (tool re-passing, file handling)

* remove merge error

* tidy up, isolate beta classes to utils

* anthropic tool calling support.

* improve handling of directlyAnswer tool

* fix streaming

* slight tidy up to tools flow handling

* fix: dont pass tools in final generation, instead deduce tools from tool results

---------

Co-authored-by: Nathan Sarrazin <[email protected]>

src/lib/server/endpoints/anthropic/endpointAnthropic.ts CHANGED
@@ -3,9 +3,19 @@ import type { Endpoint } from "../endpoints";
3
  import { env } from "$env/dynamic/private";
4
  import type { TextGenerationStreamOutput } from "@huggingface/inference";
5
  import { createImageProcessorOptionsValidator } from "../images";
6
- import { endpointMessagesToAnthropicMessages } from "./utils";
7
  import { createDocumentProcessorOptionsValidator } from "../document";
 
 
 
 
 
 
 
 
 
8
  import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
 
9
 
10
  export const endpointAnthropicParametersSchema = z.object({
11
  weight: z.number().int().positive().default(1),
@@ -52,23 +62,41 @@ export async function endpointAnthropic(
52
  defaultQuery,
53
  });
54
 
55
- return async ({ messages, preprompt, generateSettings }) => {
 
 
 
 
 
 
 
56
  let system = preprompt;
57
  if (messages?.[0]?.from === "system") {
58
  system = messages[0].content;
59
  }
60
 
61
  let tokenId = 0;
 
 
 
 
 
 
 
 
62
 
63
  const parameters = { ...model.parameters, ...generateSettings };
64
 
65
  return (async function* () {
66
  const stream = anthropic.messages.stream({
67
  model: model.id ?? model.name,
68
- messages: (await endpointMessagesToAnthropicMessages(
69
- messages,
70
- multimodal
71
- )) as MessageParam[],
 
 
 
72
  max_tokens: parameters?.max_new_tokens,
73
  temperature: parameters?.temperature,
74
  top_p: parameters?.top_p,
@@ -79,21 +107,40 @@ export async function endpointAnthropic(
79
  while (true) {
80
  const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);
81
 
82
- // Stream end
83
  if (result === undefined) {
84
- yield {
85
- token: {
86
- id: tokenId++,
87
- text: "",
88
- logprob: 0,
89
- special: true,
90
- },
91
- generated_text: await stream.finalText(),
92
- details: null,
93
- } satisfies TextGenerationStreamOutput;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  return;
95
  }
96
-
97
  // Text delta
98
  yield {
99
  token: {
@@ -109,3 +156,66 @@ export async function endpointAnthropic(
109
  })();
110
  };
111
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import { env } from "$env/dynamic/private";
4
  import type { TextGenerationStreamOutput } from "@huggingface/inference";
5
  import { createImageProcessorOptionsValidator } from "../images";
6
+ import { endpointMessagesToAnthropicMessages, addToolResults } from "./utils";
7
  import { createDocumentProcessorOptionsValidator } from "../document";
8
+ import type {
9
+ Tool,
10
+ ToolCall,
11
+ ToolInput,
12
+ ToolInputFile,
13
+ ToolInputFixed,
14
+ ToolInputOptional,
15
+ } from "$lib/types/Tool";
16
+ import type Anthropic from "@anthropic-ai/sdk";
17
  import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";
18
+ import directlyAnswer from "$lib/server/tools/directlyAnswer";
19
 
20
  export const endpointAnthropicParametersSchema = z.object({
21
  weight: z.number().int().positive().default(1),
 
62
  defaultQuery,
63
  });
64
 
65
+ return async ({
66
+ messages,
67
+ preprompt,
68
+ generateSettings,
69
+ conversationId,
70
+ tools = [],
71
+ toolResults = [],
72
+ }) => {
73
  let system = preprompt;
74
  if (messages?.[0]?.from === "system") {
75
  system = messages[0].content;
76
  }
77
 
78
  let tokenId = 0;
79
+ if (tools.length === 0 && toolResults.length > 0) {
80
+ const toolNames = new Set(toolResults.map((tool) => tool.call.name));
81
+ tools = Array.from(toolNames).map((name) => ({
82
+ name,
83
+ description: "",
84
+ inputs: [],
85
+ })) as unknown as Tool[];
86
+ }
87
 
88
  const parameters = { ...model.parameters, ...generateSettings };
89
 
90
  return (async function* () {
91
  const stream = anthropic.messages.stream({
92
  model: model.id ?? model.name,
93
+ tools: createAnthropicTools(tools),
94
+ tool_choice:
95
+ tools.length > 0 ? { type: "auto", disable_parallel_tool_use: false } : undefined,
96
+ messages: addToolResults(
97
+ await endpointMessagesToAnthropicMessages(messages, multimodal, conversationId),
98
+ toolResults
99
+ ) as MessageParam[],
100
  max_tokens: parameters?.max_new_tokens,
101
  temperature: parameters?.temperature,
102
  top_p: parameters?.top_p,
 
107
  while (true) {
108
  const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);
109
 
 
110
  if (result === undefined) {
111
+ if ("tool_use" === stream.receivedMessages[0].stop_reason) {
112
+ // this should really create a new "Assistant" message with the tool id in it.
113
+ const toolCalls: ToolCall[] = stream.receivedMessages[0].content
114
+ .filter(
115
+ (block): block is Anthropic.Messages.ContentBlock & { type: "tool_use" } =>
116
+ block.type === "tool_use"
117
+ )
118
+ .map((block) => ({
119
+ name: block.name,
120
+ parameters: block.input as Record<string, string | number | boolean>,
121
+ id: block.id,
122
+ }));
123
+
124
+ yield {
125
+ token: { id: tokenId, text: "", logprob: 0, special: false, toolCalls },
126
+ generated_text: null,
127
+ details: null,
128
+ };
129
+ } else {
130
+ yield {
131
+ token: {
132
+ id: tokenId++,
133
+ text: "",
134
+ logprob: 0,
135
+ special: true,
136
+ },
137
+ generated_text: await stream.finalText(),
138
+ details: null,
139
+ } satisfies TextGenerationStreamOutput;
140
+ }
141
+
142
  return;
143
  }
 
144
  // Text delta
145
  yield {
146
  token: {
 
156
  })();
157
  };
158
  }
159
+
160
+ function createAnthropicTools(tools: Tool[]): Anthropic.Messages.Tool[] {
161
+ return tools
162
+ .filter((tool) => tool.name !== directlyAnswer.name)
163
+ .map((tool) => {
164
+ const properties = tool.inputs.reduce((acc, input) => {
165
+ acc[input.name] = convertToolInputToJSONSchema(input);
166
+ return acc;
167
+ }, {} as Record<string, unknown>);
168
+
169
+ const required = tool.inputs
170
+ .filter((input) => input.paramType === "required")
171
+ .map((input) => input.name);
172
+
173
+ return {
174
+ name: tool.name,
175
+ description: tool.description,
176
+ input_schema: {
177
+ type: "object",
178
+ properties,
179
+ required: required.length > 0 ? required : undefined,
180
+ },
181
+ };
182
+ });
183
+ }
184
+
185
+ function convertToolInputToJSONSchema(input: ToolInput): Record<string, unknown> {
186
+ const baseSchema: Record<string, unknown> = {};
187
+ if ("description" in input) {
188
+ baseSchema["description"] = input.description || "";
189
+ }
190
+ switch (input.paramType) {
191
+ case "optional":
192
+ baseSchema["default"] = (input as ToolInputOptional).default;
193
+ break;
194
+ case "fixed":
195
+ baseSchema["const"] = (input as ToolInputFixed).value;
196
+ break;
197
+ }
198
+
199
+ if (input.type === "file") {
200
+ baseSchema["type"] = "string";
201
+ baseSchema["format"] = "binary";
202
+ baseSchema["mimeTypes"] = (input as ToolInputFile).mimeTypes;
203
+ } else {
204
+ switch (input.type) {
205
+ case "str":
206
+ baseSchema["type"] = "string";
207
+ break;
208
+ case "int":
209
+ baseSchema["type"] = "integer";
210
+ break;
211
+ case "float":
212
+ baseSchema["type"] = "number";
213
+ break;
214
+ case "bool":
215
+ baseSchema["type"] = "boolean";
216
+ break;
217
+ }
218
+ }
219
+
220
+ return baseSchema;
221
+ }
src/lib/server/endpoints/anthropic/utils.ts CHANGED
@@ -7,12 +7,16 @@ import type {
7
  BetaMessageParam,
8
  BetaBase64PDFBlock,
9
  } from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs";
 
 
 
10
 
11
  export async function fileToImageBlock(
12
  file: MessageFile,
13
  opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
14
  ): Promise<BetaImageBlockParam> {
15
  const processor = makeImageProcessor(opts);
 
16
  const { image, mime } = await processor(file);
17
 
18
  return {
@@ -48,7 +52,8 @@ export async function endpointMessagesToAnthropicMessages(
48
  multimodal: {
49
  image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">;
50
  document?: FileProcessorOptions<"application/pdf">;
51
- }
 
52
  ): Promise<BetaMessageParam[]> {
53
  return await Promise.all(
54
  messages
@@ -57,20 +62,59 @@ export async function endpointMessagesToAnthropicMessages(
57
  return {
58
  role: message.from,
59
  content: [
60
- ...(await Promise.all(
61
- (message.files ?? []).map(async (file) => {
62
- if (file.mime.startsWith("image/")) {
63
- return fileToImageBlock(file, multimodal.image);
64
- } else if (file.mime === "application/pdf" && multimodal.document) {
65
- return fileToDocumentBlock(file, multimodal.document);
66
- } else {
67
- throw new Error(`Unsupported file type: ${file.mime}`);
68
- }
69
- })
70
- )),
 
 
 
 
 
 
71
  { type: "text", text: message.content },
72
  ],
73
  };
74
  })
75
  );
76
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  BetaMessageParam,
8
  BetaBase64PDFBlock,
9
  } from "@anthropic-ai/sdk/resources/beta/messages/messages.mjs";
10
+ import type { ToolResult } from "$lib/types/Tool";
11
+ import { downloadFile } from "$lib/server/files/downloadFile";
12
+ import type { ObjectId } from "mongodb";
13
 
14
  export async function fileToImageBlock(
15
  file: MessageFile,
16
  opts: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">
17
  ): Promise<BetaImageBlockParam> {
18
  const processor = makeImageProcessor(opts);
19
+
20
  const { image, mime } = await processor(file);
21
 
22
  return {
 
52
  multimodal: {
53
  image: ImageProcessorOptions<"image/png" | "image/jpeg" | "image/webp">;
54
  document?: FileProcessorOptions<"application/pdf">;
55
+ },
56
+ conversationId?: ObjectId | undefined
57
  ): Promise<BetaMessageParam[]> {
58
  return await Promise.all(
59
  messages
 
62
  return {
63
  role: message.from,
64
  content: [
65
+ ...(message.from === "user"
66
+ ? await Promise.all(
67
+ (message.files ?? []).map(async (file) => {
68
+ if (file.type === "hash" && conversationId) {
69
+ file = await downloadFile(file.value, conversationId);
70
+ }
71
+
72
+ if (file.mime.startsWith("image/")) {
73
+ return fileToImageBlock(file, multimodal.image);
74
+ } else if (file.mime === "application/pdf" && multimodal.document) {
75
+ return fileToDocumentBlock(file, multimodal.document);
76
+ } else {
77
+ throw new Error(`Unsupported file type: ${file.mime}`);
78
+ }
79
+ })
80
+ )
81
+ : []),
82
  { type: "text", text: message.content },
83
  ],
84
  };
85
  })
86
  );
87
  }
88
+
89
+ export function addToolResults(
90
+ messages: BetaMessageParam[],
91
+ toolResults: ToolResult[]
92
+ ): BetaMessageParam[] {
93
+ const id = crypto.randomUUID();
94
+ if (toolResults.length === 0) {
95
+ return messages;
96
+ }
97
+ return [
98
+ ...messages,
99
+ {
100
+ role: "assistant",
101
+ content: toolResults.map((result, index) => ({
102
+ type: "tool_use",
103
+ id: `tool_${index}_${id}`,
104
+ name: result.call.name,
105
+ input: result.call.parameters,
106
+ })),
107
+ },
108
+ {
109
+ role: "user",
110
+ content: toolResults.map((result, index) => ({
111
+ type: "tool_result",
112
+ tool_use_id: `tool_${index}_${id}`,
113
+ is_error: result.status === "error",
114
+ content: JSON.stringify(
115
+ result.status === "error" ? result.message : "outputs" in result ? result.outputs : ""
116
+ ),
117
+ })),
118
+ },
119
+ ];
120
+ }
src/lib/server/textGeneration/generate.ts CHANGED
@@ -1,4 +1,4 @@
1
- import type { ToolResult } from "$lib/types/Tool";
2
  import {
3
  MessageReasoningUpdateType,
4
  MessageUpdateType,
@@ -16,7 +16,8 @@ type GenerateContext = Omit<TextGenerationContext, "messages"> & { messages: End
16
  export async function* generate(
17
  { model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
18
  toolResults: ToolResult[],
19
- preprompt?: string
 
20
  ): AsyncIterable<MessageUpdate> {
21
  // reasoning mode is false by default
22
  let reasoning = false;
@@ -43,6 +44,7 @@ export async function* generate(
43
  preprompt,
44
  continueMessage: isContinue,
45
  generateSettings: assistant?.generateSettings,
 
46
  toolResults,
47
  isMultimodal: model.multimodal,
48
  conversationId: conv._id,
 
1
+ import type { ToolResult, Tool } from "$lib/types/Tool";
2
  import {
3
  MessageReasoningUpdateType,
4
  MessageUpdateType,
 
16
  export async function* generate(
17
  { model, endpoint, conv, messages, assistant, isContinue, promptedAt }: GenerateContext,
18
  toolResults: ToolResult[],
19
+ preprompt?: string,
20
+ tools?: Tool[]
21
  ): AsyncIterable<MessageUpdate> {
22
  // reasoning mode is false by default
23
  let reasoning = false;
 
44
  preprompt,
45
  continueMessage: isContinue,
46
  generateSettings: assistant?.generateSettings,
47
+ tools,
48
  toolResults,
49
  isMultimodal: model.multimodal,
50
  conversationId: conv._id,
src/lib/server/textGeneration/index.ts CHANGED
@@ -20,6 +20,7 @@ import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
20
  import type { TextGenerationContext } from "./types";
21
  import type { ToolResult } from "$lib/types/Tool";
22
  import { toolHasName } from "../tools/utils";
 
23
 
24
  async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, undefined, undefined> {
25
  while (!done.aborted) {
@@ -73,11 +74,13 @@ async function* textGenerationWithoutTitle(
73
  }
74
 
75
  let toolResults: ToolResult[] = [];
 
76
 
77
- if (model.tools) {
78
- const tools = await getTools(toolsPreference, ctx.assistant);
79
- const toolCallsRequired = tools.some((tool) => !toolHasName("directly_answer", tool));
80
- if (toolCallsRequired) toolResults = yield* runTools(ctx, tools, preprompt);
 
81
  }
82
 
83
  const processedMessages = await preprocessMessages(messages, webSearchResult, convId);
 
20
  import type { TextGenerationContext } from "./types";
21
  import type { ToolResult } from "$lib/types/Tool";
22
  import { toolHasName } from "../tools/utils";
23
+ import directlyAnswer from "../tools/directlyAnswer";
24
 
25
  async function* keepAlive(done: AbortSignal): AsyncGenerator<MessageUpdate, undefined, undefined> {
26
  while (!done.aborted) {
 
74
  }
75
 
76
  let toolResults: ToolResult[] = [];
77
+ let tools = model.tools ? await getTools(toolsPreference, ctx.assistant) : undefined;
78
 
79
+ if (tools) {
80
+ const toolCallsRequired = tools.some((tool) => !toolHasName(directlyAnswer.name, tool));
81
+ if (toolCallsRequired) {
82
+ toolResults = yield* runTools(ctx, tools, preprompt);
83
+ } else tools = undefined;
84
  }
85
 
86
  const processedMessages = await preprocessMessages(messages, webSearchResult, convId);
src/lib/server/textGeneration/tools.ts CHANGED
@@ -213,7 +213,7 @@ export async function* runTools(
213
  }
214
 
215
  // if we dont see a tool call in the first 25 chars, something is going wrong and we abort
216
- if (rawText.length > 25 && !(rawText.includes("```json") || rawText.includes("{"))) {
217
  return [];
218
  }
219
 
 
213
  }
214
 
215
  // if we dont see a tool call in the first 25 chars, something is going wrong and we abort
216
+ if (rawText.length > 100 && !(rawText.includes("```json") || rawText.includes("{"))) {
217
  return [];
218
  }
219