Spaces:
Sleeping
Sleeping
File size: 4,974 Bytes
b9ec522 43606a3 f58e466 d13f9cf b5ae065 d13f9cf 11983d2 c087a6b b5ae065 c087a6b b07f0b1 c087a6b b5ae065 01b06a3 c087a6b 241ba68 c087a6b d13f9cf c087a6b b5ae065 b9ec522 b5ae065 b07f0b1 b5ae065 01b06a3 c202241 c087a6b b5ae065 c087a6b b5ae065 c087a6b 241ba68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import {
HF_TOKEN,
HF_API_ROOT,
MODELS,
OLD_MODELS,
TASK_MODEL,
HF_ACCESS_TOKEN,
} from "$env/static/private";
import type { ChatTemplateInput } from "$lib/types/Template";
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
promptExamples: z
.array(
z.object({
title: z.string().min(1),
prompt: z.string().min(1),
})
)
.optional(),
endpoints: z.array(endpointSchema).optional(),
parameters: z
.object({
temperature: z.number().min(0).max(1),
truncate: z.number().int().positive().optional(),
max_new_tokens: z.number().int().positive(),
stop: z.array(z.string()).optional(),
top_p: z.number().positive().optional(),
top_k: z.number().positive().optional(),
repetition_penalty: z.number().min(-2).max(2).optional(),
})
.passthrough()
.optional(),
multimodal: z.boolean().default(false),
});
const modelsRaw = z.array(modelConfig).parse(JSON.parse(MODELS));
const processModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
});
const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
...m,
getEndpoint: async (): Promise<Endpoint> => {
if (!m.endpoints) {
return endpointTgi({
type: "tgi",
url: `${HF_API_ROOT}/${m.name}`,
accessToken: HF_TOKEN ?? HF_ACCESS_TOKEN,
weight: 1,
model: m,
});
}
const totalWeight = sum(m.endpoints.map((e) => e.weight));
let random = Math.random() * totalWeight;
for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };
switch (args.type) {
case "tgi":
return endpoints.tgi(args);
case "aws":
return await endpoints.aws(args);
case "openai":
return await endpoints.openai(args);
case "llamacpp":
return endpoints.llamacpp(args);
case "ollama":
return endpoints.ollama(args);
default:
// for legacy reason
return endpoints.tgi(args);
}
}
random -= endpoint.weight;
}
throw new Error(`Failed to select endpoint`);
},
});
export const models = await Promise.all(modelsRaw.map((e) => processModel(e).then(addEndpoint)));
export const defaultModel = models[0];
// Models that have been deprecated
export const oldModels = OLD_MODELS
? z
.array(
z.object({
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
})
)
.parse(JSON.parse(OLD_MODELS))
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];
export const validateModel = (_models: BackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};
// if `TASK_MODEL` is the name of a model we use it, else we try to parse `TASK_MODEL` as a model config itself
export const smallModel = TASK_MODEL
? (models.find((m) => m.name === TASK_MODEL) ||
(await processModel(modelConfig.parse(JSON.parse(TASK_MODEL))).then((m) =>
addEndpoint(m)
))) ??
defaultModel
: defaultModel;
export type BackendModel = Optional<typeof defaultModel, "preprompt" | "parameters" | "multimodal">;
|