gladish jacobgladish nsarrazin HF Staff commited on
Commit
7dbac68
·
unverified ·
1 Parent(s): 94c06a8

feat(tools): Basic tool support for OpenAI models (#1447)

Browse files

* feat(tools): Basic tool support for OpenAI models

* feat(tools): Basic tool support for OpenAI models

* fix: tools with document input (not image)

* fix: double yield of streaming tokens

* fix: boolean parameter processing error

took me a while to find this one lol

* feat: throw error if using tools in `completions` mode

* feat: fix image processing in tools mode

* fix: remove correlationKey and clean up types

---------

Co-authored-by: Jacob Gladish <[email protected]>
Co-authored-by: Nathan Sarrazin <[email protected]>

src/lib/server/endpoints/openai/endpointOai.ts CHANGED
@@ -2,14 +2,84 @@ import { z } from "zod";
2
  import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
3
  import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
4
  import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
5
- import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
 
 
 
 
6
  import { buildPrompt } from "$lib/buildPrompt";
7
  import { env } from "$env/dynamic/private";
8
  import type { Endpoint } from "../endpoints";
9
  import type OpenAI from "openai";
10
  import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
11
  import type { MessageFile } from "$lib/types/Message";
 
12
  import type { EndpointMessage } from "../endpoints";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  export const endpointOAIParametersSchema = z.object({
15
  weight: z.number().int().positive().default(1),
@@ -57,7 +127,6 @@ export async function endpointOai(
57
  extraBody,
58
  } = endpointOAIParametersSchema.parse(input);
59
 
60
- /* eslint-disable-next-line no-shadow */
61
  let OpenAI;
62
  try {
63
  OpenAI = (await import("openai")).OpenAI;
@@ -75,6 +144,11 @@ export async function endpointOai(
75
  const imageProcessor = makeImageProcessor(multimodal.image);
76
 
77
  if (completion === "completions") {
 
 
 
 
 
78
  return async ({ messages, preprompt, continueMessage, generateSettings }) => {
79
  const prompt = await buildPrompt({
80
  messages,
@@ -102,9 +176,9 @@ export async function endpointOai(
102
  return openAICompletionToTextGenerationStream(openAICompletion);
103
  };
104
  } else if (completion === "chat_completions") {
105
- return async ({ messages, preprompt, generateSettings }) => {
106
  let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
107
- await prepareMessages(messages, imageProcessor);
108
 
109
  if (messagesOpenAI?.[0]?.role !== "system") {
110
  messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
@@ -114,7 +188,44 @@ export async function endpointOai(
114
  messagesOpenAI[0].content = preprompt ?? "";
115
  }
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  const parameters = { ...model.parameters, ...generateSettings };
 
118
  const body: ChatCompletionCreateParamsStreaming = {
119
  model: model.id ?? model.name,
120
  messages: messagesOpenAI,
@@ -124,6 +235,7 @@ export async function endpointOai(
124
  temperature: parameters?.temperature,
125
  top_p: parameters?.top_p,
126
  frequency_penalty: parameters?.repetition_penalty,
 
127
  };
128
 
129
  const openChatAICompletion = await openai.chat.completions.create(body, {
@@ -139,11 +251,12 @@ export async function endpointOai(
139
 
140
  async function prepareMessages(
141
  messages: EndpointMessage[],
142
- imageProcessor: ReturnType<typeof makeImageProcessor>
 
143
  ): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
144
  return Promise.all(
145
  messages.map(async (message) => {
146
- if (message.from === "user") {
147
  return {
148
  role: message.from,
149
  content: [
@@ -164,7 +277,9 @@ async function prepareFiles(
164
  imageProcessor: ReturnType<typeof makeImageProcessor>,
165
  files: MessageFile[]
166
  ): Promise<OpenAI.Chat.Completions.ChatCompletionContentPartImage[]> {
167
- const processedFiles = await Promise.all(files.map(imageProcessor));
 
 
168
  return processedFiles.map((file) => ({
169
  type: "image_url" as const,
170
  image_url: {
 
2
  import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
3
  import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
4
  import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
5
+ import type {
6
+ ChatCompletionCreateParamsStreaming,
7
+ ChatCompletionTool,
8
+ } from "openai/resources/chat/completions";
9
+ import type { FunctionDefinition, FunctionParameters } from "openai/resources/shared";
10
  import { buildPrompt } from "$lib/buildPrompt";
11
  import { env } from "$env/dynamic/private";
12
  import type { Endpoint } from "../endpoints";
13
  import type OpenAI from "openai";
14
  import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
15
  import type { MessageFile } from "$lib/types/Message";
16
+ import { type Tool } from "$lib/types/Tool";
17
  import type { EndpointMessage } from "../endpoints";
18
+ import { v4 as uuidv4 } from "uuid";
19
+ function createChatCompletionToolsArray(tools: Tool[] | undefined): ChatCompletionTool[] {
20
+ const toolChoices = [] as ChatCompletionTool[];
21
+ if (tools === undefined) {
22
+ return toolChoices;
23
+ }
24
+
25
+ for (const t of tools) {
26
+ const requiredProperties = [] as string[];
27
+
28
+ const properties = {} as Record<string, unknown>;
29
+ for (const idx in t.inputs) {
30
+ const parameterDefinition = t.inputs[idx];
31
+
32
+ const parameter = {} as Record<string, unknown>;
33
+ switch (parameterDefinition.type) {
34
+ case "str":
35
+ parameter.type = "string";
36
+ break;
37
+ case "float":
38
+ case "int":
39
+ parameter.type = "number";
40
+ break;
41
+ case "bool":
42
+ parameter.type = "boolean";
43
+ break;
44
+ case "file":
45
+ throw new Error("File type's currently not supported");
46
+ default:
47
+ throw new Error(`Unknown tool IO type: ${t}`);
48
+ }
49
+
50
+ if ("description" in parameterDefinition) {
51
+ parameter.description = parameterDefinition.description;
52
+ }
53
+
54
+ if (parameterDefinition.paramType == "required") {
55
+ requiredProperties.push(t.inputs[idx].name);
56
+ }
57
+
58
+ properties[t.inputs[idx].name] = parameter;
59
+ }
60
+
61
+ const functionParameters: FunctionParameters = {
62
+ type: "object",
63
+ ...(requiredProperties.length > 0 ? { required: requiredProperties } : {}),
64
+ properties,
65
+ };
66
+
67
+ const functionDefinition: FunctionDefinition = {
68
+ name: t.name,
69
+ description: t.description,
70
+ parameters: functionParameters,
71
+ };
72
+
73
+ const toolDefinition: ChatCompletionTool = {
74
+ type: "function",
75
+ function: functionDefinition,
76
+ };
77
+
78
+ toolChoices.push(toolDefinition);
79
+ }
80
+
81
+ return toolChoices;
82
+ }
83
 
84
  export const endpointOAIParametersSchema = z.object({
85
  weight: z.number().int().positive().default(1),
 
127
  extraBody,
128
  } = endpointOAIParametersSchema.parse(input);
129
 
 
130
  let OpenAI;
131
  try {
132
  OpenAI = (await import("openai")).OpenAI;
 
144
  const imageProcessor = makeImageProcessor(multimodal.image);
145
 
146
  if (completion === "completions") {
147
+ if (model.tools) {
148
+ throw new Error(
149
+ "Tools are not supported for 'completions' mode, switch to 'chat_completions' instead"
150
+ );
151
+ }
152
  return async ({ messages, preprompt, continueMessage, generateSettings }) => {
153
  const prompt = await buildPrompt({
154
  messages,
 
176
  return openAICompletionToTextGenerationStream(openAICompletion);
177
  };
178
  } else if (completion === "chat_completions") {
179
+ return async ({ messages, preprompt, generateSettings, tools, toolResults }) => {
180
  let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
181
+ await prepareMessages(messages, imageProcessor, !model.tools && model.multimodal);
182
 
183
  if (messagesOpenAI?.[0]?.role !== "system") {
184
  messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
 
188
  messagesOpenAI[0].content = preprompt ?? "";
189
  }
190
 
191
+ if (toolResults && toolResults.length > 0) {
192
+ const toolCallRequests: OpenAI.Chat.Completions.ChatCompletionAssistantMessageParam = {
193
+ role: "assistant",
194
+ content: null,
195
+ tool_calls: [],
196
+ };
197
+
198
+ const responses: Array<OpenAI.Chat.Completions.ChatCompletionToolMessageParam> = [];
199
+
200
+ for (const result of toolResults) {
201
+ const id = uuidv4();
202
+
203
+ const toolCallResult: OpenAI.Chat.Completions.ChatCompletionMessageToolCall = {
204
+ type: "function",
205
+ function: {
206
+ name: result.call.name,
207
+ arguments: JSON.stringify(result.call.parameters),
208
+ },
209
+ id,
210
+ };
211
+
212
+ toolCallRequests.tool_calls?.push(toolCallResult);
213
+ const toolCallResponse: OpenAI.Chat.Completions.ChatCompletionToolMessageParam = {
214
+ role: "tool",
215
+ content: "",
216
+ tool_call_id: id,
217
+ };
218
+ if ("outputs" in result) {
219
+ toolCallResponse.content = JSON.stringify(result.outputs);
220
+ }
221
+ responses.push(toolCallResponse);
222
+ }
223
+ messagesOpenAI.push(toolCallRequests);
224
+ messagesOpenAI.push(...responses);
225
+ }
226
+
227
  const parameters = { ...model.parameters, ...generateSettings };
228
+ const toolCallChoices = createChatCompletionToolsArray(tools);
229
  const body: ChatCompletionCreateParamsStreaming = {
230
  model: model.id ?? model.name,
231
  messages: messagesOpenAI,
 
235
  temperature: parameters?.temperature,
236
  top_p: parameters?.top_p,
237
  frequency_penalty: parameters?.repetition_penalty,
238
+ ...(toolCallChoices.length > 0 ? { tools: toolCallChoices, tool_choice: "auto" } : {}),
239
  };
240
 
241
  const openChatAICompletion = await openai.chat.completions.create(body, {
 
251
 
252
  async function prepareMessages(
253
  messages: EndpointMessage[],
254
+ imageProcessor: ReturnType<typeof makeImageProcessor>,
255
+ isMultimodal: boolean
256
  ): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
257
  return Promise.all(
258
  messages.map(async (message) => {
259
+ if (message.from === "user" && isMultimodal) {
260
  return {
261
  role: message.from,
262
  content: [
 
277
  imageProcessor: ReturnType<typeof makeImageProcessor>,
278
  files: MessageFile[]
279
  ): Promise<OpenAI.Chat.Completions.ChatCompletionContentPartImage[]> {
280
+ const processedFiles = await Promise.all(
281
+ files.filter((file) => file.mime.startsWith("image/")).map(imageProcessor)
282
+ );
283
  return processedFiles.map((file) => ({
284
  type: "image_url" as const,
285
  image_url: {
src/lib/server/endpoints/openai/openAIChatToTextGenerationStream.ts CHANGED
@@ -1,6 +1,44 @@
1
  import type { TextGenerationStreamOutput } from "@huggingface/inference";
2
  import type OpenAI from "openai";
3
  import type { Stream } from "openai/streaming";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  /**
6
  * Transform a stream of OpenAI.Chat.ChatCompletion into a stream of TextGenerationStreamOutput
@@ -10,6 +48,7 @@ export async function* openAIChatToTextGenerationStream(
10
  ) {
11
  let generatedText = "";
12
  let tokenId = 0;
 
13
  for await (const completion of completionStream) {
14
  const { choices } = completion;
15
  const content = choices[0]?.delta?.content ?? "";
@@ -28,5 +67,30 @@ export async function* openAIChatToTextGenerationStream(
28
  details: null,
29
  };
30
  yield output;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
  }
 
1
  import type { TextGenerationStreamOutput } from "@huggingface/inference";
2
  import type OpenAI from "openai";
3
  import type { Stream } from "openai/streaming";
4
+ import type { ToolCall } from "$lib/types/Tool";
5
+
6
+ type ToolCallWithParameters = {
7
+ toolCall: ToolCall;
8
+ parameterJsonString: string;
9
+ };
10
+
11
+ function prepareToolCalls(toolCallsWithParameters: ToolCallWithParameters[], tokenId: number) {
12
+ const toolCalls: ToolCall[] = [];
13
+
14
+ for (const toolCallWithParameters of toolCallsWithParameters) {
15
+ // HACK: sometimes gpt4 via azure returns the JSON with literal newlines in it
16
+ // like {\n "foo": "bar" }
17
+ const s = toolCallWithParameters.parameterJsonString.replace("\n", "");
18
+ const params = JSON.parse(s);
19
+
20
+ const toolCall = toolCallWithParameters.toolCall;
21
+ for (const name in params) {
22
+ toolCall.parameters[name] = params[name];
23
+ }
24
+
25
+ toolCalls.push(toolCall);
26
+ }
27
+
28
+ const output = {
29
+ token: {
30
+ id: tokenId,
31
+ text: "",
32
+ logprob: 0,
33
+ special: false,
34
+ toolCalls,
35
+ },
36
+ generated_text: null,
37
+ details: null,
38
+ };
39
+
40
+ return output;
41
+ }
42
 
43
  /**
44
  * Transform a stream of OpenAI.Chat.ChatCompletion into a stream of TextGenerationStreamOutput
 
48
  ) {
49
  let generatedText = "";
50
  let tokenId = 0;
51
+ const toolCalls: ToolCallWithParameters[] = [];
52
  for await (const completion of completionStream) {
53
  const { choices } = completion;
54
  const content = choices[0]?.delta?.content ?? "";
 
67
  details: null,
68
  };
69
  yield output;
70
+
71
+ const tools = completion.choices[0]?.delta?.tool_calls || [];
72
+ for (const tool of tools) {
73
+ if (tool.id) {
74
+ if (!tool.function?.name) {
75
+ throw new Error("Tool call without function name");
76
+ }
77
+ const toolCallWithParameters: ToolCallWithParameters = {
78
+ toolCall: {
79
+ name: tool.function.name,
80
+ parameters: {},
81
+ },
82
+ parameterJsonString: "",
83
+ };
84
+ toolCalls.push(toolCallWithParameters);
85
+ }
86
+
87
+ if (toolCalls.length > 0 && tool.function?.arguments) {
88
+ toolCalls[toolCalls.length - 1].parameterJsonString += tool.function.arguments;
89
+ }
90
+ }
91
+
92
+ if (choices[0]?.finish_reason === "tool_calls") {
93
+ yield prepareToolCalls(toolCalls, tokenId++);
94
+ }
95
  }
96
  }