Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,618 Bytes
a8a9533 b5ae065 486ffa7 b5ae065 71b3944 8bd5344 b5ae065 51b0991 71b3944 4e43408 b5ae065 2a808d7 b5ae065 4e43408 71b3944 b5ae065 4e43408 d96c921 b5ae065 606fbf7 b5ae065 606fbf7 b5ae065 606fbf7 b5ae065 606fbf7 7dd92b8 dc98038 606fbf7 379bd53 606fbf7 b5ae065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import { env } from "$env/dynamic/private";
import { buildPrompt } from "$lib/buildPrompt";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import type { Endpoint } from "../endpoints";
import { z } from "zod";
import { logger } from "$lib/server/logger";
export const endpointLlamacppParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("llamacpp"),
url: z.string().url().default("http://127.0.0.1:8080"), // legacy, feel free to remove in breaking change update
baseURL: z.string().url().optional(),
accessToken: z.string().default(env.HF_TOKEN ?? env.HF_ACCESS_TOKEN),
});
export function endpointLlamacpp(
input: z.input<typeof endpointLlamacppParametersSchema>
): Endpoint {
const { baseURL, url, model } = endpointLlamacppParametersSchema.parse(input);
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});
const parameters = { ...model.parameters, ...generateSettings };
const r = await fetch(`${baseURL ?? url}/completion`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
prompt,
stream: true,
temperature: parameters.temperature,
top_p: parameters.top_p,
top_k: parameters.top_k,
stop: parameters.stop,
repeat_penalty: parameters.repetition_penalty,
n_predict: parameters.max_new_tokens,
cache_prompt: true,
}),
});
if (!r.ok) {
throw new Error(`Failed to generate text: ${await r.text()}`);
}
const encoder = new TextDecoderStream();
const reader = r.body?.pipeThrough(encoder).getReader();
return (async function* () {
let stop = false;
let generatedText = "";
let tokenId = 0;
let accumulatedData = ""; // Buffer to accumulate data chunks
while (!stop) {
// Read the stream and log the outputs to console
const out = (await reader?.read()) ?? { done: false, value: undefined };
// If it's done, we cancel
if (out.done) {
reader?.cancel();
return;
}
if (!out.value) {
return;
}
// Accumulate the data chunk
accumulatedData += out.value;
// Process each complete JSON object in the accumulated data
while (accumulatedData.includes("\n")) {
// Assuming each JSON object ends with a newline
const endIndex = accumulatedData.indexOf("\n");
let jsonString = accumulatedData.substring(0, endIndex).trim();
// Remove the processed part from the buffer
accumulatedData = accumulatedData.substring(endIndex + 1);
if (jsonString.startsWith("data: ")) {
jsonString = jsonString.slice(6);
let data = null;
try {
data = JSON.parse(jsonString);
} catch (e) {
logger.error(e, "Failed to parse JSON");
logger.error(jsonString, "Problematic JSON string:");
continue; // Skip this iteration and try the next chunk
}
// Handle the parsed data
if (data.content || data.stop) {
generatedText += data.content;
const output: TextGenerationStreamOutput = {
token: {
id: tokenId++,
text: data.content ?? "",
logprob: 0,
special: false,
},
generated_text: data.stop ? generatedText : null,
details: null,
};
if (data.stop) {
stop = true;
output.token.special = true;
reader?.cancel();
}
yield output;
}
}
}
}
})();
};
}
export default endpointLlamacpp;
|