taeminlee nsarrazin HF Staff commited on
Commit
ec2a4ed
·
unverified ·
1 Parent(s): 61e5613

Extend endpointOai.ts to allow usage of extra sampling parameters (#1032)

Browse files

* Extend endpointOai.ts to allow usage of extra sampling parameters when calling vllm as an OpenAI compatible

* refactor : prettier endpointOai.ts

* Fix: Corrected type imports in endpointOai.ts

* Simplifies code a bit and adds `extraBody` to open ai endpooint

* Update zod schema to allow any type in extraBody

---------

Co-authored-by: Nathan Sarrazin <[email protected]>

src/lib/server/endpoints/openai/endpointOai.ts CHANGED
@@ -1,6 +1,8 @@
1
  import { z } from "zod";
2
  import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
3
  import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
 
 
4
  import { buildPrompt } from "$lib/buildPrompt";
5
  import { env } from "$env/dynamic/private";
6
  import type { Endpoint } from "../endpoints";
@@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
16
  .default("chat_completions"),
17
  defaultHeaders: z.record(z.string()).optional(),
18
  defaultQuery: z.record(z.string()).optional(),
 
19
  });
20
 
21
  export async function endpointOai(
22
  input: z.input<typeof endpointOAIParametersSchema>
23
  ): Promise<Endpoint> {
24
- const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
25
  endpointOAIParametersSchema.parse(input);
26
  let OpenAI;
27
  try {
@@ -47,19 +50,22 @@ export async function endpointOai(
47
  });
48
 
49
  const parameters = { ...model.parameters, ...generateSettings };
 
 
 
 
 
 
 
 
 
 
50
 
51
- return openAICompletionToTextGenerationStream(
52
- await openai.completions.create({
53
- model: model.id ?? model.name,
54
- prompt,
55
- stream: true,
56
- max_tokens: parameters?.max_new_tokens,
57
- stop: parameters?.stop,
58
- temperature: parameters?.temperature,
59
- top_p: parameters?.top_p,
60
- frequency_penalty: parameters?.repetition_penalty,
61
- })
62
- );
63
  };
64
  } else if (completion === "chat_completions") {
65
  return async ({ messages, preprompt, generateSettings }) => {
@@ -77,19 +83,22 @@ export async function endpointOai(
77
  }
78
 
79
  const parameters = { ...model.parameters, ...generateSettings };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- return openAIChatToTextGenerationStream(
82
- await openai.chat.completions.create({
83
- model: model.id ?? model.name,
84
- messages: messagesOpenAI,
85
- stream: true,
86
- max_tokens: parameters?.max_new_tokens,
87
- stop: parameters?.stop,
88
- temperature: parameters?.temperature,
89
- top_p: parameters?.top_p,
90
- frequency_penalty: parameters?.repetition_penalty,
91
- })
92
- );
93
  };
94
  } else {
95
  throw new Error("Invalid completion type");
 
1
  import { z } from "zod";
2
  import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
3
  import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
4
+ import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
5
+ import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
6
  import { buildPrompt } from "$lib/buildPrompt";
7
  import { env } from "$env/dynamic/private";
8
  import type { Endpoint } from "../endpoints";
 
18
  .default("chat_completions"),
19
  defaultHeaders: z.record(z.string()).optional(),
20
  defaultQuery: z.record(z.string()).optional(),
21
+ extraBody: z.record(z.any()).optional(),
22
  });
23
 
24
  export async function endpointOai(
25
  input: z.input<typeof endpointOAIParametersSchema>
26
  ): Promise<Endpoint> {
27
+ const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
28
  endpointOAIParametersSchema.parse(input);
29
  let OpenAI;
30
  try {
 
50
  });
51
 
52
  const parameters = { ...model.parameters, ...generateSettings };
53
+ const body: CompletionCreateParamsStreaming = {
54
+ model: model.id ?? model.name,
55
+ prompt,
56
+ stream: true,
57
+ max_tokens: parameters?.max_new_tokens,
58
+ stop: parameters?.stop,
59
+ temperature: parameters?.temperature,
60
+ top_p: parameters?.top_p,
61
+ frequency_penalty: parameters?.repetition_penalty,
62
+ };
63
 
64
+ const openAICompletion = await openai.completions.create(body, {
65
+ body: { ...body, ...extraBody },
66
+ });
67
+
68
+ return openAICompletionToTextGenerationStream(openAICompletion);
 
 
 
 
 
 
 
69
  };
70
  } else if (completion === "chat_completions") {
71
  return async ({ messages, preprompt, generateSettings }) => {
 
83
  }
84
 
85
  const parameters = { ...model.parameters, ...generateSettings };
86
+ const body: ChatCompletionCreateParamsStreaming = {
87
+ model: model.id ?? model.name,
88
+ messages: messagesOpenAI,
89
+ stream: true,
90
+ max_tokens: parameters?.max_new_tokens,
91
+ stop: parameters?.stop,
92
+ temperature: parameters?.temperature,
93
+ top_p: parameters?.top_p,
94
+ frequency_penalty: parameters?.repetition_penalty,
95
+ };
96
+
97
+ const openChatAICompletion = await openai.chat.completions.create(body, {
98
+ body: { ...body, ...extraBody },
99
+ });
100
 
101
+ return openAIChatToTextGenerationStream(openChatAICompletion);
 
 
 
 
 
 
 
 
 
 
 
102
  };
103
  } else {
104
  throw new Error("Invalid completion type");