Liam Dyer
Generic Multimodal Support Fixed (#1147)
5b1a9aa unverified
raw
history blame
5.16 kB
import { z } from "zod";
import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream";
import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream";
import type { CompletionCreateParamsStreaming } from "openai/resources/completions";
import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions";
import { buildPrompt } from "$lib/buildPrompt";
import { env } from "$env/dynamic/private";
import type { Endpoint } from "../endpoints";
import type OpenAI from "openai";
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
import type { MessageFile } from "$lib/types/Message";
import type { EndpointMessage } from "../endpoints";
export const endpointOAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("openai"),
baseURL: z.string().url().default("https://api.openai.com/v1"),
apiKey: z.string().default(env.OPENAI_API_KEY ?? "sk-"),
completion: z
.union([z.literal("completions"), z.literal("chat_completions")])
.default("chat_completions"),
defaultHeaders: z.record(z.string()).optional(),
defaultQuery: z.record(z.string()).optional(),
extraBody: z.record(z.any()).optional(),
multimodal: z
.object({
image: createImageProcessorOptionsValidator({
supportedMimeTypes: [
"image/png",
"image/jpeg",
"image/webp",
"image/avif",
"image/tiff",
"image/gif",
],
preferredMimeType: "image/webp",
maxSizeInMB: Infinity,
maxWidth: 4096,
maxHeight: 4096,
}),
})
.default({}),
});
export async function endpointOai(
input: z.input<typeof endpointOAIParametersSchema>
): Promise<Endpoint> {
const {
baseURL,
apiKey,
completion,
model,
defaultHeaders,
defaultQuery,
multimodal,
extraBody,
} = endpointOAIParametersSchema.parse(input);
/* eslint-disable-next-line no-shadow */
let OpenAI;
try {
OpenAI = (await import("openai")).OpenAI;
} catch (e) {
throw new Error("Failed to import OpenAI", { cause: e });
}
const openai = new OpenAI({
apiKey: apiKey ?? "sk-",
baseURL,
defaultHeaders,
defaultQuery,
});
const imageProcessor = makeImageProcessor(multimodal.image);
if (completion === "completions") {
return async ({ messages, preprompt, continueMessage, generateSettings }) => {
const prompt = await buildPrompt({
messages,
continueMessage,
preprompt,
model,
});
const parameters = { ...model.parameters, ...generateSettings };
const body: CompletionCreateParamsStreaming = {
model: model.id ?? model.name,
prompt,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};
const openAICompletion = await openai.completions.create(body, {
body: { ...body, ...extraBody },
});
return openAICompletionToTextGenerationStream(openAICompletion);
};
} else if (completion === "chat_completions") {
return async ({ messages, preprompt, generateSettings }) => {
let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] =
await prepareMessages(messages, imageProcessor);
if (messagesOpenAI?.[0]?.role !== "system") {
messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI];
}
if (messagesOpenAI?.[0]) {
messagesOpenAI[0].content = preprompt ?? "";
}
const parameters = { ...model.parameters, ...generateSettings };
const body: ChatCompletionCreateParamsStreaming = {
model: model.id ?? model.name,
messages: messagesOpenAI,
stream: true,
max_tokens: parameters?.max_new_tokens,
stop: parameters?.stop,
temperature: parameters?.temperature,
top_p: parameters?.top_p,
frequency_penalty: parameters?.repetition_penalty,
};
const openChatAICompletion = await openai.chat.completions.create(body, {
body: { ...body, ...extraBody },
});
return openAIChatToTextGenerationStream(openChatAICompletion);
};
} else {
throw new Error("Invalid completion type");
}
}
async function prepareMessages(
messages: EndpointMessage[],
imageProcessor: ReturnType<typeof makeImageProcessor>
): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
return Promise.all(
messages.map(async (message) => {
if (message.from === "user") {
return {
role: message.from,
content: [
...(await prepareFiles(imageProcessor, message.files ?? [])),
{ type: "text", text: message.content },
],
};
}
return {
role: message.from,
content: message.content,
};
})
);
}
async function prepareFiles(
imageProcessor: ReturnType<typeof makeImageProcessor>,
files: MessageFile[]
): Promise<OpenAI.Chat.Completions.ChatCompletionContentPartImage[]> {
const processedFiles = await Promise.all(files.map(imageProcessor));
return processedFiles.map((file) => ({
type: "image_url" as const,
image_url: {
url: `data:${file.mime};base64,${file.image.toString("base64")}`,
},
}));
}