Spaces:
Sleeping
Sleeping
File size: 4,802 Bytes
ee1ec85 0fc59f9 f58e466 d13f9cf 11983d2 81812fb 01b06a3 d13f9cf 01b06a3 c85ac8f 01b06a3 7a3250d 01b06a3 95af686 01b06a3 95af686 f58e466 5ebe152 f58e466 5ebe152 a18e88d 01b06a3 f58e466 0fc59f9 f58e466 0fc59f9 f58e466 01b06a3 81812fb 01b06a3 f16c630 01b06a3 c202241 01b06a3 d13f9cf 9891123 d13f9cf 01b06a3 5ebe152 f58e466 c85ac8f 01b06a3 d13f9cf 01b06a3 c202241 11983d2 81812fb 01b06a3 c202241 ee1ec85 c202241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import { HF_ACCESS_TOKEN, MODELS, OLD_MODELS, TASK_MODEL } from "$env/static/private";
import type { ChatTemplateInput, WebSearchQueryTemplateInput } from "$lib/types/Template";
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
const sagemakerEndpoint = z.object({
host: z.literal("sagemaker"),
url: z.string().url(),
accessKey: z.string().min(1),
secretKey: z.string().min(1),
sessionToken: z.string().optional(),
});
const tgiEndpoint = z.object({
host: z.union([z.literal("tgi"), z.undefined()]),
url: z.string().url(),
authorization: z.string().min(1).default(`Bearer ${HF_ACCESS_TOKEN}`),
});
const commonEndpoint = z.object({
weight: z.number().int().positive().default(1),
});
const endpoint = z.lazy(() =>
z.union([sagemakerEndpoint.merge(commonEndpoint), tgiEndpoint.merge(commonEndpoint)])
);
const combinedEndpoint = endpoint.transform((data) => {
if (data.host === "tgi" || data.host === undefined) {
return tgiEndpoint.merge(commonEndpoint).parse(data);
} else if (data.host === "sagemaker") {
return sagemakerEndpoint.merge(commonEndpoint).parse(data);
} else {
throw new Error(`Invalid host: ${data.host}`);
}
});
const modelsRaw = z
.array(
z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().min(1),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
userMessageToken: z.string().default(""),
userMessageEndToken: z.string().default(""),
assistantMessageToken: z.string().default(""),
assistantMessageEndToken: z.string().default(""),
messageEndToken: z.string().default(""),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z
.string()
.default(
"{{preprompt}}" +
"{{#each messages}}" +
"{{#ifUser}}{{@root.userMessageToken}}{{content}}{{@root.userMessageEndToken}}{{/ifUser}}" +
"{{#ifAssistant}}{{@root.assistantMessageToken}}{{content}}{{@root.assistantMessageEndToken}}{{/ifAssistant}}" +
"{{/each}}" +
"{{assistantMessageToken}}"
),
webSearchQueryPromptTemplate: z
.string()
.default(
"{{userMessageToken}}" +
'My question is: "{{message.content}}". ' +
"Based on the conversation history (my previous questions are: {{previousMessages}}), give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query. For the context, today is {{currentDate}}" +
"{{userMessageEndToken}}" +
"{{assistantMessageToken}}"
),
promptExamples: z
.array(
z.object({
title: z.string().min(1),
prompt: z.string().min(1),
})
)
.optional(),
endpoints: z.array(combinedEndpoint).optional(),
parameters: z
.object({
temperature: z.number().min(0).max(1),
truncate: z.number().int().positive(),
max_new_tokens: z.number().int().positive(),
stop: z.array(z.string()).optional(),
})
.passthrough()
.optional(),
})
)
.parse(JSON.parse(MODELS));
export const models = await Promise.all(
modelsRaw.map(async (m) => ({
...m,
userMessageEndToken: m?.userMessageEndToken || m?.messageEndToken,
assistantMessageEndToken: m?.assistantMessageEndToken || m?.messageEndToken,
chatPromptRender: compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m),
webSearchQueryPromptRender: compileTemplate<WebSearchQueryTemplateInput>(
m.webSearchQueryPromptTemplate,
m
),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
}))
);
// Models that have been deprecated
export const oldModels = OLD_MODELS
? z
.array(
z.object({
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
})
)
.parse(JSON.parse(OLD_MODELS))
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];
export type BackendModel = Optional<(typeof models)[0], "preprompt">;
export type Endpoint = z.infer<typeof endpoint>;
export const defaultModel = models[0];
export const smallModel = models.find((m) => m.name === TASK_MODEL) || defaultModel;
export const validateModel = (_models: BackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};
|