import { z } from "zod"; import type { Endpoint } from "../endpoints"; import type { TextGenerationStreamOutput } from "@huggingface/inference"; import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images"; import type { EndpointMessage } from "../endpoints"; import type { MessageFile } from "$lib/types/Message"; export const endpointBedrockParametersSchema = z.object({ weight: z.number().int().positive().default(1), type: z.literal("bedrock"), region: z.string().default("us-east-1"), model: z.any(), anthropicVersion: z.string().default("bedrock-2023-05-31"), isNova: z.boolean().default(false), multimodal: z .object({ image: createImageProcessorOptionsValidator({ supportedMimeTypes: [ "image/png", "image/jpeg", "image/webp", "image/avif", "image/tiff", "image/gif", ], preferredMimeType: "image/webp", maxSizeInMB: Infinity, maxWidth: 4096, maxHeight: 4096, }), }) .default({}), }); export async function endpointBedrock( input: z.input ): Promise { const { region, model, anthropicVersion, multimodal, isNova } = endpointBedrockParametersSchema.parse(input); let BedrockRuntimeClient, InvokeModelWithResponseStreamCommand; try { ({ BedrockRuntimeClient, InvokeModelWithResponseStreamCommand } = await import( "@aws-sdk/client-bedrock-runtime" )); } catch (error) { throw new Error("Failed to import @aws-sdk/client-bedrock-runtime. Make sure it's installed."); } const client = new BedrockRuntimeClient({ region, }); const imageProcessor = makeImageProcessor(multimodal.image); return async ({ messages, preprompt, generateSettings }) => { let system = preprompt; // Use the first message as the system prompt if it's of type "system" if (messages?.[0]?.from === "system") { system = messages[0].content; messages = messages.slice(1); // Remove the first system message from the array } const formattedMessages = await prepareMessages(messages, isNova, imageProcessor); let tokenId = 0; const parameters = { ...model.parameters, ...generateSettings }; return (async function* () { const baseCommandParams = { contentType: "application/json", accept: "application/json", modelId: model.id, }; const maxTokens = parameters.max_new_tokens || 4096; let bodyContent; if (isNova) { bodyContent = { messages: formattedMessages, inferenceConfig: { maxTokens, topP: 0.1, temperature: 1.0, }, system: [{ text: system }], }; } else { bodyContent = { anthropic_version: anthropicVersion, max_tokens: maxTokens, messages: formattedMessages, system, }; } const command = new InvokeModelWithResponseStreamCommand({ ...baseCommandParams, body: Buffer.from(JSON.stringify(bodyContent), "utf-8"), trace: "DISABLED", }); const response = await client.send(command); let text = ""; for await (const item of response.body ?? []) { const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes)); if ("contentBlockDelta" in chunk || chunk.type === "content_block_delta") { const chunkText = chunk.contentBlockDelta?.delta?.text || chunk.delta?.text || ""; text += chunkText; yield { token: { id: tokenId++, text: chunkText, logprob: 0, special: false, }, generated_text: null, details: null, } satisfies TextGenerationStreamOutput; } else if ("messageStop" in chunk || chunk.type === "message_stop") { yield { token: { id: tokenId++, text: "", logprob: 0, special: true, }, generated_text: text, details: null, } satisfies TextGenerationStreamOutput; } } })(); }; } // Prepare the messages excluding system prompts async function prepareMessages( messages: EndpointMessage[], isNova: boolean, imageProcessor: ReturnType ) { const formattedMessages = []; for (const message of messages) { const content = []; if (message.files?.length) { content.push(...(await prepareFiles(imageProcessor, isNova, message.files))); } if (isNova) { content.push({ text: message.content }); } else { content.push({ type: "text", text: message.content }); } const lastMessage = formattedMessages[formattedMessages.length - 1]; if (lastMessage && lastMessage.role === message.from) { // If the last message has the same role, merge the content lastMessage.content.push(...content); } else { formattedMessages.push({ role: message.from, content }); } } return formattedMessages; } // Process files and convert them to base64 encoded strings async function prepareFiles( imageProcessor: ReturnType, isNova: boolean, files: MessageFile[] ) { const processedFiles = await Promise.all(files.map(imageProcessor)); if (isNova) { return processedFiles.map((file) => ({ image: { format: file.mime.substring("image/".length), source: { bytes: file.image.toString("base64") }, }, })); } else { return processedFiles.map((file) => ({ type: "image", source: { type: "base64", media_type: file.mime, data: file.image.toString("base64") }, })); } }