Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Extend endpointOai.ts to allow usage of extra sampling parameters (#1032)
Browse files* Extend endpointOai.ts to allow usage of extra sampling parameters when calling vllm as an OpenAI compatible
* refactor : prettier endpointOai.ts
* Fix: Corrected type imports in endpointOai.ts
* Simplifies code a bit and adds `extraBody` to open ai endpooint
* Update zod schema to allow any type in extraBody
---------
Co-authored-by: Nathan Sarrazin <[email protected]>
src/lib/server/endpoints/openai/endpointOai.ts
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import { z } from "zod";
|
2 |
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
|
3 |
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
|
|
|
|
|
4 |
import { buildPrompt } from "$lib/buildPrompt";
|
5 |
import { env } from "$env/dynamic/private";
|
6 |
import type { Endpoint } from "../endpoints";
|
@@ -16,12 +18,13 @@ export const endpointOAIParametersSchema = z.object({
|
|
16 |
.default("chat_completions"),
|
17 |
defaultHeaders: z.record(z.string()).optional(),
|
18 |
defaultQuery: z.record(z.string()).optional(),
|
|
|
19 |
});
|
20 |
|
21 |
export async function endpointOai(
|
22 |
input: z.input<typeof endpointOAIParametersSchema>
|
23 |
): Promise<Endpoint> {
|
24 |
-
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery } =
|
25 |
endpointOAIParametersSchema.parse(input);
|
26 |
let OpenAI;
|
27 |
try {
|
@@ -47,19 +50,22 @@ export async function endpointOai(
|
|
47 |
});
|
48 |
|
49 |
const parameters = { ...model.parameters, ...generateSettings };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
max_tokens: parameters?.max_new_tokens,
|
57 |
-
stop: parameters?.stop,
|
58 |
-
temperature: parameters?.temperature,
|
59 |
-
top_p: parameters?.top_p,
|
60 |
-
frequency_penalty: parameters?.repetition_penalty,
|
61 |
-
})
|
62 |
-
);
|
63 |
};
|
64 |
} else if (completion === "chat_completions") {
|
65 |
return async ({ messages, preprompt, generateSettings }) => {
|
@@ -77,19 +83,22 @@ export async function endpointOai(
|
|
77 |
}
|
78 |
|
79 |
const parameters = { ...model.parameters, ...generateSettings };
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
return openAIChatToTextGenerationStream(
|
82 |
-
await openai.chat.completions.create({
|
83 |
-
model: model.id ?? model.name,
|
84 |
-
messages: messagesOpenAI,
|
85 |
-
stream: true,
|
86 |
-
max_tokens: parameters?.max_new_tokens,
|
87 |
-
stop: parameters?.stop,
|
88 |
-
temperature: parameters?.temperature,
|
89 |
-
top_p: parameters?.top_p,
|
90 |
-
frequency_penalty: parameters?.repetition_penalty,
|
91 |
-
})
|
92 |
-
);
|
93 |
};
|
94 |
} else {
|
95 |
throw new Error("Invalid completion type");
|
|
|
1 |
import { z } from "zod";
|
2 |
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
|
3 |
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
|
4 |
+
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
|
5 |
+
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
|
6 |
import { buildPrompt } from "$lib/buildPrompt";
|
7 |
import { env } from "$env/dynamic/private";
|
8 |
import type { Endpoint } from "../endpoints";
|
|
|
18 |
.default("chat_completions"),
|
19 |
defaultHeaders: z.record(z.string()).optional(),
|
20 |
defaultQuery: z.record(z.string()).optional(),
|
21 |
+
extraBody: z.record(z.any()).optional(),
|
22 |
});
|
23 |
|
24 |
export async function endpointOai(
|
25 |
input: z.input<typeof endpointOAIParametersSchema>
|
26 |
): Promise<Endpoint> {
|
27 |
+
const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, extraBody } =
|
28 |
endpointOAIParametersSchema.parse(input);
|
29 |
let OpenAI;
|
30 |
try {
|
|
|
50 |
});
|
51 |
|
52 |
const parameters = { ...model.parameters, ...generateSettings };
|
53 |
+
const body: CompletionCreateParamsStreaming = {
|
54 |
+
model: model.id ?? model.name,
|
55 |
+
prompt,
|
56 |
+
stream: true,
|
57 |
+
max_tokens: parameters?.max_new_tokens,
|
58 |
+
stop: parameters?.stop,
|
59 |
+
temperature: parameters?.temperature,
|
60 |
+
top_p: parameters?.top_p,
|
61 |
+
frequency_penalty: parameters?.repetition_penalty,
|
62 |
+
};
|
63 |
|
64 |
+
const openAICompletion = await openai.completions.create(body, {
|
65 |
+
body: { ...body, ...extraBody },
|
66 |
+
});
|
67 |
+
|
68 |
+
return openAICompletionToTextGenerationStream(openAICompletion);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
};
|
70 |
} else if (completion === "chat_completions") {
|
71 |
return async ({ messages, preprompt, generateSettings }) => {
|
|
|
83 |
}
|
84 |
|
85 |
const parameters = { ...model.parameters, ...generateSettings };
|
86 |
+
const body: ChatCompletionCreateParamsStreaming = {
|
87 |
+
model: model.id ?? model.name,
|
88 |
+
messages: messagesOpenAI,
|
89 |
+
stream: true,
|
90 |
+
max_tokens: parameters?.max_new_tokens,
|
91 |
+
stop: parameters?.stop,
|
92 |
+
temperature: parameters?.temperature,
|
93 |
+
top_p: parameters?.top_p,
|
94 |
+
frequency_penalty: parameters?.repetition_penalty,
|
95 |
+
};
|
96 |
+
|
97 |
+
const openChatAICompletion = await openai.chat.completions.create(body, {
|
98 |
+
body: { ...body, ...extraBody },
|
99 |
+
});
|
100 |
|
101 |
+
return openAIChatToTextGenerationStream(openChatAICompletion);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
};
|
103 |
} else {
|
104 |
throw new Error("Invalid completion type");
|