Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
feat(tools): use gradio ETA for tools (#1411)
Browse files
src/lib/components/chat/ToolUpdate.svelte
CHANGED
@@ -9,7 +9,7 @@
|
|
9 |
import CarbonTools from "~icons/carbon/tools";
|
10 |
import { ToolResultStatus, type ToolFront } from "$lib/types/Tool";
|
11 |
import { page } from "$app/stores";
|
12 |
-
import {
|
13 |
import { browser } from "$app/environment";
|
14 |
|
15 |
export let tool: MessageToolUpdate[];
|
@@ -19,6 +19,8 @@
|
|
19 |
$: toolError = tool.some(isMessageToolErrorUpdate);
|
20 |
$: toolDone = tool.some(isMessageToolResultUpdate);
|
21 |
|
|
|
|
|
22 |
const availableTools: ToolFront[] = $page.data.tools;
|
23 |
|
24 |
let loadingBarEl: HTMLDivElement;
|
@@ -26,16 +28,24 @@
|
|
26 |
|
27 |
let isShowingLoadingBar = false;
|
28 |
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
loadingBarEl.classList.remove("hidden");
|
32 |
isShowingLoadingBar = true;
|
33 |
animation = loadingBarEl.animate([{ width: "0%" }, { width: "calc(100%+1rem)" }], {
|
34 |
-
duration:
|
35 |
fill: "forwards",
|
36 |
});
|
|
|
|
|
|
|
|
|
|
|
37 |
}
|
38 |
-
return () => animation?.cancel();
|
39 |
});
|
40 |
|
41 |
// go to 100% quickly if loading is done
|
|
|
9 |
import CarbonTools from "~icons/carbon/tools";
|
10 |
import { ToolResultStatus, type ToolFront } from "$lib/types/Tool";
|
11 |
import { page } from "$app/stores";
|
12 |
+
import { onDestroy } from "svelte";
|
13 |
import { browser } from "$app/environment";
|
14 |
|
15 |
export let tool: MessageToolUpdate[];
|
|
|
19 |
$: toolError = tool.some(isMessageToolErrorUpdate);
|
20 |
$: toolDone = tool.some(isMessageToolResultUpdate);
|
21 |
|
22 |
+
$: eta = tool.find((el) => el.subtype === MessageToolUpdateType.ETA)?.eta;
|
23 |
+
|
24 |
const availableTools: ToolFront[] = $page.data.tools;
|
25 |
|
26 |
let loadingBarEl: HTMLDivElement;
|
|
|
28 |
|
29 |
let isShowingLoadingBar = false;
|
30 |
|
31 |
+
$: !toolError &&
|
32 |
+
!toolDone &&
|
33 |
+
loading &&
|
34 |
+
loadingBarEl &&
|
35 |
+
eta &&
|
36 |
+
(() => {
|
37 |
loadingBarEl.classList.remove("hidden");
|
38 |
isShowingLoadingBar = true;
|
39 |
animation = loadingBarEl.animate([{ width: "0%" }, { width: "calc(100%+1rem)" }], {
|
40 |
+
duration: eta * 1000,
|
41 |
fill: "forwards",
|
42 |
});
|
43 |
+
})();
|
44 |
+
|
45 |
+
onDestroy(() => {
|
46 |
+
if (animation) {
|
47 |
+
animation.cancel();
|
48 |
}
|
|
|
49 |
});
|
50 |
|
51 |
// go to 100% quickly if loading is done
|
src/lib/server/textGeneration/tools.ts
CHANGED
@@ -70,7 +70,7 @@ async function* callTool(
|
|
70 |
};
|
71 |
|
72 |
try {
|
73 |
-
const toolResult = yield* tool.call(call.parameters, ctx);
|
74 |
|
75 |
yield {
|
76 |
type: MessageUpdateType.Tool,
|
|
|
70 |
};
|
71 |
|
72 |
try {
|
73 |
+
const toolResult = yield* tool.call(call.parameters, ctx, uuid);
|
74 |
|
75 |
yield {
|
76 |
type: MessageUpdateType.Tool,
|
src/lib/server/tools/index.ts
CHANGED
@@ -119,7 +119,7 @@ export const configTools = z
|
|
119 |
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
|
120 |
|
121 |
export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
|
122 |
-
return async function* (params, ctx) {
|
123 |
if (
|
124 |
tool.endpoint === null ||
|
125 |
!tool.baseUrl ||
|
@@ -203,11 +203,12 @@ export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
|
|
203 |
}
|
204 |
});
|
205 |
|
206 |
-
const outputs =
|
207 |
tool.baseUrl,
|
208 |
tool.endpoint,
|
209 |
await Promise.all(inputs),
|
210 |
-
ipToken
|
|
|
211 |
);
|
212 |
|
213 |
if (!isValidOutputComponent(tool.outputComponent)) {
|
|
|
119 |
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]);
|
120 |
|
121 |
export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall {
|
122 |
+
return async function* (params, ctx, uuid) {
|
123 |
if (
|
124 |
tool.endpoint === null ||
|
125 |
!tool.baseUrl ||
|
|
|
203 |
}
|
204 |
});
|
205 |
|
206 |
+
const outputs = yield* callSpace(
|
207 |
tool.baseUrl,
|
208 |
tool.endpoint,
|
209 |
await Promise.all(inputs),
|
210 |
+
ipToken,
|
211 |
+
uuid
|
212 |
);
|
213 |
|
214 |
if (!isValidOutputComponent(tool.outputComponent)) {
|
src/lib/server/tools/utils.ts
CHANGED
@@ -1,27 +1,20 @@
|
|
1 |
import { env } from "$env/dynamic/private";
|
2 |
import { Client } from "@gradio/client";
|
3 |
import { SignJWT } from "jose";
|
4 |
-
import { logger } from "../logger";
|
5 |
import JSON5 from "json5";
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
export
|
8 |
-
path: string;
|
9 |
-
url: string;
|
10 |
-
orig_name: string;
|
11 |
-
is_stream: boolean;
|
12 |
-
meta: Record<string, unknown>;
|
13 |
-
};
|
14 |
-
|
15 |
-
type GradioResponse = {
|
16 |
-
data: unknown[];
|
17 |
-
};
|
18 |
-
|
19 |
-
export async function callSpace<TInput extends unknown[], TOutput extends unknown[]>(
|
20 |
name: string,
|
21 |
func: string,
|
22 |
parameters: TInput,
|
23 |
-
ipToken: string | undefined
|
24 |
-
|
|
|
25 |
class CustomClient extends Client {
|
26 |
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
|
27 |
init = init || {};
|
@@ -34,15 +27,32 @@ export async function callSpace<TInput extends unknown[], TOutput extends unknow
|
|
34 |
}
|
35 |
const client = await CustomClient.connect(name, {
|
36 |
hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
|
|
|
37 |
});
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
}
|
47 |
|
48 |
export async function getIpToken(ip: string, username?: string) {
|
|
|
1 |
import { env } from "$env/dynamic/private";
|
2 |
import { Client } from "@gradio/client";
|
3 |
import { SignJWT } from "jose";
|
|
|
4 |
import JSON5 from "json5";
|
5 |
+
import {
|
6 |
+
MessageToolUpdateType,
|
7 |
+
MessageUpdateType,
|
8 |
+
type MessageToolUpdate,
|
9 |
+
} from "$lib/types/MessageUpdate";
|
10 |
|
11 |
+
export async function* callSpace<TInput extends unknown[], TOutput extends unknown[]>(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
name: string,
|
13 |
func: string,
|
14 |
parameters: TInput,
|
15 |
+
ipToken: string | undefined,
|
16 |
+
uuid: string
|
17 |
+
): AsyncGenerator<MessageToolUpdate, TOutput, undefined> {
|
18 |
class CustomClient extends Client {
|
19 |
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
|
20 |
init = init || {};
|
|
|
27 |
}
|
28 |
const client = await CustomClient.connect(name, {
|
29 |
hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
|
30 |
+
events: ["status", "data"],
|
31 |
});
|
32 |
|
33 |
+
const job = client.submit(func, parameters);
|
34 |
+
|
35 |
+
let data;
|
36 |
+
for await (const output of job) {
|
37 |
+
console.log({ output });
|
38 |
+
if (output.type === "data") {
|
39 |
+
data = output.data as TOutput;
|
40 |
+
}
|
41 |
+
if (output.type === "status" && output.eta) {
|
42 |
+
yield {
|
43 |
+
type: MessageUpdateType.Tool,
|
44 |
+
subtype: MessageToolUpdateType.ETA,
|
45 |
+
eta: output.eta,
|
46 |
+
uuid,
|
47 |
+
};
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
if (!data) {
|
52 |
+
throw new Error("No data found in tool call");
|
53 |
+
}
|
54 |
+
|
55 |
+
return data;
|
56 |
}
|
57 |
|
58 |
export async function getIpToken(ip: string, username?: string) {
|
src/lib/types/MessageUpdate.ts
CHANGED
@@ -75,7 +75,10 @@ export enum MessageToolUpdateType {
|
|
75 |
Result = "result",
|
76 |
/** Error while running tool */
|
77 |
Error = "error",
|
|
|
|
|
78 |
}
|
|
|
79 |
interface MessageToolBaseUpdate<TSubType extends MessageToolUpdateType> {
|
80 |
type: MessageUpdateType.Tool;
|
81 |
subtype: TSubType;
|
@@ -91,10 +94,16 @@ export interface MessageToolResultUpdate
|
|
91 |
export interface MessageToolErrorUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.Error> {
|
92 |
message: string;
|
93 |
}
|
|
|
|
|
|
|
|
|
|
|
94 |
export type MessageToolUpdate =
|
95 |
| MessageToolCallUpdate
|
96 |
| MessageToolResultUpdate
|
97 |
-
| MessageToolErrorUpdate
|
|
|
98 |
|
99 |
// Everything else
|
100 |
export interface MessageTitleUpdate {
|
|
|
75 |
Result = "result",
|
76 |
/** Error while running tool */
|
77 |
Error = "error",
|
78 |
+
/** ETA update */
|
79 |
+
ETA = "eta",
|
80 |
}
|
81 |
+
|
82 |
interface MessageToolBaseUpdate<TSubType extends MessageToolUpdateType> {
|
83 |
type: MessageUpdateType.Tool;
|
84 |
subtype: TSubType;
|
|
|
94 |
export interface MessageToolErrorUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.Error> {
|
95 |
message: string;
|
96 |
}
|
97 |
+
|
98 |
+
export interface MessageToolETAUpdate extends MessageToolBaseUpdate<MessageToolUpdateType.ETA> {
|
99 |
+
eta: number;
|
100 |
+
}
|
101 |
+
|
102 |
export type MessageToolUpdate =
|
103 |
| MessageToolCallUpdate
|
104 |
| MessageToolResultUpdate
|
105 |
+
| MessageToolErrorUpdate
|
106 |
+
| MessageToolETAUpdate;
|
107 |
|
108 |
// Everything else
|
109 |
export interface MessageTitleUpdate {
|
src/lib/types/Tool.ts
CHANGED
@@ -177,5 +177,6 @@ export interface ToolCall {
|
|
177 |
|
178 |
export type BackendCall = (
|
179 |
params: Record<string, string | number | boolean>,
|
180 |
-
context: BackendToolContext
|
|
|
181 |
) => AsyncGenerator<MessageUpdate, Omit<ToolResultSuccess, "status" | "call" | "type">, undefined>;
|
|
|
177 |
|
178 |
export type BackendCall = (
|
179 |
params: Record<string, string | number | boolean>,
|
180 |
+
context: BackendToolContext,
|
181 |
+
uuid: string
|
182 |
) => AsyncGenerator<MessageUpdate, Omit<ToolResultSuccess, "status" | "call" | "type">, undefined>;
|