nsarrazin HF Staff commited on
Commit
15baec4
·
unverified ·
1 Parent(s): 63edbab

feat(tools): use gradio ETA for tools (#1411)

Browse files
src/lib/components/chat/ToolUpdate.svelte CHANGED
@@ -9,7 +9,7 @@
9
  import CarbonTools from "~icons/carbon/tools";
10
  import { ToolResultStatus, type ToolFront } from "$lib/types/Tool";
11
  import { page } from "$app/stores";
12
- import { onMount } from "svelte";
13
  import { browser } from "$app/environment";
14
 
15
  export let tool: MessageToolUpdate[];
@@ -19,6 +19,8 @@
19
  $: toolError = tool.some(isMessageToolErrorUpdate);
20
  $: toolDone = tool.some(isMessageToolResultUpdate);
21
 
 
 
22
  const availableTools: ToolFront[] = $page.data.tools;
23
 
24
  let loadingBarEl: HTMLDivElement;
@@ -26,16 +28,24 @@
26
 
27
  let isShowingLoadingBar = false;
28
 
29
- onMount(() => {
30
- if (!toolError && !toolDone && loading && loadingBarEl) {
 
 
 
 
31
  loadingBarEl.classList.remove("hidden");
32
  isShowingLoadingBar = true;
33
  animation = loadingBarEl.animate([{ width: "0%" }, { width: "calc(100%+1rem)" }], {
34
- duration: availableTools.find((tool) => tool.name === toolFnName)?.timeToUseMS,
35
  fill: "forwards",
36
  });
 
 
 
 
 
37
  }
38
- return () => animation?.cancel();
39
  });
40
 
41
  // go to 100% quickly if loading is done
 
9
  import CarbonTools from "~icons/carbon/tools";
10
  import { ToolResultStatus, type ToolFront } from "$lib/types/Tool";
11
  import { page } from "$app/stores";
12
+ import { onDestroy } from "svelte";
13
  import { browser } from "$app/environment";
14
 
15
  export let tool: MessageToolUpdate[];
 
19
  $: toolError = tool.some(isMessageToolErrorUpdate);
20
  $: toolDone = tool.some(isMessageToolResultUpdate);
21
 
22
+ $: eta = tool.find((el) => el.subtype === MessageToolUpdateType.ETA)?.eta;
23
+
24
  const availableTools: ToolFront[] = $page.data.tools;
25
 
26
  let loadingBarEl: HTMLDivElement;
 
28
 
29
  let isShowingLoadingBar = false;
30
 
31
+ $: !toolError &&
32
+ !toolDone &&
33
+ loading &&
34
+ loadingBarEl &&
35
+ eta &&
36
+ (() => {
37
  loadingBarEl.classList.remove("hidden");
38
  isShowingLoadingBar = true;
39
  animation = loadingBarEl.animate([{ width: "0%" }, { width: "calc(100%+1rem)" }], {
40
+ duration: eta * 1000,
41
  fill: "forwards",
42
  });
43
+ })();
44
+
45
+ onDestroy(() => {
46
+ if (animation) {
47
+ animation.cancel();
48
  }
 
49
  });
50
 
51
  // go to 100% quickly if loading is done
src/lib/server/textGeneration/tools.ts CHANGED
@@ -70,7 +70,7 @@ async function* callTool(
70
  };
71
 
72
  try {
73
- const toolResult = yield* tool.call(call.parameters, ctx);
74
 
75
  yield {
76
  type: MessageUpdateType.Tool,
 
70
  };
71
 
72
  try {
73
+ const toolResult = yield* tool.call(call.parameters, ctx, uuid);
74
 
75
  yield {
76
  type: MessageUpdateType.Tool,
src/lib/server/tools/index.ts CHANGED
@@ -119,7 +119,7 @@ export const configTools = z
119
  .transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
120
 
121
  export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
122
- return async function* (params, ctx) {
123
  if (
124
  tool.endpoint === null ||
125
  !tool.baseUrl ||
@@ -203,11 +203,12 @@ export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
203
  }
204
  });
205
 
206
- const outputs = await callSpace(
207
  tool.baseUrl,
208
  tool.endpoint,
209
  await Promise.all(inputs),
210
- ipToken
 
211
  );
212
 
213
  if (!isValidOutputComponent(tool.outputComponent)) {
 
119
  .transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
120
 
121
  export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
122
+ return async function* (params, ctx, uuid) {
123
  if (
124
  tool.endpoint === null ||
125
  !tool.baseUrl ||
 
203
  }
204
  });
205
 
206
+ const outputs = yield* callSpace(
207
  tool.baseUrl,
208
  tool.endpoint,
209
  await Promise.all(inputs),
210
+ ipToken,
211
+ uuid
212
  );
213
 
214
  if (!isValidOutputComponent(tool.outputComponent)) {
src/lib/server/tools/utils.ts CHANGED
@@ -1,27 +1,20 @@
1
  import { env } from "$env/dynamic/private";
2
  import { Client } from "@gradio/client";
3
  import { SignJWT } from "jose";
4
- import { logger } from "../logger";
5
  import JSON5 from "json5";
 
 
 
 
 
6
 
7
- export type GradioImage = {
8
- path: string;
9
- url: string;
10
- orig_name: string;
11
- is_stream: boolean;
12
- meta: Record<string, unknown>;
13
- };
14
-
15
- type GradioResponse = {
16
- data: unknown[];
17
- };
18
-
19
- export async function callSpace<TInput extends unknown[], TOutput extends unknown[]>(
20
  name: string,
21
  func: string,
22
  parameters: TInput,
23
- ipToken: string | undefined
24
- ): Promise<TOutput> {
 
25
  class CustomClient extends Client {
26
  fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
27
  init = init || {};
@@ -34,15 +27,32 @@ export async function callSpace<TInput extends unknown[], TOutput extends unknow
34
  }
35
  const client = await CustomClient.connect(name, {
36
  hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
 
37
  });
38
 
39
- return await client
40
- .predict(func, parameters)
41
- .then((res) => (res as unknown as GradioResponse).data as TOutput)
42
- .catch((e) => {
43
- logger.error(e);
44
- throw e;
45
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  }
47
 
48
  export async function getIpToken(ip: string, username?: string) {
 
1
  import { env } from "$env/dynamic/private";
2
  import { Client } from "@gradio/client";
3
  import { SignJWT } from "jose";
 
4
  import JSON5 from "json5";
5
+ import {
6
+ MessageToolUpdateType,
7
+ MessageUpdateType,
8
+ type MessageToolUpdate,
9
+ } from "$lib/types/MessageUpdate";
10
 
11
+ export async function* callSpace<TInput extends unknown[], TOutput extends unknown[]>(
 
 
 
 
 
 
 
 
 
 
 
 
12
  name: string,
13
  func: string,
14
  parameters: TInput,
15
+ ipToken: string | undefined,
16
+ uuid: string
17
+ ): AsyncGenerator<MessageToolUpdate, TOutput, undefined> {
18
  class CustomClient extends Client {
19
  fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
20
  init = init || {};
 
27
  }
28
  const client = await CustomClient.connect(name, {
29
  hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
30
+ events: ["status", "data"],
31
  });
32
 
33
+ const job = client.submit(func, parameters);
34
+
35
+ let data;
36
+ for await (const output of job) {
37
+ console.log({ output });
38
+ if (output.type === "data") {
39
+ data = output.data as TOutput;
40
+ }
41
+ if (output.type === "status" && output.eta) {
42
+ yield {
43
+ type: MessageUpdateType.Tool,
44
+ subtype: MessageToolUpdateType.ETA,
45
+ eta: output.eta,
46
+ uuid,
47
+ };
48
+ }
49
+ }
50
+
51
+ if (!data) {
52
+ throw new Error("No data found in tool call");
53
+ }
54
+
55
+ return data;
56
  }
57
 
58
  export async function getIpToken(ip: string, username?: string) {
src/lib/types/MessageUpdate.ts CHANGED
@@ -75,7 +75,10 @@ export enum MessageToolUpdateType {
75
  Result = "result",
76
  /** Error while running tool */
77
  Error = "error",
 
 
78
  }
 
79
  interface MessageToolBaseUpdate<TSubType extends MessageToolUpdateType> {
80
  type: MessageUpdateType.Tool;
81
  subtype: TSubType;
@@ -91,10 +94,16 @@ export interface MessageToolResultUpdate
91
  export interface MessageToolErrorUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.Error> {
92
  message: string;
93
  }
 
 
 
 
 
94
  export type MessageToolUpdate =
95
  | MessageToolCallUpdate
96
  | MessageToolResultUpdate
97
- | MessageToolErrorUpdate;
 
98
 
99
  // Everything else
100
  export interface MessageTitleUpdate {
 
75
  Result = "result",
76
  /** Error while running tool */
77
  Error = "error",
78
+ /** ETA update */
79
+ ETA = "eta",
80
  }
81
+
82
  interface MessageToolBaseUpdate<TSubType extends MessageToolUpdateType> {
83
  type: MessageUpdateType.Tool;
84
  subtype: TSubType;
 
94
  export interface MessageToolErrorUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.Error> {
95
  message: string;
96
  }
97
+
98
+ export interface MessageToolETAUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.ETA> {
99
+ eta: number;
100
+ }
101
+
102
  export type MessageToolUpdate =
103
  | MessageToolCallUpdate
104
  | MessageToolResultUpdate
105
+ | MessageToolErrorUpdate
106
+ | MessageToolETAUpdate;
107
 
108
  // Everything else
109
  export interface MessageTitleUpdate {
src/lib/types/Tool.ts CHANGED
@@ -177,5 +177,6 @@ export interface ToolCall {
177
 
178
  export type BackendCall = (
179
  params: Record<string, string | number | boolean>,
180
- context: BackendToolContext
 
181
  ) => AsyncGenerator<MessageUpdate, Omit<ToolResultSuccess, "status" | "call" | "type">, undefined>;
 
177
 
178
  export type BackendCall = (
179
  params: Record<string, string | number | boolean>,
180
+ context: BackendToolContext,
181
+ uuid: string
182
  ) => AsyncGenerator<MessageUpdate, Omit<ToolResultSuccess, "status" | "call" | "type">, undefined>;