import { z } from "zod"; import { openAICompletionToTextGenerationStream } from "./openAICompletionToTextGenerationStream"; import { openAIChatToTextGenerationStream } from "./openAIChatToTextGenerationStream"; import type { CompletionCreateParamsStreaming } from "openai/resources/completions"; import type { ChatCompletionCreateParamsStreaming } from "openai/resources/chat/completions"; import { buildPrompt } from "$lib/buildPrompt"; import { env } from "$env/dynamic/private"; import type { Endpoint } from "../endpoints"; import type OpenAI from "openai"; import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; import type { MessageFile } from "$lib/types/Message"; import type { EndpointMessage } from "../endpoints"; export const endpointOAIParametersSchema = z.object({ weight: z.number().int().positive().default(1), model: z.any(), type: z.literal("openai"), baseURL: z.string().url().default("https://api.openai.com/v1"), apiKey: z.string().default(env.OPENAI_API_KEY ?? "sk-"), completion: z .union([z.literal("completions"), z.literal("chat_completions")]) .default("chat_completions"), defaultHeaders: z.record(z.string()).optional(), defaultQuery: z.record(z.string()).optional(), extraBody: z.record(z.any()).optional(), multimodal: z .object({ image: createImageProcessorOptionsValidator({ supportedMimeTypes: [ "image/png", "image/jpeg", "image/webp", "image/avif", "image/tiff", "image/gif", ], preferredMimeType: "image/webp", maxSizeInMB: Infinity, maxWidth: 4096, maxHeight: 4096, }), }) .default({}), }); export async function endpointOai( input: z.input ): Promise { const { baseURL, apiKey, completion, model, defaultHeaders, defaultQuery, multimodal, extraBody, } = endpointOAIParametersSchema.parse(input); /* eslint-disable-next-line no-shadow */ let OpenAI; try { OpenAI = (await import("openai")).OpenAI; } catch (e) { throw new Error("Failed to import OpenAI", { cause: e }); } const openai = new OpenAI({ apiKey: apiKey ?? "sk-", baseURL, defaultHeaders, defaultQuery, }); const imageProcessor = makeImageProcessor(multimodal.image); if (completion === "completions") { return async ({ messages, preprompt, continueMessage, generateSettings }) => { const prompt = await buildPrompt({ messages, continueMessage, preprompt, model, }); const parameters = { ...model.parameters, ...generateSettings }; const body: CompletionCreateParamsStreaming = { model: model.id ?? model.name, prompt, stream: true, max_tokens: parameters?.max_new_tokens, stop: parameters?.stop, temperature: parameters?.temperature, top_p: parameters?.top_p, frequency_penalty: parameters?.repetition_penalty, }; const openAICompletion = await openai.completions.create(body, { body: { ...body, ...extraBody }, }); return openAICompletionToTextGenerationStream(openAICompletion); }; } else if (completion === "chat_completions") { return async ({ messages, preprompt, generateSettings }) => { let messagesOpenAI: OpenAI.Chat.Completions.ChatCompletionMessageParam[] = await prepareMessages(messages, imageProcessor); if (messagesOpenAI?.[0]?.role !== "system") { messagesOpenAI = [{ role: "system", content: "" }, ...messagesOpenAI]; } if (messagesOpenAI?.[0]) { messagesOpenAI[0].content = preprompt ?? ""; } const parameters = { ...model.parameters, ...generateSettings }; const body: ChatCompletionCreateParamsStreaming = { model: model.id ?? model.name, messages: messagesOpenAI, stream: true, max_tokens: parameters?.max_new_tokens, stop: parameters?.stop, temperature: parameters?.temperature, top_p: parameters?.top_p, frequency_penalty: parameters?.repetition_penalty, }; const openChatAICompletion = await openai.chat.completions.create(body, { body: { ...body, ...extraBody }, }); return openAIChatToTextGenerationStream(openChatAICompletion); }; } else { throw new Error("Invalid completion type"); } } async function prepareMessages( messages: EndpointMessage[], imageProcessor: ReturnType ): Promise { return Promise.all( messages.map(async (message) => { if (message.from === "user") { return { role: message.from, content: [ ...(await prepareFiles(imageProcessor, message.files ?? [])), { type: "text", text: message.content }, ], }; } return { role: message.from, content: message.content, }; }) ); } async function prepareFiles( imageProcessor: ReturnType, files: MessageFile[] ): Promise { const processedFiles = await Promise.all(files.map(imageProcessor)); return processedFiles.map((file) => ({ type: "image_url" as const, image_url: { url: `data:${file.mime};base64,${file.image.toString("base64")}`, }, })); }