nsarrazin HF Staff commited on
Commit
cce2203
·
unverified ·
1 Parent(s): d947276

fix(tools): improve json parsing (#1356)

Browse files

* fix(tools): improve json parsing

* lint

src/lib/server/textGeneration/tools.ts CHANGED
@@ -1,6 +1,5 @@
1
  import { ToolResultStatus, type ToolCall, type ToolResult } from "$lib/types/Tool";
2
  import { v4 as uuidV4 } from "uuid";
3
- import JSON5 from "json5";
4
  import type { BackendTool, BackendToolContext } from "../tools";
5
  import {
6
  MessageToolUpdateType,
@@ -15,7 +14,7 @@ import directlyAnswer from "../tools/directlyAnswer";
15
  import websearch from "../tools/web/search";
16
  import { z } from "zod";
17
  import { logger } from "../logger";
18
- import { toolHasName } from "../tools/utils";
19
  import type { MessageFile } from "$lib/types/Message";
20
  import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
21
  import { MetricsServer } from "../metrics";
@@ -143,33 +142,23 @@ export async function* runTools(
143
  // look for a code blocks of ```json and parse them
144
  // if they're valid json, add them to the calls array
145
  if (output.generated_text) {
146
- if (!output.generated_text.endsWith("```")) {
147
- output.generated_text = output.generated_text + "```";
148
- }
149
- const codeBlocks = Array.from(output.generated_text.matchAll(/```json\n(.*?)```/gs))
150
- .map(([, block]) => block)
151
- // remove trailing comma
152
- .map((block) => block.trim().replace(/,$/, ""));
153
- if (codeBlocks.length === 0) continue;
154
- // grab only the capture group from the regex match
155
- for (const block of codeBlocks) {
156
- // make it an array if it's not already
157
- let call = JSON5.parse(block);
158
- if (!Array.isArray(call)) {
159
- call = [call];
160
- }
161
-
162
- try {
163
- calls.push(...call.filter(isExternalToolCall).map(externalToToolCall).filter(Boolean));
164
- } catch (e) {
165
- logger.error(e, "Error while parsing tool calls, please retry");
166
- // error parsing the calls
167
- yield {
168
- type: MessageUpdateType.Status,
169
- status: MessageUpdateStatus.Error,
170
- message: "Error while parsing tool calls, please retry",
171
- };
172
- }
173
  }
174
  }
175
  }
 
1
  import { ToolResultStatus, type ToolCall, type ToolResult } from "$lib/types/Tool";
2
  import { v4 as uuidV4 } from "uuid";
 
3
  import type { BackendTool, BackendToolContext } from "../tools";
4
  import {
5
  MessageToolUpdateType,
 
14
  import websearch from "../tools/web/search";
15
  import { z } from "zod";
16
  import { logger } from "../logger";
17
+ import { extractJson, toolHasName } from "../tools/utils";
18
  import type { MessageFile } from "$lib/types/Message";
19
  import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators";
20
  import { MetricsServer } from "../metrics";
 
142
  // look for a code blocks of ```json and parse them
143
  // if they're valid json, add them to the calls array
144
  if (output.generated_text) {
145
+ logger.info(output.generated_text);
146
+ try {
147
+ const rawCalls = await extractJson(output.generated_text);
148
+ const newCalls = rawCalls
149
+ .filter(isExternalToolCall)
150
+ .map(externalToToolCall)
151
+ .filter((call) => call !== undefined) as ToolCall[];
152
+
153
+ calls.push(...newCalls);
154
+ } catch (e) {
155
+ logger.error(e, "Error while parsing tool calls, please retry");
156
+ // error parsing the calls
157
+ yield {
158
+ type: MessageUpdateType.Status,
159
+ status: MessageUpdateStatus.Error,
160
+ message: "Error while parsing tool calls, please retry",
161
+ };
 
 
 
 
 
 
 
 
 
 
162
  }
163
  }
164
  }
src/lib/server/tools/utils.ts CHANGED
@@ -1,6 +1,7 @@
1
  import { env } from "$env/dynamic/private";
2
  import { Client } from "@gradio/client";
3
  import { SignJWT } from "jose";
 
4
 
5
  export type GradioImage = {
6
  path: string;
@@ -52,3 +53,41 @@ export async function getIpToken(ip: string, username?: string) {
52
  }
53
 
54
  export { toolHasName } from "$lib/utils/tools";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import { env } from "$env/dynamic/private";
2
  import { Client } from "@gradio/client";
3
  import { SignJWT } from "jose";
4
+ import JSON5 from "json5";
5
 
6
  export type GradioImage = {
7
  path: string;
 
53
  }
54
 
55
  export { toolHasName } from "$lib/utils/tools";
56
+
57
+ export async function extractJson(text: string): Promise<unknown[]> {
58
+ const calls: string[] = [];
59
+
60
+ let codeBlocks = Array.from(text.matchAll(/```json\n(.*?)```/gs))
61
+ .map(([, block]) => block)
62
+ // remove trailing comma
63
+ .map((block) => block.trim().replace(/,$/, ""));
64
+
65
+ // if there is no code block, try to find the first json object
66
+ // by trimming the string and trying to parse with JSON5
67
+ if (codeBlocks.length === 0) {
68
+ const start = [text.indexOf("["), text.indexOf("{")]
69
+ .filter((i) => i !== -1)
70
+ .reduce((a, b) => Math.max(a, b), -Infinity);
71
+ const end = [text.lastIndexOf("]"), text.lastIndexOf("}")]
72
+ .filter((i) => i !== -1)
73
+ .reduce((a, b) => Math.min(a, b), Infinity);
74
+
75
+ if (start === -Infinity || end === Infinity) {
76
+ return [""];
77
+ }
78
+
79
+ const json = text.substring(start, end + 1);
80
+ codeBlocks = [json];
81
+ }
82
+
83
+ // grab only the capture group from the regex match
84
+ for (const block of codeBlocks) {
85
+ // make it an array if it's not already
86
+ let call = JSON5.parse(block);
87
+ if (!Array.isArray(call)) {
88
+ call = [call];
89
+ }
90
+ calls.push(call);
91
+ }
92
+ return calls.flat();
93
+ }