File size: 3,618 Bytes
a8a9533
b5ae065
 
 
 
486ffa7
b5ae065
 
 
 
 
71b3944
 
8bd5344
b5ae065
 
51b0991
 
 
71b3944
4e43408
b5ae065
2a808d7
 
 
b5ae065
 
 
4e43408
 
71b3944
b5ae065
 
 
 
 
 
 
4e43408
 
 
 
 
 
d96c921
b5ae065
 
 
 
 
 
 
 
 
 
 
 
 
 
606fbf7
 
b5ae065
606fbf7
b5ae065
606fbf7
 
b5ae065
 
 
 
 
 
 
 
 
606fbf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dd92b8
dc98038
606fbf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379bd53
606fbf7
 
 
b5ae065
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import { env } from "$env/dynamic/private";
import { buildPrompt } from "$lib/buildPrompt";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";
import { logger } from "$lib/server/logger";

export const endpointLlamacppParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("llamacpp"),
	url: z.string().url().default("http://127.0.0.1:8080"), // legacy, feel free to remove in breaking change update
	baseURL: z.string().url().optional(),
	accessToken: z.string().default(env.HF_TOKEN ?? env.HF_ACCESS_TOKEN),
});

export function endpointLlamacpp(
	input: z.input<typeof endpointLlamacppParametersSchema>
): Endpoint {
	const { baseURL, url, model } = endpointLlamacppParametersSchema.parse(input);
	return async ({ messages, preprompt, continueMessage, generateSettings }) => {
		const prompt = await buildPrompt({
			messages,
			continueMessage,
			preprompt,
			model,
		});

		const parameters = { ...model.parameters, ...generateSettings };

		const r = await fetch(`${baseURL ?? url}/completion`, {
			method: "POST",
			headers: {
				"Content-Type": "application/json",
			},
			body: JSON.stringify({
				prompt,
				stream: true,
				temperature: parameters.temperature,
				top_p: parameters.top_p,
				top_k: parameters.top_k,
				stop: parameters.stop,
				repeat_penalty: parameters.repetition_penalty,
				n_predict: parameters.max_new_tokens,
				cache_prompt: true,
			}),
		});

		if (!r.ok) {
			throw new Error(`Failed to generate text: ${await r.text()}`);
		}

		const encoder = new TextDecoderStream();
		const reader = r.body?.pipeThrough(encoder).getReader();

		return (async function* () {
			let stop = false;
			let generatedText = "";
			let tokenId = 0;
			let accumulatedData = ""; // Buffer to accumulate data chunks

			while (!stop) {
				// Read the stream and log the outputs to console
				const out = (await reader?.read()) ?? { done: false, value: undefined };

				// If it's done, we cancel
				if (out.done) {
					reader?.cancel();
					return;
				}

				if (!out.value) {
					return;
				}

				// Accumulate the data chunk
				accumulatedData += out.value;

				// Process each complete JSON object in the accumulated data
				while (accumulatedData.includes("\n")) {
					// Assuming each JSON object ends with a newline
					const endIndex = accumulatedData.indexOf("\n");
					let jsonString = accumulatedData.substring(0, endIndex).trim();

					// Remove the processed part from the buffer
					accumulatedData = accumulatedData.substring(endIndex + 1);

					if (jsonString.startsWith("data: ")) {
						jsonString = jsonString.slice(6);
						let data = null;

						try {
							data = JSON.parse(jsonString);
						} catch (e) {
							logger.error(e, "Failed to parse JSON");
							logger.error(jsonString, "Problematic JSON string:");
							continue; // Skip this iteration and try the next chunk
						}

						// Handle the parsed data
						if (data.content || data.stop) {
							generatedText += data.content;
							const output: TextGenerationStreamOutput = {
								token: {
									id: tokenId++,
									text: data.content ?? "",
									logprob: 0,
									special: false,
								},
								generated_text: data.stop ? generatedText : null,
								details: null,
							};
							if (data.stop) {
								stop = true;
								output.token.special = true;
								reader?.cancel();
							}
							yield output;
						}
					}
				}
			}
		})();
	};
}

export default endpointLlamacpp;