import { ToolResultStatus, type ToolCall, type ToolResult } from "$lib/types/Tool"; import { v4 as uuidV4 } from "uuid"; import type { BackendTool, BackendToolContext } from "../tools"; import { MessageToolUpdateType, MessageUpdateStatus, MessageUpdateType, type MessageUpdate, } from "$lib/types/MessageUpdate"; import type { TextGenerationContext } from "./types"; import { allTools } from "../tools"; import directlyAnswer from "../tools/directlyAnswer"; import websearch from "../tools/web/search"; import { z } from "zod"; import { logger } from "../logger"; import { extractJson, toolHasName } from "../tools/utils"; import type { MessageFile } from "$lib/types/Message"; import { mergeAsyncGenerators } from "$lib/utils/mergeAsyncGenerators"; import { MetricsServer } from "../metrics"; import { stringifyError } from "$lib/utils/stringifyError"; function makeFilesPrompt(files: MessageFile[], fileMessageIndex: number): string { if (files.length === 0) { return "The user has not uploaded any files. Do not attempt to use any tools that require files"; } const stringifiedFiles = files .map( (file, fileIndex) => ` - fileMessageIndex ${fileMessageIndex} | fileIndex ${fileIndex} | ${file.name} (${file.mime})` ) .join("\n"); return `Attached ${files.length} file${files.length === 1 ? "" : "s"}:\n${stringifiedFiles}`; } export function pickTools( toolsPreference: Record, isAssistant: boolean ): BackendTool[] { // if it's an assistant, only support websearch for now if (isAssistant) return [directlyAnswer, websearch]; // filter based on tool preferences, add the tools that are on by default return allTools.filter((el) => { if (el.isLocked && el.isOnByDefault) return true; return toolsPreference?.[el.name] ?? el.isOnByDefault; }); } async function* callTool( ctx: BackendToolContext, tools: BackendTool[], call: ToolCall ): AsyncGenerator { const uuid = uuidV4(); const tool = tools.find((el) => toolHasName(call.name, el)); if (!tool) { return { call, status: ToolResultStatus.Error, message: `Could not find tool "${call.name}"` }; } // Special case for directly_answer tool where we ignore if (toolHasName(directlyAnswer.name, tool)) return; const startTime = Date.now(); MetricsServer.getMetrics().tool.toolUseCount.inc({ tool: call.name }); yield { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.Call, uuid, call, }; try { const toolResult = yield* tool.call(call.parameters, ctx); yield { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.Result, uuid, result: { ...toolResult, call } as ToolResult, }; MetricsServer.getMetrics().tool.toolUseDuration.observe( { tool: call.name }, Date.now() - startTime ); return { ...toolResult, call } as ToolResult; } catch (error) { MetricsServer.getMetrics().tool.toolUseCountError.inc({ tool: call.name }); logger.error(error, `Failed while running tool ${call.name}. ${stringifyError(error)}`); yield { type: MessageUpdateType.Tool, subtype: MessageToolUpdateType.Error, uuid, message: "Error occurred", }; return { call, status: ToolResultStatus.Error, message: "Error occurred", }; } } export async function* runTools( ctx: TextGenerationContext, tools: BackendTool[], preprompt?: string ): AsyncGenerator { const { endpoint, conv, messages, assistant, ip, username } = ctx; const calls: ToolCall[] = []; const messagesWithFilesPrompt = messages.map((message, idx) => { if (!message.files?.length) return message; return { ...message, content: `${message.content}\n${makeFilesPrompt(message.files, idx)}`, }; }); const pickToolStartTime = Date.now(); // do the function calling bits here for await (const output of await endpoint({ messages: messagesWithFilesPrompt, preprompt, generateSettings: assistant?.generateSettings, tools, })) { // model natively supports tool calls if (output.token.toolCalls) { calls.push(...output.token.toolCalls); continue; } // look for a code blocks of ```json and parse them // if they're valid json, add them to the calls array if (output.generated_text) { try { const rawCalls = await extractJson(output.generated_text); const newCalls = rawCalls .filter(isExternalToolCall) .map(externalToToolCall) .filter((call) => call !== undefined) as ToolCall[]; calls.push(...newCalls); } catch (e) { logger.error(e, "Error while parsing tool calls, please retry"); // error parsing the calls yield { type: MessageUpdateType.Status, status: MessageUpdateStatus.Error, message: "Error while parsing tool calls, please retry", }; } } } MetricsServer.getMetrics().tool.timeToChooseTools.observe( { model: conv.model }, Date.now() - pickToolStartTime ); const toolContext: BackendToolContext = { conv, messages, preprompt, assistant, ip, username }; const toolResults: (ToolResult | undefined)[] = yield* mergeAsyncGenerators( calls.map((call) => callTool(toolContext, tools, call)) ); return toolResults.filter((result): result is ToolResult => result !== undefined); } const externalToolCall = z.object({ tool_name: z.string(), parameters: z.record(z.any()), }); type ExternalToolCall = z.infer; function isExternalToolCall(call: unknown): call is ExternalToolCall { return externalToolCall.safeParse(call).success; } function externalToToolCall(call: ExternalToolCall): ToolCall | undefined { // Convert - to _ since some models insist on using _ instead of - const tool = allTools.find((tool) => toolHasName(call.tool_name, tool)); if (!tool) { logger.debug(`Model requested tool that does not exist: "${call.tool_name}". Skipping tool...`); return; } const parametersWithDefaults: Record = {}; for (const [key, definition] of Object.entries(tool.parameterDefinitions)) { const value = call.parameters[key]; // Required so ensure it's there, otherwise return undefined if (definition.required) { if (value === undefined) { logger.debug( `Model requested tool "${call.tool_name}" but was missing required parameter "${key}". Skipping tool...` ); return; } parametersWithDefaults[key] = value; continue; } // Optional so use default if not there parametersWithDefaults[key] = value ?? definition.default; } return { name: call.tool_name, parameters: parametersWithDefaults, }; }