File size: 1,839 Bytes
6655689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b6d44
 
 
 
6655689
 
 
 
 
 
 
 
 
 
28b6d44
 
 
6655689
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import type { Tool } from "$lib/types/Tool";
import { extractJson } from "./utils";
import { externalToToolCall } from "../textGeneration/tools";
import { logger } from "../logger";
import type { Endpoint, EndpointMessage } from "../endpoints/endpoints";

interface GetToolOutputOptions {
	messages: EndpointMessage[];
	tool: Tool;
	preprompt?: string;
	endpoint: Endpoint;
	generateSettings?: {
		max_new_tokens?: number;
		[key: string]: unknown;
	};
}

export async function getToolOutput<T = string>({
	messages,
	preprompt,
	tool,
	endpoint,
	generateSettings = { max_new_tokens: 64 },
}: GetToolOutputOptions): Promise<T | undefined> {
	try {
		const stream = await endpoint({
			messages,
			preprompt: preprompt + `\n\n Only use tool ${tool.name}.`,
			tools: [tool],
			generateSettings,
		});

		const calls = [];

		for await (const output of stream) {
			if (output.token.toolCalls) {
				calls.push(...output.token.toolCalls);
			}
			if (output.generated_text) {
				const extractedCalls = await extractJson(output.generated_text).then((calls) =>
					calls.map((call) => externalToToolCall(call, [tool])).filter((call) => call !== undefined)
				);
				calls.push(...extractedCalls);
			}

			if (calls.length > 0) {
				break;
			}
		}

		if (calls.length > 0) {
			// Find the tool call matching our tool
			const toolCall = calls.find((call) => call.name === tool.name);

			// If we found a matching call and it has parameters
			if (toolCall?.parameters) {
				// Get the first parameter value since most tools have a single main parameter
				const firstParamValue = Object.values(toolCall.parameters)[0];
				if (typeof firstParamValue === "string") {
					return firstParamValue as T;
				}
			}
		}

		return undefined;
	} catch (error) {
		logger.warn(error, "Error getting tool output");
		return undefined;
	}
}