Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,839 Bytes
6655689 28b6d44 6655689 28b6d44 6655689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import type { Tool } from "$lib/types/Tool";
import { extractJson } from "./utils";
import { externalToToolCall } from "../textGeneration/tools";
import { logger } from "../logger";
import type { Endpoint, EndpointMessage } from "../endpoints/endpoints";
interface GetToolOutputOptions {
messages: EndpointMessage[];
tool: Tool;
preprompt?: string;
endpoint: Endpoint;
generateSettings?: {
max_new_tokens?: number;
[key: string]: unknown;
};
}
export async function getToolOutput<T = string>({
messages,
preprompt,
tool,
endpoint,
generateSettings = { max_new_tokens: 64 },
}: GetToolOutputOptions): Promise<T | undefined> {
try {
const stream = await endpoint({
messages,
preprompt: preprompt + `\n\n Only use tool ${tool.name}.`,
tools: [tool],
generateSettings,
});
const calls = [];
for await (const output of stream) {
if (output.token.toolCalls) {
calls.push(...output.token.toolCalls);
}
if (output.generated_text) {
const extractedCalls = await extractJson(output.generated_text).then((calls) =>
calls.map((call) => externalToToolCall(call, [tool])).filter((call) => call !== undefined)
);
calls.push(...extractedCalls);
}
if (calls.length > 0) {
break;
}
}
if (calls.length > 0) {
// Find the tool call matching our tool
const toolCall = calls.find((call) => call.name === tool.name);
// If we found a matching call and it has parameters
if (toolCall?.parameters) {
// Get the first parameter value since most tools have a single main parameter
const firstParamValue = Object.values(toolCall.parameters)[0];
if (typeof firstParamValue === "string") {
return firstParamValue as T;
}
}
}
return undefined;
} catch (error) {
logger.warn(error, "Error getting tool output");
return undefined;
}
}
|