File size: 5,273 Bytes
70cdf7a
a8a9533
70cdf7a
 
 
 
564e576
06feee8
564e576
70cdf7a
 
 
 
 
a8a9533
75663fd
70cdf7a
46a8e79
70cdf7a
 
 
 
 
46a8e79
 
70cdf7a
 
 
 
 
 
75663fd
70cdf7a
 
fa62a1a
70cdf7a
 
564e576
70cdf7a
 
 
 
 
564e576
 
 
 
 
 
70cdf7a
 
 
 
 
 
 
 
564e576
70cdf7a
 
 
564e576
 
70cdf7a
 
 
46a8e79
70cdf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
77ea6f2
70cdf7a
06feee8
70cdf7a
564e576
 
31170da
564e576
 
 
 
 
 
 
 
 
 
 
7da6dd6
 
 
 
 
 
 
a1a6daf
7da6dd6
564e576
 
 
 
 
 
 
 
 
 
70cdf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
564e576
 
 
 
 
 
 
 
 
 
 
 
70cdf7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564e576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import { z } from "zod";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Cohere, CohereClient } from "cohere-ai";
import { buildPrompt } from "$lib/buildPrompt";
import { ToolResultStatus, type ToolCall } from "$lib/types/Tool";
import { pipeline, Writable, type Readable } from "node:stream";
import { toolHasName } from "$lib/utils/tools";

export const endpointCohereParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("cohere"),
	apiKey: z.string().default(env.COHERE_API_TOKEN),
	clientName: z.string().optional(),
	raw: z.boolean().default(false),
	forceSingleStep: z.boolean().default(true),
});

export async function endpointCohere(
	input: z.input<typeof endpointCohereParametersSchema>
): Promise<Endpoint> {
	const { apiKey, clientName, model, raw, forceSingleStep } =
		endpointCohereParametersSchema.parse(input);

	let cohere: CohereClient;

	try {
		cohere = new (await import("cohere-ai")).CohereClient({
			token: apiKey,
			clientName,
		});
	} catch (e) {
		throw new Error("Failed to import cohere-ai", { cause: e });
	}

	return async ({ messages, preprompt, generateSettings, continueMessage, tools, toolResults }) => {
		let system = preprompt;
		if (messages?.[0]?.from === "system") {
			system = messages[0].content;
		}

		// Tools must use [A-z_] for their names and directly_answer is banned
		// It's safe to convert the tool names because we treat - and _ the same
		tools = tools
			?.filter((tool) => !toolHasName("directly_answer", tool))
			.map((tool) => ({ ...tool, name: tool.name.replaceAll("-", "_") }));

		const parameters = { ...model.parameters, ...generateSettings };

		return (async function* () {
			let stream;
			let tokenId = 0;

			if (raw) {
				const prompt = await buildPrompt({
					messages,
					model,
					preprompt: system,
					continueMessage,
					tools,
					toolResults,
				});

				stream = await cohere.chatStream({
					forceSingleStep,
					message: prompt,
					rawPrompting: true,
					model: model.id ?? model.name,
					p: parameters?.top_p,
					k: parameters?.top_k,
					maxTokens: parameters?.max_new_tokens,
					temperature: parameters?.temperature,
					stopSequences: parameters?.stop,
					frequencyPenalty: parameters?.frequency_penalty,
				});
			} else {
				const formattedMessages = messages
					.filter((message) => message.from !== "system")
					.map((message) => ({
						role: message.from === "user" ? "USER" : "CHATBOT",
						message: message.content,
					})) satisfies Cohere.Message[];

				stream = await cohere
					.chatStream({
						forceSingleStep,
						model: model.id ?? model.name,
						chatHistory: formattedMessages.slice(0, -1),
						message: formattedMessages[formattedMessages.length - 1].message,
						preamble: system,
						p: parameters?.top_p,
						k: parameters?.top_k,
						maxTokens: parameters?.max_new_tokens,
						temperature: parameters?.temperature,
						stopSequences: parameters?.stop,
						frequencyPenalty: parameters?.frequency_penalty,
						tools,
						toolResults:
							toolResults?.length && toolResults?.length > 0
								? toolResults?.map((toolResult) => {
										if (toolResult.status === ToolResultStatus.Error) {
											return { call: toolResult.call, outputs: [{ error: toolResult.message }] };
										}
										return { call: toolResult.call, outputs: toolResult.outputs };
									})
								: undefined,
					})
					.catch(async (err) => {
						if (!err.body) throw err;

						// Decode the error message and throw
						const message = await convertStreamToBuffer(err.body).catch(() => {
							throw err;
						});
						throw Error(message, { cause: err });
					});
			}

			for await (const output of stream) {
				if (output.eventType === "text-generation") {
					yield {
						token: {
							id: tokenId++,
							text: output.text,
							logprob: 0,
							special: false,
						},
						generated_text: null,
						details: null,
					} satisfies TextGenerationStreamOutput;
				} else if (output.eventType === "tool-calls-generation") {
					yield {
						token: {
							id: tokenId++,
							text: "",
							logprob: 0,
							special: true,
							toolCalls: output.toolCalls as ToolCall[],
						},
						generated_text: null,
						details: null,
					};
				} else if (output.eventType === "stream-end") {
					if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) {
						throw new Error(output.finishReason);
					}
					yield {
						token: {
							id: tokenId++,
							text: "",
							logprob: 0,
							special: true,
						},
						generated_text: output.response.text,
						details: null,
					};
				}
			}
		})();
	};
}

async function convertStreamToBuffer(webReadableStream: Readable) {
	return new Promise<string>((resolve, reject) => {
		const chunks: Buffer[] = [];

		pipeline(
			webReadableStream,
			new Writable({
				write(chunk, _, callback) {
					chunks.push(chunk);
					callback();
				},
			}),
			(err) => {
				if (err) {
					reject(err);
				} else {
					resolve(Buffer.concat(chunks).toString("utf-8"));
				}
			}
		);
	});
}