File size: 2,703 Bytes
a8a9533
b5ae065
 
5b1a9aa
b5ae065
5b1a9aa
 
 
 
 
b5ae065
 
 
 
 
 
a8a9533
51b0991
5b1a9aa
 
 
 
 
 
 
bd01335
 
5b1a9aa
 
 
b5ae065
 
51b0991
5b1a9aa
 
 
 
564e576
 
 
 
 
 
 
 
791e118
564e576
5b1a9aa
 
 
c51ecfb
2a808d7
5b1a9aa
2a808d7
b5ae065
2a808d7
564e576
 
b5ae065
 
2dcc8e6
 
4e43408
2dcc8e6
 
 
 
51b0991
 
 
 
b9ec522
51b0991
 
 
791e118
51b0991
 
 
 
 
2dcc8e6
b5ae065
 
 
5b1a9aa
 
 
 
 
 
bd01335
5b1a9aa
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import { env } from "$env/dynamic/private";
import { buildPrompt } from "$lib/buildPrompt";
import { textGenerationStream } from "@huggingface/inference";
import type { Endpoint, EndpointMessage } from "../endpoints";
import { z } from "zod";
import {
	createImageProcessorOptionsValidator,
	makeImageProcessor,
	type ImageProcessor,
} from "../images";

export const endpointTgiParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("tgi"),
	url: z.string().url(),
	accessToken: z.string().default(env.HF_TOKEN ?? env.HF_ACCESS_TOKEN),
	authorization: z.string().optional(),
	multimodal: z
		.object({
			// Assumes IDEFICS
			image: createImageProcessorOptionsValidator({
				supportedMimeTypes: ["image/jpeg", "image/webp"],
				preferredMimeType: "image/webp",
				maxSizeInMB: 5,
				maxWidth: 378,
				maxHeight: 980,
			}),
		})
		.default({}),
});

export function endpointTgi(input: z.input<typeof endpointTgiParametersSchema>): Endpoint {
	const { url, accessToken, model, authorization, multimodal } =
		endpointTgiParametersSchema.parse(input);
	const imageProcessor = makeImageProcessor(multimodal.image);

	return async ({
		messages,
		preprompt,
		continueMessage,
		generateSettings,
		tools,
		toolResults,
		isMultimodal,
		conversationId,
	}) => {
		const messagesWithResizedFiles = await Promise.all(
			messages.map((message) => prepareMessage(Boolean(isMultimodal), message, imageProcessor))
		);

		const prompt = await buildPrompt({
			messages: messagesWithResizedFiles,
			preprompt,
			model,
			continueMessage,
			tools,
			toolResults,
		});

		return textGenerationStream(
			{
				parameters: { ...model.parameters, ...generateSettings, return_full_text: false },
				model: url,
				inputs: prompt,
				accessToken,
			},
			{
				use_cache: false,
				fetch: async (endpointUrl, info) => {
					if (info && authorization && !accessToken) {
						// Set authorization header if it is defined and HF_TOKEN is empty
						info.headers = {
							...info.headers,
							Authorization: authorization,
							"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
						};
					}
					return fetch(endpointUrl, info);
				},
			}
		);
	};
}

async function prepareMessage(
	isMultimodal: boolean,
	message: EndpointMessage,
	imageProcessor: ImageProcessor
): Promise<EndpointMessage> {
	if (!isMultimodal) return message;
	const files = await Promise.all(message.files?.map(imageProcessor) ?? []);
	const markdowns = files.map(
		(file) => `![](data:${file.mime};base64,${file.image.toString("base64")})`
	);
	const content = message.content + "\n" + markdowns.join("\n ");

	return { ...message, content };
}