Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
fix(endpoints): fix for tool calling on hf inference with openai endpoint type (#1754)
Browse files* fix(endpoints): fix for tool calling on hf inference with openai endpoint type
* moar fix
* fix: typechecks
src/lib/server/endpoints/openai/openAIChatToTextGenerationStream.ts
CHANGED
@@ -49,10 +49,47 @@ export async function* openAIChatToTextGenerationStream(
|
|
49 |
let generatedText = "";
|
50 |
let tokenId = 0;
|
51 |
const toolCalls: ToolCallWithParameters[] = [];
|
|
|
|
|
52 |
for await (const completion of completionStream) {
|
53 |
const { choices } = completion;
|
54 |
const content = choices[0]?.delta?.content ?? "";
|
55 |
const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
if (content) {
|
57 |
generatedText = generatedText + content;
|
58 |
}
|
|
|
49 |
let generatedText = "";
|
50 |
let tokenId = 0;
|
51 |
const toolCalls: ToolCallWithParameters[] = [];
|
52 |
+
let toolBuffer = ""; // XXX: hack because tools seem broken on tgi openai endpoints?
|
53 |
+
|
54 |
for await (const completion of completionStream) {
|
55 |
const { choices } = completion;
|
56 |
const content = choices[0]?.delta?.content ?? "";
|
57 |
const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";
|
58 |
+
|
59 |
+
// if the last token is a stop and the tool buffer is not empty, yield it as a generated_text
|
60 |
+
if (choices[0]?.finish_reason === "stop" && toolBuffer.length > 0) {
|
61 |
+
yield {
|
62 |
+
token: {
|
63 |
+
id: tokenId++,
|
64 |
+
special: true,
|
65 |
+
logprob: 0,
|
66 |
+
text: "",
|
67 |
+
},
|
68 |
+
generated_text: toolBuffer,
|
69 |
+
details: null,
|
70 |
+
} as TextGenerationStreamOutput;
|
71 |
+
break;
|
72 |
+
}
|
73 |
+
|
74 |
+
// weird bug where the parameters are streamed in like this
|
75 |
+
if (choices[0]?.delta?.tool_calls) {
|
76 |
+
const calls = Array.isArray(choices[0].delta.tool_calls)
|
77 |
+
? choices[0].delta.tool_calls
|
78 |
+
: [choices[0].delta.tool_calls];
|
79 |
+
|
80 |
+
if (
|
81 |
+
calls.length === 1 &&
|
82 |
+
calls[0].index === 0 &&
|
83 |
+
calls[0].id === "" &&
|
84 |
+
calls[0].type === "function" &&
|
85 |
+
!!calls[0].function &&
|
86 |
+
calls[0].function.name === null
|
87 |
+
) {
|
88 |
+
toolBuffer += calls[0].function.arguments;
|
89 |
+
continue;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
if (content) {
|
94 |
generatedText = generatedText + content;
|
95 |
}
|
src/lib/server/textGeneration/tools.ts
CHANGED
@@ -314,6 +314,22 @@ function isValidCallObject(call: unknown): call is Record<string, unknown> {
|
|
314 |
}
|
315 |
|
316 |
function parseExternalCall(callObj: Record<string, unknown>) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
const nameFields = ["tool_name", "name"] as const;
|
318 |
const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const;
|
319 |
|
@@ -323,14 +339,14 @@ function parseExternalCall(callObj: Record<string, unknown>) {
|
|
323 |
};
|
324 |
|
325 |
for (const name of nameFields) {
|
326 |
-
if (
|
327 |
-
groupedCall.tool_name =
|
328 |
}
|
329 |
}
|
330 |
|
331 |
for (const name of parametersFields) {
|
332 |
-
if (
|
333 |
-
groupedCall.parameters =
|
334 |
}
|
335 |
}
|
336 |
|
|
|
314 |
}
|
315 |
|
316 |
function parseExternalCall(callObj: Record<string, unknown>) {
|
317 |
+
let toolCall = callObj;
|
318 |
+
if (
|
319 |
+
isValidCallObject(callObj) &&
|
320 |
+
"function" in callObj &&
|
321 |
+
isValidCallObject(callObj.function) &&
|
322 |
+
"_name" in callObj.function
|
323 |
+
) {
|
324 |
+
toolCall = {
|
325 |
+
tool_name: callObj["function"]["_name"],
|
326 |
+
parameters: {
|
327 |
+
...callObj["function"],
|
328 |
+
_name: undefined,
|
329 |
+
},
|
330 |
+
};
|
331 |
+
}
|
332 |
+
|
333 |
const nameFields = ["tool_name", "name"] as const;
|
334 |
const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const;
|
335 |
|
|
|
339 |
};
|
340 |
|
341 |
for (const name of nameFields) {
|
342 |
+
if (toolCall[name]) {
|
343 |
+
groupedCall.tool_name = toolCall[name] as string;
|
344 |
}
|
345 |
}
|
346 |
|
347 |
for (const name of parametersFields) {
|
348 |
+
if (toolCall[name]) {
|
349 |
+
groupedCall.parameters = toolCall[name] as Record<string, string>;
|
350 |
}
|
351 |
}
|
352 |
|