Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import { MessageUpdateType } from "$lib/types/MessageUpdate"; | |
import { | |
ToolColor, | |
ToolIcon, | |
ToolOutputComponents, | |
type BackendCall, | |
type BaseTool, | |
type ConfigTool, | |
type ToolInput, | |
} from "$lib/types/Tool"; | |
import type { TextGenerationContext } from "../textGeneration/types"; | |
import { z } from "zod"; | |
import JSON5 from "json5"; | |
import { env } from "$env/dynamic/private"; | |
import jp from "jsonpath"; | |
import calculator from "./calculator"; | |
import directlyAnswer from "./directlyAnswer"; | |
import fetchUrl from "./web/url"; | |
import websearch from "./web/search"; | |
import { callSpace, getIpToken } from "./utils"; | |
import { uploadFile } from "../files/uploadFile"; | |
import type { MessageFile } from "$lib/types/Message"; | |
import { sha256 } from "$lib/utils/sha256"; | |
import { ObjectId } from "mongodb"; | |
import { isValidOutputComponent, ToolOutputPaths } from "./outputs"; | |
import { downloadFile } from "../files/downloadFile"; | |
import { fileTypeFromBlob } from "file-type"; | |
export type BackendToolContext = Pick< | |
TextGenerationContext, | |
"conv" | "messages" | "assistant" | "ip" | "username" | |
> & { preprompt?: string }; | |
const IOType = z.union([z.literal("str"), z.literal("int"), z.literal("float"), z.literal("bool")]); | |
const toolInputBaseSchema = z.union([ | |
z.object({ | |
name: z.string().min(1).max(80), | |
description: z.string().max(200).optional(), | |
paramType: z.literal("required"), | |
}), | |
z.object({ | |
name: z.string().min(1).max(80), | |
description: z.string().max(200).optional(), | |
paramType: z.literal("optional"), | |
default: z | |
.union([z.string().max(300), z.number(), z.boolean(), z.undefined()]) | |
.transform((val) => (val === undefined ? "" : val)), | |
}), | |
z.object({ | |
name: z.string().min(1).max(80), | |
paramType: z.literal("fixed"), | |
value: z | |
.union([z.string().max(300), z.number(), z.boolean(), z.undefined()]) | |
.transform((val) => (val === undefined ? "" : val)), | |
}), | |
]); | |
const toolInputSchema = toolInputBaseSchema.and( | |
z.object({ type: IOType }).or( | |
z.object({ | |
type: z.literal("file"), | |
mimeTypes: z.string().min(1), | |
}) | |
) | |
); | |
export const editableToolSchema = z | |
.object({ | |
name: z | |
.string() | |
.regex(/^[a-zA-Z_][a-zA-Z0-9_]*$/) // only allow letters, numbers, and underscores, and start with a letter or underscore | |
.min(1) | |
.max(40), | |
// only allow huggingface spaces either through namespace or direct URLs | |
baseUrl: z.union([ | |
z.string().regex(/^[^/]+\/[^/]+$/), | |
z | |
.string() | |
.regex(/^https:\/\/huggingface\.co\/spaces\/[a-zA-Z0-9-]+\/[a-zA-Z0-9-]+$/) | |
.transform((url) => url.split("/").slice(-2).join("/")), | |
]), | |
endpoint: z.string().min(1).max(100), | |
inputs: z.array(toolInputSchema), | |
outputComponent: z.string().min(1).max(100), | |
showOutput: z.boolean(), | |
displayName: z.string().min(1).max(40), | |
color: ToolColor, | |
icon: ToolIcon, | |
description: z.string().min(1).max(100), | |
}) | |
.transform((tool) => ({ | |
...tool, | |
outputComponentIdx: parseInt(tool.outputComponent.split(";")[0]), | |
outputComponent: ToolOutputComponents.parse(tool.outputComponent.split(";")[1]), | |
})); | |
export const configTools = z | |
.array( | |
z | |
.object({ | |
name: z.string(), | |
description: z.string(), | |
endpoint: z.union([z.string(), z.null()]), | |
inputs: z.array(toolInputSchema), | |
outputComponent: ToolOutputComponents.or(z.null()), | |
outputComponentIdx: z.number().int().default(0), | |
showOutput: z.boolean(), | |
_id: z | |
.string() | |
.length(24) | |
.regex(/^[0-9a-fA-F]{24}$/) | |
.transform((val) => new ObjectId(val)), | |
baseUrl: z.string().optional(), | |
displayName: z.string(), | |
color: ToolColor, | |
icon: ToolIcon, | |
isOnByDefault: z.optional(z.literal(true)), | |
isLocked: z.optional(z.literal(true)), | |
isHidden: z.optional(z.literal(true)), | |
}) | |
.transform((val) => ({ | |
type: "config" as const, | |
...val, | |
call: getCallMethod(val), | |
})) | |
) | |
// add the extra hardcoded tools | |
.transform((val) => [...val, calculator, directlyAnswer, fetchUrl, websearch]); | |
export function getCallMethod(tool: Omit<BaseTool, "call">): BackendCall { | |
return async function* (params, ctx, uuid) { | |
if ( | |
tool.endpoint === null || | |
!tool.baseUrl || | |
!tool.outputComponent || | |
tool.outputComponentIdx === null | |
) { | |
throw new Error(`Tool function ${tool.name} has no endpoint`); | |
} | |
const ipToken = await getIpToken(ctx.ip, ctx.username); | |
function coerceInput(value: unknown, type: ToolInput["type"]) { | |
const valueStr = String(value); | |
switch (type) { | |
case "str": | |
return valueStr; | |
case "int": | |
return parseInt(valueStr); | |
case "float": | |
return parseFloat(valueStr); | |
case "bool": | |
return valueStr === "true"; | |
default: | |
throw new Error(`Unsupported type ${type}`); | |
} | |
} | |
const inputs = tool.inputs.map(async (input) => { | |
if (input.type === "file" && input.paramType !== "required") { | |
throw new Error("File inputs are always required and cannot be optional or fixed"); | |
} | |
if (input.paramType === "fixed") { | |
return coerceInput(input.value, input.type); | |
} else if (input.paramType === "optional") { | |
return coerceInput(params[input.name] ?? input.default, input.type); | |
} else if (input.paramType === "required") { | |
if (params[input.name] === undefined) { | |
throw new Error(`Missing required input ${input.name}`); | |
} | |
if (input.type === "file") { | |
// todo: parse file here ! | |
// structure is {input|output}-{msgIdx}-{fileIdx}-{filename} | |
const filename = params[input.name]; | |
if (!filename || typeof filename !== "string") { | |
throw new Error(`Filename is not a string`); | |
} | |
const messages = ctx.messages; | |
const msgIdx = parseInt(filename.split("_")[1]); | |
const fileIdx = parseInt(filename.split("_")[2]); | |
if (Number.isNaN(msgIdx) || Number.isNaN(fileIdx)) { | |
throw Error(`Message index or file index is missing`); | |
} | |
if (msgIdx >= messages.length) { | |
throw Error(`Message index ${msgIdx} is out of bounds`); | |
} | |
const file = messages[msgIdx].files?.[fileIdx]; | |
if (!file) { | |
throw Error(`File index ${fileIdx} is out of bounds`); | |
} | |
const blob = await downloadFile(file.value, ctx.conv._id) | |
.then((file) => fetch(`data:${file.mime};base64,${file.value}`)) | |
.then((res) => res.blob()) | |
.catch((err) => { | |
throw Error("Failed to download file", { cause: err }); | |
}); | |
return blob; | |
} else { | |
return coerceInput(params[input.name], input.type); | |
} | |
} | |
}); | |
const outputs = yield* callSpace( | |
tool.baseUrl, | |
tool.endpoint, | |
await Promise.all(inputs), | |
ipToken, | |
uuid | |
); | |
if (!isValidOutputComponent(tool.outputComponent)) { | |
throw new Error(`Tool output component is not defined`); | |
} | |
const { type, path } = ToolOutputPaths[tool.outputComponent]; | |
if (!path || !type) { | |
throw new Error(`Tool output type ${tool.outputComponent} is not supported`); | |
} | |
const files: MessageFile[] = []; | |
const toolOutputs: Array<Record<string, string>> = []; | |
if (outputs.length <= tool.outputComponentIdx) { | |
throw new Error(`Tool output component index is out of bounds`); | |
} | |
// if its not an object, return directly | |
if ( | |
outputs[tool.outputComponentIdx] !== undefined && | |
typeof outputs[tool.outputComponentIdx] !== "object" | |
) { | |
return { | |
outputs: [{ [tool.name + "-0"]: outputs[tool.outputComponentIdx] }], | |
display: tool.showOutput, | |
}; | |
} | |
await Promise.all( | |
jp | |
.query(outputs[tool.outputComponentIdx], path) | |
.map(async (output: string | string[], idx) => { | |
const arrayedOutput = Array.isArray(output) ? output : [output]; | |
if (type === "file") { | |
// output files are actually URLs | |
await Promise.all( | |
arrayedOutput.map(async (output, idx) => { | |
await fetch(output) | |
.then((res) => res.blob()) | |
.then(async (blob) => { | |
const { ext, mime } = (await fileTypeFromBlob(blob)) ?? { ext: "octet-stream" }; | |
return new File( | |
[blob], | |
`${idx}-${await sha256(JSON.stringify(params))}.${ext}`, | |
{ | |
type: mime, | |
} | |
); | |
}) | |
.then((file) => uploadFile(file, ctx.conv)) | |
.then((file) => files.push(file)); | |
}) | |
); | |
toolOutputs.push({ | |
[tool.name + "-" + idx.toString()]: | |
`Only and always answer: 'I used the tool ${tool.displayName}, here is the result.' Don't add anything else.`, | |
}); | |
} else { | |
for (const output of arrayedOutput) { | |
toolOutputs.push({ | |
[tool.name + "-" + idx.toString()]: output, | |
}); | |
} | |
} | |
}) | |
); | |
for (const file of files) { | |
yield { | |
type: MessageUpdateType.File, | |
name: file.name, | |
sha: file.value, | |
mime: file.mime, | |
}; | |
} | |
return { outputs: toolOutputs, display: tool.showOutput }; | |
}; | |
} | |
export const toolFromConfigs = configTools.parse(JSON5.parse(env.TOOLS)) satisfies ConfigTool[]; | |