nsarrazin HF Staff commited on
Commit
185c2ff
·
unverified ·
1 Parent(s): c7c1fe9

fix(endpoints): fix for tool calling on hf inference with openai endpoint type (#1754)

Browse files

* fix(endpoints): fix for tool calling on hf inference with openai endpoint type

* moar fix

* fix: typechecks

src/lib/server/endpoints/openai/openAIChatToTextGenerationStream.ts CHANGED
@@ -49,10 +49,47 @@ export async function* openAIChatToTextGenerationStream(
49
  let generatedText = "";
50
  let tokenId = 0;
51
  const toolCalls: ToolCallWithParameters[] = [];
 
 
52
  for await (const completion of completionStream) {
53
  const { choices } = completion;
54
  const content = choices[0]?.delta?.content ?? "";
55
  const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if (content) {
57
  generatedText = generatedText + content;
58
  }
 
49
  let generatedText = "";
50
  let tokenId = 0;
51
  const toolCalls: ToolCallWithParameters[] = [];
52
+ let toolBuffer = ""; // XXX: hack because tools seem broken on tgi openai endpoints?
53
+
54
  for await (const completion of completionStream) {
55
  const { choices } = completion;
56
  const content = choices[0]?.delta?.content ?? "";
57
  const last = choices[0]?.finish_reason === "stop" || choices[0]?.finish_reason === "length";
58
+
59
+ // if the last token is a stop and the tool buffer is not empty, yield it as a generated_text
60
+ if (choices[0]?.finish_reason === "stop" && toolBuffer.length > 0) {
61
+ yield {
62
+ token: {
63
+ id: tokenId++,
64
+ special: true,
65
+ logprob: 0,
66
+ text: "",
67
+ },
68
+ generated_text: toolBuffer,
69
+ details: null,
70
+ } as TextGenerationStreamOutput;
71
+ break;
72
+ }
73
+
74
+ // weird bug where the parameters are streamed in like this
75
+ if (choices[0]?.delta?.tool_calls) {
76
+ const calls = Array.isArray(choices[0].delta.tool_calls)
77
+ ? choices[0].delta.tool_calls
78
+ : [choices[0].delta.tool_calls];
79
+
80
+ if (
81
+ calls.length === 1 &&
82
+ calls[0].index === 0 &&
83
+ calls[0].id === "" &&
84
+ calls[0].type === "function" &&
85
+ !!calls[0].function &&
86
+ calls[0].function.name === null
87
+ ) {
88
+ toolBuffer += calls[0].function.arguments;
89
+ continue;
90
+ }
91
+ }
92
+
93
  if (content) {
94
  generatedText = generatedText + content;
95
  }
src/lib/server/textGeneration/tools.ts CHANGED
@@ -314,6 +314,22 @@ function isValidCallObject(call: unknown): call is Record<string, unknown> {
314
  }
315
 
316
  function parseExternalCall(callObj: Record<string, unknown>) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  const nameFields = ["tool_name", "name"] as const;
318
  const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const;
319
 
@@ -323,14 +339,14 @@ function parseExternalCall(callObj: Record<string, unknown>) {
323
  };
324
 
325
  for (const name of nameFields) {
326
- if (callObj[name]) {
327
- groupedCall.tool_name = callObj[name] as string;
328
  }
329
  }
330
 
331
  for (const name of parametersFields) {
332
- if (callObj[name]) {
333
- groupedCall.parameters = callObj[name] as Record<string, string>;
334
  }
335
  }
336
 
 
314
  }
315
 
316
  function parseExternalCall(callObj: Record<string, unknown>) {
317
+ let toolCall = callObj;
318
+ if (
319
+ isValidCallObject(callObj) &&
320
+ "function" in callObj &&
321
+ isValidCallObject(callObj.function) &&
322
+ "_name" in callObj.function
323
+ ) {
324
+ toolCall = {
325
+ tool_name: callObj["function"]["_name"],
326
+ parameters: {
327
+ ...callObj["function"],
328
+ _name: undefined,
329
+ },
330
+ };
331
+ }
332
+
333
  const nameFields = ["tool_name", "name"] as const;
334
  const parametersFields = ["parameters", "arguments", "parameter_definitions"] as const;
335
 
 
339
  };
340
 
341
  for (const name of nameFields) {
342
+ if (toolCall[name]) {
343
+ groupedCall.tool_name = toolCall[name] as string;
344
  }
345
  }
346
 
347
  for (const name of parametersFields) {
348
+ if (toolCall[name]) {
349
+ groupedCall.parameters = toolCall[name] as Record<string, string>;
350
  }
351
  }
352