Thomas G. Lopes
commited on
Commit
·
b34bca6
1
Parent(s):
09f13ea
fix cohere & minimize tokenizer requests
Browse files
src/lib/components/inference-playground/provider-select.svelte
CHANGED
@@ -45,7 +45,7 @@
|
|
45 |
"nebius": "Nebius AI Studio",
|
46 |
"hyperbolic": "Hyperbolic",
|
47 |
"novita": "Novita",
|
48 |
-
"cohere": "
|
49 |
"hf-inference": "HF Inference API",
|
50 |
};
|
51 |
const UPPERCASE_WORDS = ["hf", "ai"];
|
|
|
45 |
"nebius": "Nebius AI Studio",
|
46 |
"hyperbolic": "Hyperbolic",
|
47 |
"novita": "Novita",
|
48 |
+
"cohere": "Cohere",
|
49 |
"hf-inference": "HF Inference API",
|
50 |
};
|
51 |
const UPPERCASE_WORDS = ["hf", "ai"];
|
src/lib/components/inference-playground/utils.ts
CHANGED
@@ -180,7 +180,9 @@ export async function handleNonStreamingResponse(
|
|
180 |
|
181 |
export function isSystemPromptSupported(model: Model | CustomModel) {
|
182 |
if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages
|
183 |
-
|
|
|
|
|
184 |
}
|
185 |
|
186 |
export const defaultSystemMessage: { [key: string]: string } = {
|
@@ -288,19 +290,25 @@ export function hasInferenceSnippet(
|
|
288 |
return getInferenceSnippet(model, provider, language, "").length > 0;
|
289 |
}
|
290 |
|
291 |
-
const tokenizers = new Map<string, PreTrainedTokenizer>();
|
292 |
|
293 |
export async function getTokenizer(model: Model) {
|
294 |
if (tokenizers.has(model.id)) return tokenizers.get(model.id)!;
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
298 |
}
|
299 |
|
300 |
export async function getTokens(conversation: Conversation): Promise<number> {
|
301 |
const model = conversation.model;
|
302 |
if (isCustomModel(model)) return 0;
|
303 |
const tokenizer = await getTokenizer(model);
|
|
|
304 |
|
305 |
// This is a simplified version - you might need to adjust based on your exact needs
|
306 |
let formattedText = "";
|
|
|
180 |
|
181 |
export function isSystemPromptSupported(model: Model | CustomModel) {
|
182 |
if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages
|
183 |
+
const template = model?.config.tokenizer_config?.chat_template;
|
184 |
+
if (typeof template !== "string") return false;
|
185 |
+
return template.includes("system");
|
186 |
}
|
187 |
|
188 |
export const defaultSystemMessage: { [key: string]: string } = {
|
|
|
290 |
return getInferenceSnippet(model, provider, language, "").length > 0;
|
291 |
}
|
292 |
|
293 |
+
const tokenizers = new Map<string, PreTrainedTokenizer | null>();
|
294 |
|
295 |
export async function getTokenizer(model: Model) {
|
296 |
if (tokenizers.has(model.id)) return tokenizers.get(model.id)!;
|
297 |
+
try {
|
298 |
+
const tokenizer = await AutoTokenizer.from_pretrained(model.id);
|
299 |
+
tokenizers.set(model.id, tokenizer);
|
300 |
+
return tokenizer;
|
301 |
+
} catch {
|
302 |
+
tokenizers.set(model.id, null);
|
303 |
+
return null;
|
304 |
+
}
|
305 |
}
|
306 |
|
307 |
export async function getTokens(conversation: Conversation): Promise<number> {
|
308 |
const model = conversation.model;
|
309 |
if (isCustomModel(model)) return 0;
|
310 |
const tokenizer = await getTokenizer(model);
|
311 |
+
if (tokenizer === null) return 0;
|
312 |
|
313 |
// This is a simplified version - you might need to adjust based on your exact needs
|
314 |
let formattedText = "";
|
src/lib/types.ts
CHANGED
@@ -46,7 +46,7 @@ export type Session = {
|
|
46 |
};
|
47 |
|
48 |
interface TokenizerConfig {
|
49 |
-
chat_template?: string;
|
50 |
model_max_length?: number;
|
51 |
}
|
52 |
|
@@ -156,7 +156,7 @@ export enum UnkTokenEnum {
|
|
156 |
}
|
157 |
|
158 |
export type InferenceProviderMapping = {
|
159 |
-
provider:
|
160 |
providerId: string;
|
161 |
status: Status;
|
162 |
task: Task;
|
@@ -173,6 +173,7 @@ export enum Provider {
|
|
173 |
Replicate = "replicate",
|
174 |
Sambanova = "sambanova",
|
175 |
Together = "together",
|
|
|
176 |
}
|
177 |
|
178 |
export enum Status {
|
|
|
46 |
};
|
47 |
|
48 |
interface TokenizerConfig {
|
49 |
+
chat_template?: string | Array<{ name: string; template: string }>;
|
50 |
model_max_length?: number;
|
51 |
}
|
52 |
|
|
|
156 |
}
|
157 |
|
158 |
export type InferenceProviderMapping = {
|
159 |
+
provider: string;
|
160 |
providerId: string;
|
161 |
status: Status;
|
162 |
task: Task;
|
|
|
173 |
Replicate = "replicate",
|
174 |
Sambanova = "sambanova",
|
175 |
Together = "together",
|
176 |
+
Cohere = "cohere",
|
177 |
}
|
178 |
|
179 |
export enum Status {
|