File size: 3,283 Bytes
b07f0b1
 
 
 
 
 
 
 
 
 
 
 
 
51b0991
 
 
4e43408
b07f0b1
2a808d7
 
 
b07f0b1
 
 
4e43408
 
30a5447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07f0b1
 
 
 
 
 
 
 
 
 
4e43408
 
 
 
 
 
b07f0b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import { buildPrompt } from "$lib/buildPrompt";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";

export const endpointOllamaParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("ollama"),
	url: z.string().url().default("http://127.0.0.1:11434"),
	ollamaName: z.string().min(1).optional(),
});

export function endpointOllama(input: z.input<typeof endpointOllamaParametersSchema>): Endpoint {
	const { url, model, ollamaName } = endpointOllamaParametersSchema.parse(input);

	return async ({ messages, preprompt, continueMessage, generateSettings }) => {
		const prompt = await buildPrompt({
			messages,
			continueMessage,
			preprompt,
			model,
		});

		const parameters = { ...model.parameters, ...generateSettings };

		const requestInfo = await fetch(`${url}/api/tags`, {
			method: "GET",
			headers: {
				"Content-Type": "application/json",
			},
		});

		const tags = await requestInfo.json();

		if (!tags.models.some((m: { name: string }) => m.name === ollamaName)) {
			// if its not in the tags, pull but dont wait for the answer
			fetch(`${url}/api/pull`, {
				method: "POST",
				headers: {
					"Content-Type": "application/json",
				},
				body: JSON.stringify({
					name: ollamaName ?? model.name,
					stream: false,
				}),
			});

			throw new Error("Currently pulling model from Ollama, please try again later.");
		}

		const r = await fetch(`${url}/api/generate`, {
			method: "POST",
			headers: {
				"Content-Type": "application/json",
			},
			body: JSON.stringify({
				prompt,
				model: ollamaName ?? model.name,
				raw: true,
				options: {
					top_p: parameters.top_p,
					top_k: parameters.top_k,
					temperature: parameters.temperature,
					repeat_penalty: parameters.repetition_penalty,
					stop: parameters.stop,
					num_predict: parameters.max_new_tokens,
				},
			}),
		});

		if (!r.ok) {
			throw new Error(`Failed to generate text: ${await r.text()}`);
		}

		const encoder = new TextDecoderStream();
		const reader = r.body?.pipeThrough(encoder).getReader();

		return (async function* () {
			let generatedText = "";
			let tokenId = 0;
			let stop = false;
			while (!stop) {
				// read the stream and log the outputs to console
				const out = (await reader?.read()) ?? { done: false, value: undefined };
				// we read, if it's done we cancel
				if (out.done) {
					reader?.cancel();
					return;
				}

				if (!out.value) {
					return;
				}

				let data = null;
				try {
					data = JSON.parse(out.value);
				} catch (e) {
					return;
				}
				if (!data.done) {
					generatedText += data.response;

					yield {
						token: {
							id: tokenId++,
							text: data.response ?? "",
							logprob: 0,
							special: false,
						},
						generated_text: null,
						details: null,
					} satisfies TextGenerationStreamOutput;
				} else {
					stop = true;
					yield {
						token: {
							id: tokenId++,
							text: data.response ?? "",
							logprob: 0,
							special: true,
						},
						generated_text: generatedText,
						details: null,
					} satisfies TextGenerationStreamOutput;
				}
			}
		})();
	};
}

export default endpointOllama;