Spaces:
Sleeping
Sleeping
File size: 5,260 Bytes
b9ec522 43606a3 f58e466 d13f9cf b5ae065 b17a5c8 d13f9cf 20a343f 11983d2 c087a6b b5ae065 c087a6b b07f0b1 c087a6b b5ae065 01b06a3 c087a6b 241ba68 4a66e10 b17a5c8 c087a6b d13f9cf 20a343f c087a6b b5ae065 b9ec522 b5ae065 b07f0b1 b5ae065 01b06a3 c202241 20a343f c202241 c087a6b d885316 b5ae065 c087a6b b5ae065 20a343f b5ae065 c087a6b 4a66e10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import {
HF_TOKEN,
HF_API_ROOT,
MODELS,
OLD_MODELS,
TASK_MODEL,
HF_ACCESS_TOKEN,
} from "$env/static/private";
import type { ChatTemplateInput } from "$lib/types/Template";
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import endpointTgi from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
import JSON5 from "json5";
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
promptExamples: z
.array(
z.object({
title: z.string().min(1),
prompt: z.string().min(1),
})
)
.optional(),
endpoints: z.array(endpointSchema).optional(),
parameters: z
.object({
temperature: z.number().min(0).max(1),
truncate: z.number().int().positive().optional(),
max_new_tokens: z.number().int().positive(),
stop: z.array(z.string()).optional(),
top_p: z.number().positive().optional(),
top_k: z.number().positive().optional(),
repetition_penalty: z.number().min(-2).max(2).optional(),
})
.passthrough()
.optional(),
multimodal: z.boolean().default(false),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
});
const modelsRaw = z.array(modelConfig).parse(JSON5.parse(MODELS));
const processModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
});
const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
...m,
getEndpoint: async (): Promise<Endpoint> => {
if (!m.endpoints) {
return endpointTgi({
type: "tgi",
url: `${HF_API_ROOT}/${m.name}`,
accessToken: HF_TOKEN ?? HF_ACCESS_TOKEN,
weight: 1,
model: m,
});
}
const totalWeight = sum(m.endpoints.map((e) => e.weight));
let random = Math.random() * totalWeight;
for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };
switch (args.type) {
case "tgi":
return endpoints.tgi(args);
case "aws":
return await endpoints.aws(args);
case "openai":
return await endpoints.openai(args);
case "llamacpp":
return endpoints.llamacpp(args);
case "ollama":
return endpoints.ollama(args);
default:
// for legacy reason
return endpoints.tgi(args);
}
}
random -= endpoint.weight;
}
throw new Error(`Failed to select endpoint`);
},
});
export const models = await Promise.all(modelsRaw.map((e) => processModel(e).then(addEndpoint)));
export const defaultModel = models[0];
// Models that have been deprecated
export const oldModels = OLD_MODELS
? z
.array(
z.object({
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
})
)
.parse(JSON5.parse(OLD_MODELS))
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];
export const validateModel = (_models: BackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};
// if `TASK_MODEL` is string & name of a model in `MODELS`, then we use `MODELS[TASK_MODEL]`, else we try to parse `TASK_MODEL` as a model config itself
export const smallModel = TASK_MODEL
? (models.find((m) => m.name === TASK_MODEL) ||
(await processModel(modelConfig.parse(JSON5.parse(TASK_MODEL))).then((m) =>
addEndpoint(m)
))) ??
defaultModel
: defaultModel;
export type BackendModel = Optional<
typeof defaultModel,
"preprompt" | "parameters" | "multimodal" | "unlisted"
>;
|