nsarrazin HF Staff commited on
Commit
8019701
·
unverified ·
1 Parent(s): 97f1bcf

Add sampling parameter support to cloudflare (#1233)

Browse files
src/lib/server/endpoints/cloudflare/endpointCloudflare.ts CHANGED
@@ -18,7 +18,7 @@ export async function endpointCloudflare(
18
  const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);
19
  const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`;
20
 
21
- return async ({ messages, preprompt }) => {
22
  let messagesFormatted = messages.map((message) => ({
23
  role: message.from,
24
  content: message.content,
@@ -28,9 +28,16 @@ export async function endpointCloudflare(
28
  messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
29
  }
30
 
 
 
31
  const payload = JSON.stringify({
32
  messages: messagesFormatted,
33
  stream: true,
 
 
 
 
 
34
  });
35
 
36
  const res = await fetch(apiURL, {
 
18
  const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input);
19
  const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`;
20
 
21
+ return async ({ messages, preprompt, generateSettings }) => {
22
  let messagesFormatted = messages.map((message) => ({
23
  role: message.from,
24
  content: message.content,
 
28
  messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted];
29
  }
30
 
31
+ const parameters = { ...model.parameters, ...generateSettings };
32
+
33
  const payload = JSON.stringify({
34
  messages: messagesFormatted,
35
  stream: true,
36
+ max_tokens: parameters?.max_new_tokens,
37
+ temperature: parameters?.temperature,
38
+ top_p: parameters?.top_p,
39
+ top_k: parameters?.top_k,
40
+ repetition_penalty: parameters?.repetition_penalty,
41
  });
42
 
43
  const res = await fetch(apiURL, {