Mishig commited on
Commit
719022a
·
unverified ·
1 Parent(s): aee936e

Bypass spaces tools rate limit (#1167)

Browse files

* Bypass spaces tools rate limit

* DRY

* optimize

* use `context` for userId & ip

* lint

* user `userName`

* fix merge conflict

* rn `userName` -> `username`

chart/env/prod.yaml CHANGED
@@ -356,6 +356,7 @@ externalSecrets:
356
  ADMIN_API_SECRET: "hub-prod-chat-ui-admin-api-secret"
357
  USAGE_LIMITS: "hub-prod-chat-ui-usage-limits"
358
  MESSAGES_BEFORE_LOGIN: "hub-prod-chat-ui-messages-before-login"
 
359
 
360
  autoscaling:
361
  enabled: true
 
356
  ADMIN_API_SECRET: "hub-prod-chat-ui-admin-api-secret"
357
  USAGE_LIMITS: "hub-prod-chat-ui-usage-limits"
358
  MESSAGES_BEFORE_LOGIN: "hub-prod-chat-ui-messages-before-login"
359
+ IP_TOKEN_SECRET: "hub-prod-chat-ui-ip-token-secret"
360
 
361
  autoscaling:
362
  enabled: true
package-lock.json CHANGED
@@ -26,6 +26,7 @@
26
  "highlight.js": "^11.7.0",
27
  "image-size": "^1.0.2",
28
  "ip-address": "^9.0.5",
 
29
  "jsdom": "^22.0.0",
30
  "json5": "^2.2.3",
31
  "marked": "^12.0.1",
@@ -5725,9 +5726,9 @@
5725
  }
5726
  },
5727
  "node_modules/jose": {
5728
- "version": "4.15.5",
5729
- "resolved": "https://registry.npmjs.org/jose/-/jose-4.15.5.tgz",
5730
- "integrity": "sha512-jc7BFxgKPKi94uOvEmzlSWFFe2+vASyXaKUpdQKatWAESU2MWjDfFf0fdfc83CDKcA5QecabZeNLyfhe3yKNkg==",
5731
  "funding": {
5732
  "url": "https://github.com/sponsors/panva"
5733
  }
@@ -6803,6 +6804,14 @@
6803
  "url": "https://github.com/sponsors/panva"
6804
  }
6805
  },
 
 
 
 
 
 
 
 
6806
  "node_modules/openid-client/node_modules/object-hash": {
6807
  "version": "2.2.0",
6808
  "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz",
 
26
  "highlight.js": "^11.7.0",
27
  "image-size": "^1.0.2",
28
  "ip-address": "^9.0.5",
29
+ "jose": "^5.3.0",
30
  "jsdom": "^22.0.0",
31
  "json5": "^2.2.3",
32
  "marked": "^12.0.1",
 
5726
  }
5727
  },
5728
  "node_modules/jose": {
5729
+ "version": "5.3.0",
5730
+ "resolved": "https://registry.npmjs.org/jose/-/jose-5.3.0.tgz",
5731
+ "integrity": "sha512-IChe9AtAE79ru084ow8jzkN2lNrG3Ntfiv65Cvj9uOCE2m5LNsdHG+9EbxWxAoWRF9TgDOqLN5jm08++owDVRg==",
5732
  "funding": {
5733
  "url": "https://github.com/sponsors/panva"
5734
  }
 
6804
  "url": "https://github.com/sponsors/panva"
6805
  }
6806
  },
6807
+ "node_modules/openid-client/node_modules/jose": {
6808
+ "version": "4.15.5",
6809
+ "resolved": "https://registry.npmjs.org/jose/-/jose-4.15.5.tgz",
6810
+ "integrity": "sha512-jc7BFxgKPKi94uOvEmzlSWFFe2+vASyXaKUpdQKatWAESU2MWjDfFf0fdfc83CDKcA5QecabZeNLyfhe3yKNkg==",
6811
+ "funding": {
6812
+ "url": "https://github.com/sponsors/panva"
6813
+ }
6814
+ },
6815
  "node_modules/openid-client/node_modules/object-hash": {
6816
  "version": "2.2.0",
6817
  "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-2.2.0.tgz",
package.json CHANGED
@@ -71,6 +71,7 @@
71
  "highlight.js": "^11.7.0",
72
  "image-size": "^1.0.2",
73
  "ip-address": "^9.0.5",
 
74
  "jsdom": "^22.0.0",
75
  "json5": "^2.2.3",
76
  "marked": "^12.0.1",
 
71
  "highlight.js": "^11.7.0",
72
  "image-size": "^1.0.2",
73
  "ip-address": "^9.0.5",
74
+ "jose": "^5.3.0",
75
  "jsdom": "^22.0.0",
76
  "json5": "^2.2.3",
77
  "marked": "^12.0.1",
src/lib/server/textGeneration/tools.ts CHANGED
@@ -49,7 +49,7 @@ export function pickTools(
49
  }
50
 
51
  async function* runTool(
52
- { conv, messages, preprompt, assistant }: BackendToolContext,
53
  tools: BackendTool[],
54
  call: ToolCall
55
  ): AsyncGenerator<MessageUpdate, ToolResult | undefined, undefined> {
@@ -74,12 +74,7 @@ async function* runTool(
74
  };
75
  try {
76
  try {
77
- const toolResult = yield* tool.call(call.parameters, {
78
- conv,
79
- messages,
80
- preprompt,
81
- assistant,
82
- });
83
  if (toolResult.status === ToolResultStatus.Error) {
84
  yield {
85
  type: MessageUpdateType.Tool,
@@ -123,10 +118,11 @@ async function* runTool(
123
  }
124
 
125
  export async function* runTools(
126
- { endpoint, conv, messages, assistant }: TextGenerationContext,
127
  tools: BackendTool[],
128
  preprompt?: string
129
  ): AsyncGenerator<MessageUpdate, ToolResult[], undefined> {
 
130
  const calls: ToolCall[] = [];
131
 
132
  const messagesWithFilesPrompt = messages.map((message, idx) => {
@@ -181,7 +177,7 @@ export async function* runTools(
181
  Date.now() - pickToolStartTime
182
  );
183
 
184
- const toolContext: BackendToolContext = { conv, messages, preprompt, assistant };
185
  const toolResults: (ToolResult | undefined)[] = yield* mergeAsyncGenerators(
186
  calls.map((call) => runTool(toolContext, tools, call))
187
  );
 
49
  }
50
 
51
  async function* runTool(
52
+ ctx: BackendToolContext,
53
  tools: BackendTool[],
54
  call: ToolCall
55
  ): AsyncGenerator<MessageUpdate, ToolResult | undefined, undefined> {
 
74
  };
75
  try {
76
  try {
77
+ const toolResult = yield* tool.call(call.parameters, ctx);
 
 
 
 
 
78
  if (toolResult.status === ToolResultStatus.Error) {
79
  yield {
80
  type: MessageUpdateType.Tool,
 
118
  }
119
 
120
  export async function* runTools(
121
+ ctx: TextGenerationContext,
122
  tools: BackendTool[],
123
  preprompt?: string
124
  ): AsyncGenerator<MessageUpdate, ToolResult[], undefined> {
125
+ const { endpoint, conv, messages, assistant, ip, username } = ctx;
126
  const calls: ToolCall[] = [];
127
 
128
  const messagesWithFilesPrompt = messages.map((message, idx) => {
 
177
  Date.now() - pickToolStartTime
178
  );
179
 
180
+ const toolContext: BackendToolContext = { conv, messages, preprompt, assistant, ip, username };
181
  const toolResults: (ToolResult | undefined)[] = yield* mergeAsyncGenerators(
182
  calls.map((call) => runTool(toolContext, tools, call))
183
  );
src/lib/server/textGeneration/types.ts CHANGED
@@ -14,4 +14,6 @@ export interface TextGenerationContext {
14
  webSearch: boolean;
15
  toolsPreference: Record<string, boolean>;
16
  promptedAt: Date;
 
 
17
  }
 
14
  webSearch: boolean;
15
  toolsPreference: Record<string, boolean>;
16
  promptedAt: Date;
17
+ ip: string;
18
+ username?: string;
19
  }
src/lib/server/tools/documentParser.ts CHANGED
@@ -1,6 +1,6 @@
1
  import type { BackendTool } from ".";
2
  import { ToolResultStatus } from "$lib/types/Tool";
3
- import { callSpace } from "./utils";
4
  import { downloadFile } from "$lib/server/files/downloadFile";
5
 
6
  type PdfParserInput = [Blob /* pdf */, string /* filename */];
@@ -23,7 +23,7 @@ const documentParser: BackendTool = {
23
  required: true,
24
  },
25
  },
26
- async *call({ fileMessageIndex, fileIndex }, { conv, messages }) {
27
  fileMessageIndex = Number(fileMessageIndex);
28
  fileIndex = Number(fileIndex);
29
 
@@ -47,10 +47,13 @@ const documentParser: BackendTool = {
47
  .then((file) => fetch(`data:${file.mime};base64,${file.value}`))
48
  .then((res) => res.blob());
49
 
 
 
50
  const outputs = await callSpace<PdfParserInput, PdfParserOutput>(
51
  "huggingchat/document-parser",
52
  "predict",
53
- [fileBlob, file.name]
 
54
  );
55
 
56
  let documentMarkdown = outputs[0];
 
1
  import type { BackendTool } from ".";
2
  import { ToolResultStatus } from "$lib/types/Tool";
3
+ import { callSpace, getIpToken } from "./utils";
4
  import { downloadFile } from "$lib/server/files/downloadFile";
5
 
6
  type PdfParserInput = [Blob /* pdf */, string /* filename */];
 
23
  required: true,
24
  },
25
  },
26
+ async *call({ fileMessageIndex, fileIndex }, { conv, messages, ip, username }) {
27
  fileMessageIndex = Number(fileMessageIndex);
28
  fileIndex = Number(fileIndex);
29
 
 
47
  .then((file) => fetch(`data:${file.mime};base64,${file.value}`))
48
  .then((res) => res.blob());
49
 
50
+ const ipToken = await getIpToken(ip, username);
51
+
52
  const outputs = await callSpace<PdfParserInput, PdfParserOutput>(
53
  "huggingchat/document-parser",
54
  "predict",
55
+ [fileBlob, file.name],
56
+ ipToken
57
  );
58
 
59
  let documentMarkdown = outputs[0];
src/lib/server/tools/images/editing.ts CHANGED
@@ -2,7 +2,7 @@ import type { BackendTool } from "..";
2
  import { uploadFile } from "../../files/uploadFile";
3
  import { ToolResultStatus } from "$lib/types/Tool";
4
  import { MessageUpdateType } from "$lib/types/MessageUpdate";
5
- import { callSpace, type GradioImage } from "../utils";
6
  import { downloadFile } from "$lib/server/files/downloadFile";
7
 
8
  type ImageEditingInput = [
@@ -37,7 +37,7 @@ const imageEditing: BackendTool = {
37
  required: true,
38
  },
39
  },
40
- async *call({ prompt, fileMessageIndex, fileIndex }, { conv, messages }) {
41
  prompt = String(prompt);
42
  fileMessageIndex = Number(fileMessageIndex);
43
  fileIndex = Number(fileIndex);
@@ -68,6 +68,8 @@ const imageEditing: BackendTool = {
68
  .then((file) => fetch(`data:${file.mime};base64,${file.value}`))
69
  .then((res) => res.blob());
70
 
 
 
71
  const outputs = await callSpace<ImageEditingInput, ImageEditingOutput>(
72
  "multimodalart/cosxl",
73
  "run_edit",
@@ -77,7 +79,8 @@ const imageEditing: BackendTool = {
77
  "", // negative prompt
78
  7, // guidance scale
79
  20, // steps
80
- ]
 
81
  );
82
 
83
  const outputImage = await fetch(outputs[0].url)
 
2
  import { uploadFile } from "../../files/uploadFile";
3
  import { ToolResultStatus } from "$lib/types/Tool";
4
  import { MessageUpdateType } from "$lib/types/MessageUpdate";
5
+ import { callSpace, getIpToken, type GradioImage } from "../utils";
6
  import { downloadFile } from "$lib/server/files/downloadFile";
7
 
8
  type ImageEditingInput = [
 
37
  required: true,
38
  },
39
  },
40
+ async *call({ prompt, fileMessageIndex, fileIndex }, { conv, messages, ip, username }) {
41
  prompt = String(prompt);
42
  fileMessageIndex = Number(fileMessageIndex);
43
  fileIndex = Number(fileIndex);
 
68
  .then((file) => fetch(`data:${file.mime};base64,${file.value}`))
69
  .then((res) => res.blob());
70
 
71
+ const ipToken = await getIpToken(ip, username);
72
+
73
  const outputs = await callSpace<ImageEditingInput, ImageEditingOutput>(
74
  "multimodalart/cosxl",
75
  "run_edit",
 
79
  "", // negative prompt
80
  7, // guidance scale
81
  20, // steps
82
+ ],
83
+ ipToken
84
  );
85
 
86
  const outputImage = await fetch(outputs[0].url)
src/lib/server/tools/images/generation.ts CHANGED
@@ -2,7 +2,7 @@ import type { BackendTool } from "..";
2
  import { uploadFile } from "../../files/uploadFile";
3
  import { ToolResultStatus } from "$lib/types/Tool";
4
  import { MessageUpdateType } from "$lib/types/MessageUpdate";
5
- import { callSpace, type GradioImage } from "../utils";
6
 
7
  type ImageGenerationInput = [
8
  number /* number (numeric value between 1 and 8) in 'Number of Images' Slider component */,
@@ -44,7 +44,9 @@ const imageGeneration: BackendTool = {
44
  default: 1024,
45
  },
46
  },
47
- async *call({ prompt, numberOfImages }, { conv }) {
 
 
48
  const outputs = await callSpace<ImageGenerationInput, ImageGenerationOutput>(
49
  "ByteDance/Hyper-SDXL-1Step-T2I",
50
  "/process_image",
@@ -54,7 +56,8 @@ const imageGeneration: BackendTool = {
54
  512, // number in 'Image Width' Number component
55
  String(prompt), // prompt
56
  Math.floor(Math.random() * 1000), // seed random
57
- ]
 
58
  );
59
  const imageBlobs = await Promise.all(
60
  outputs[0].map((output) =>
 
2
  import { uploadFile } from "../../files/uploadFile";
3
  import { ToolResultStatus } from "$lib/types/Tool";
4
  import { MessageUpdateType } from "$lib/types/MessageUpdate";
5
+ import { callSpace, getIpToken, type GradioImage } from "../utils";
6
 
7
  type ImageGenerationInput = [
8
  number /* number (numeric value between 1 and 8) in 'Number of Images' Slider component */,
 
44
  default: 1024,
45
  },
46
  },
47
+ async *call({ prompt, numberOfImages }, { conv, ip, username }) {
48
+ const ipToken = await getIpToken(ip, username);
49
+
50
  const outputs = await callSpace<ImageGenerationInput, ImageGenerationOutput>(
51
  "ByteDance/Hyper-SDXL-1Step-T2I",
52
  "/process_image",
 
56
  512, // number in 'Image Width' Number component
57
  String(prompt), // prompt
58
  Math.floor(Math.random() * 1000), // seed random
59
+ ],
60
+ ipToken
61
  );
62
  const imageBlobs = await Promise.all(
63
  outputs[0].map((output) =>
src/lib/server/tools/index.ts CHANGED
@@ -1,6 +1,3 @@
1
- import type { Assistant } from "$lib/types/Assistant";
2
- import type { Conversation } from "$lib/types/Conversation";
3
- import type { Message } from "$lib/types/Message";
4
  import type { MessageUpdate } from "$lib/types/MessageUpdate";
5
  import type { Tool, ToolResultError, ToolResultSuccess } from "$lib/types/Tool";
6
 
@@ -11,13 +8,12 @@ import imageGeneration from "./images/generation";
11
  import documentParser from "./documentParser";
12
  import fetchUrl from "./web/url";
13
  import websearch from "./web/search";
 
14
 
15
- export interface BackendToolContext {
16
- conv: Conversation;
17
- messages: Message[];
18
- preprompt?: string;
19
- assistant?: Pick<Assistant, "rag" | "dynamicPrompt" | "generateSettings">;
20
- }
21
 
22
  // typescript can't narrow a discriminated union after applying a generic like Omit to it
23
  // so we have to define the omitted types and create a new union
 
 
 
 
1
  import type { MessageUpdate } from "$lib/types/MessageUpdate";
2
  import type { Tool, ToolResultError, ToolResultSuccess } from "$lib/types/Tool";
3
 
 
8
  import documentParser from "./documentParser";
9
  import fetchUrl from "./web/url";
10
  import websearch from "./web/search";
11
+ import type { TextGenerationContext } from "../textGeneration/types";
12
 
13
+ export type BackendToolContext = Pick<
14
+ TextGenerationContext,
15
+ "conv" | "messages" | "assistant" | "ip" | "username"
16
+ > & { preprompt?: string };
 
 
17
 
18
  // typescript can't narrow a discriminated union after applying a generic like Omit to it
19
  // so we have to define the omitted types and create a new union
src/lib/server/tools/utils.ts CHANGED
@@ -1,5 +1,6 @@
1
  import { env } from "$env/dynamic/private";
2
  import { Client } from "@gradio/client";
 
3
 
4
  export type GradioImage = {
5
  path: string;
@@ -16,9 +17,21 @@ type GradioResponse = {
16
  export async function callSpace<TInput extends unknown[], TOutput extends unknown[]>(
17
  name: string,
18
  func: string,
19
- parameters: TInput
 
20
  ): Promise<TOutput> {
21
- const client = await Client.connect(name, {
 
 
 
 
 
 
 
 
 
 
 
22
  hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
23
  });
24
  return await client
@@ -26,4 +39,16 @@ export async function callSpace<TInput extends unknown[], TOutput extends unknow
26
  .then((res) => (res as unknown as GradioResponse).data as TOutput);
27
  }
28
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  export { toolHasName } from "$lib/utils/tools";
 
1
  import { env } from "$env/dynamic/private";
2
  import { Client } from "@gradio/client";
3
+ import { SignJWT } from "jose";
4
 
5
  export type GradioImage = {
6
  path: string;
 
17
  export async function callSpace<TInput extends unknown[], TOutput extends unknown[]>(
18
  name: string,
19
  func: string,
20
+ parameters: TInput,
21
+ ipToken: string | undefined
22
  ): Promise<TOutput> {
23
+ class CustomClient extends Client {
24
+ fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
25
+ init = init || {};
26
+ init.headers = {
27
+ ...(init.headers || {}),
28
+ ...(ipToken ? { "X-IP-Token": ipToken } : {}),
29
+ };
30
+ return super.fetch(input, init);
31
+ }
32
+ }
33
+
34
+ const client = await CustomClient.connect(name, {
35
  hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
36
  });
37
  return await client
 
39
  .then((res) => (res as unknown as GradioResponse).data as TOutput);
40
  }
41
 
42
+ export async function getIpToken(ip: string, username?: string) {
43
+ const ipTokenSecret = env.IP_TOKEN_SECRET;
44
+ if (!ipTokenSecret) {
45
+ return;
46
+ }
47
+ return await new SignJWT({ ip, user: username })
48
+ .setProtectedHeader({ alg: "HS256" })
49
+ .setIssuedAt()
50
+ .setExpirationTime("1m")
51
+ .sign(new TextEncoder().encode(ipTokenSecret));
52
+ }
53
+
54
  export { toolHasName } from "$lib/utils/tools";
src/routes/conversation/[id]/+server.ts CHANGED
@@ -407,6 +407,8 @@ export async function POST({ request, locals, params, getClientAddress }) {
407
  webSearch: webSearch ?? false,
408
  toolsPreference: toolsPreferences ?? {},
409
  promptedAt,
 
 
410
  };
411
  // run the text generation and send updates to the client
412
  for await (const event of textGeneration(ctx)) await update(event);
 
407
  webSearch: webSearch ?? false,
408
  toolsPreference: toolsPreferences ?? {},
409
  promptedAt,
410
+ ip: getClientAddress(),
411
+ username: locals.user?.username,
412
  };
413
  // run the text generation and send updates to the client
414
  for await (const event of textGeneration(ctx)) await update(event);