Thomas G. Lopes commited on
Commit
b34bca6
·
1 Parent(s): 09f13ea

fix cohere & minimize tokenizer requests

Browse files
src/lib/components/inference-playground/provider-select.svelte CHANGED
@@ -45,7 +45,7 @@
45
  "nebius": "Nebius AI Studio",
46
  "hyperbolic": "Hyperbolic",
47
  "novita": "Novita",
48
- "cohere": "Nohere",
49
  "hf-inference": "HF Inference API",
50
  };
51
  const UPPERCASE_WORDS = ["hf", "ai"];
 
45
  "nebius": "Nebius AI Studio",
46
  "hyperbolic": "Hyperbolic",
47
  "novita": "Novita",
48
+ "cohere": "Cohere",
49
  "hf-inference": "HF Inference API",
50
  };
51
  const UPPERCASE_WORDS = ["hf", "ai"];
src/lib/components/inference-playground/utils.ts CHANGED
@@ -180,7 +180,9 @@ export async function handleNonStreamingResponse(
180
 
181
  export function isSystemPromptSupported(model: Model | CustomModel) {
182
  if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages
183
- return model?.config.tokenizer_config?.chat_template?.includes("system");
 
 
184
  }
185
 
186
  export const defaultSystemMessage: { [key: string]: string } = {
@@ -288,19 +290,25 @@ export function hasInferenceSnippet(
288
  return getInferenceSnippet(model, provider, language, "").length > 0;
289
  }
290
 
291
- const tokenizers = new Map<string, PreTrainedTokenizer>();
292
 
293
  export async function getTokenizer(model: Model) {
294
  if (tokenizers.has(model.id)) return tokenizers.get(model.id)!;
295
- const tokenizer = await AutoTokenizer.from_pretrained(model.id);
296
- tokenizers.set(model.id, tokenizer);
297
- return tokenizer;
 
 
 
 
 
298
  }
299
 
300
  export async function getTokens(conversation: Conversation): Promise<number> {
301
  const model = conversation.model;
302
  if (isCustomModel(model)) return 0;
303
  const tokenizer = await getTokenizer(model);
 
304
 
305
  // This is a simplified version - you might need to adjust based on your exact needs
306
  let formattedText = "";
 
180
 
181
  export function isSystemPromptSupported(model: Model | CustomModel) {
182
  if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages
183
+ const template = model?.config.tokenizer_config?.chat_template;
184
+ if (typeof template !== "string") return false;
185
+ return template.includes("system");
186
  }
187
 
188
  export const defaultSystemMessage: { [key: string]: string } = {
 
290
  return getInferenceSnippet(model, provider, language, "").length > 0;
291
  }
292
 
293
+ const tokenizers = new Map<string, PreTrainedTokenizer | null>();
294
 
295
  export async function getTokenizer(model: Model) {
296
  if (tokenizers.has(model.id)) return tokenizers.get(model.id)!;
297
+ try {
298
+ const tokenizer = await AutoTokenizer.from_pretrained(model.id);
299
+ tokenizers.set(model.id, tokenizer);
300
+ return tokenizer;
301
+ } catch {
302
+ tokenizers.set(model.id, null);
303
+ return null;
304
+ }
305
  }
306
 
307
  export async function getTokens(conversation: Conversation): Promise<number> {
308
  const model = conversation.model;
309
  if (isCustomModel(model)) return 0;
310
  const tokenizer = await getTokenizer(model);
311
+ if (tokenizer === null) return 0;
312
 
313
  // This is a simplified version - you might need to adjust based on your exact needs
314
  let formattedText = "";
src/lib/types.ts CHANGED
@@ -46,7 +46,7 @@ export type Session = {
46
  };
47
 
48
  interface TokenizerConfig {
49
- chat_template?: string;
50
  model_max_length?: number;
51
  }
52
 
@@ -156,7 +156,7 @@ export enum UnkTokenEnum {
156
  }
157
 
158
  export type InferenceProviderMapping = {
159
- provider: Provider;
160
  providerId: string;
161
  status: Status;
162
  task: Task;
@@ -173,6 +173,7 @@ export enum Provider {
173
  Replicate = "replicate",
174
  Sambanova = "sambanova",
175
  Together = "together",
 
176
  }
177
 
178
  export enum Status {
 
46
  };
47
 
48
  interface TokenizerConfig {
49
+ chat_template?: string | Array<{ name: string; template: string }>;
50
  model_max_length?: number;
51
  }
52
 
 
156
  }
157
 
158
  export type InferenceProviderMapping = {
159
+ provider: string;
160
  providerId: string;
161
  status: Status;
162
  task: Task;
 
173
  Replicate = "replicate",
174
  Sambanova = "sambanova",
175
  Together = "together",
176
+ Cohere = "cohere",
177
  }
178
 
179
  export enum Status {