Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Add support for Anthropic models via AWS Bedrock (#1413)
Browse files* Add support for Anthropic models via AWS Bedrock
* deps
* Fixed type errors
* Temporary fix for continue button showing up on Claude
* Fix continue button issue by setting the last message token's special to true
---------
Co-authored-by: Nathan Sarrazin <[email protected]>
- package-lock.json +0 -0
- package.json +1 -0
- src/lib/server/endpoints/aws/endpointBedrock.ts +150 -0
- src/lib/server/endpoints/endpoints.ts +3 -0
- src/lib/server/models.ts +2 -0
package-lock.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
package.json
CHANGED
@@ -108,6 +108,7 @@
|
|
108 |
"zod": "^3.22.3"
|
109 |
},
|
110 |
"optionalDependencies": {
|
|
|
111 |
"@anthropic-ai/sdk": "^0.25.0",
|
112 |
"@anthropic-ai/vertex-sdk": "^0.4.1",
|
113 |
"@google-cloud/vertexai": "^1.1.0",
|
|
|
108 |
"zod": "^3.22.3"
|
109 |
},
|
110 |
"optionalDependencies": {
|
111 |
+
"@aws-sdk/client-bedrock-runtime": "^3.631.0",
|
112 |
"@anthropic-ai/sdk": "^0.25.0",
|
113 |
"@anthropic-ai/vertex-sdk": "^0.4.1",
|
114 |
"@google-cloud/vertexai": "^1.1.0",
|
src/lib/server/endpoints/aws/endpointBedrock.ts
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { z } from "zod";
|
2 |
+
import type { Endpoint } from "../endpoints";
|
3 |
+
import type { TextGenerationStreamOutput } from "@huggingface/inference";
|
4 |
+
import {
|
5 |
+
BedrockRuntimeClient,
|
6 |
+
InvokeModelWithResponseStreamCommand,
|
7 |
+
} from "@aws-sdk/client-bedrock-runtime";
|
8 |
+
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
|
9 |
+
import type { EndpointMessage } from "../endpoints";
|
10 |
+
import type { MessageFile } from "$lib/types/Message";
|
11 |
+
|
12 |
+
export const endpointBedrockParametersSchema = z.object({
|
13 |
+
weight: z.number().int().positive().default(1),
|
14 |
+
type: z.literal("bedrock"),
|
15 |
+
region: z.string().default("us-east-1"),
|
16 |
+
model: z.any(),
|
17 |
+
anthropicVersion: z.string().default("bedrock-2023-05-31"),
|
18 |
+
multimodal: z
|
19 |
+
.object({
|
20 |
+
image: createImageProcessorOptionsValidator({
|
21 |
+
supportedMimeTypes: [
|
22 |
+
"image/png",
|
23 |
+
"image/jpeg",
|
24 |
+
"image/webp",
|
25 |
+
"image/avif",
|
26 |
+
"image/tiff",
|
27 |
+
"image/gif",
|
28 |
+
],
|
29 |
+
preferredMimeType: "image/webp",
|
30 |
+
maxSizeInMB: Infinity,
|
31 |
+
maxWidth: 4096,
|
32 |
+
maxHeight: 4096,
|
33 |
+
}),
|
34 |
+
})
|
35 |
+
.default({}),
|
36 |
+
});
|
37 |
+
|
38 |
+
export async function endpointBedrock(
|
39 |
+
input: z.input<typeof endpointBedrockParametersSchema>
|
40 |
+
): Promise<Endpoint> {
|
41 |
+
const { region, model, anthropicVersion, multimodal } =
|
42 |
+
endpointBedrockParametersSchema.parse(input);
|
43 |
+
const client = new BedrockRuntimeClient({
|
44 |
+
region,
|
45 |
+
});
|
46 |
+
const imageProcessor = makeImageProcessor(multimodal.image);
|
47 |
+
|
48 |
+
return async ({ messages, preprompt, generateSettings }) => {
|
49 |
+
let system = preprompt;
|
50 |
+
// Use the first message as the system prompt if it's of type "system"
|
51 |
+
if (messages?.[0]?.from === "system") {
|
52 |
+
system = messages[0].content;
|
53 |
+
messages = messages.slice(1); // Remove the first system message from the array
|
54 |
+
}
|
55 |
+
|
56 |
+
const formattedMessages = await prepareMessages(messages, imageProcessor);
|
57 |
+
|
58 |
+
let tokenId = 0;
|
59 |
+
const parameters = { ...model.parameters, ...generateSettings };
|
60 |
+
return (async function* () {
|
61 |
+
const command = new InvokeModelWithResponseStreamCommand({
|
62 |
+
body: Buffer.from(
|
63 |
+
JSON.stringify({
|
64 |
+
anthropic_version: anthropicVersion,
|
65 |
+
max_tokens: parameters.max_new_tokens ? parameters.max_new_tokens : 4096,
|
66 |
+
messages: formattedMessages,
|
67 |
+
system,
|
68 |
+
}),
|
69 |
+
"utf-8"
|
70 |
+
),
|
71 |
+
contentType: "application/json",
|
72 |
+
accept: "application/json",
|
73 |
+
modelId: model.id,
|
74 |
+
trace: "DISABLED",
|
75 |
+
});
|
76 |
+
|
77 |
+
const response = await client.send(command);
|
78 |
+
|
79 |
+
let text = "";
|
80 |
+
|
81 |
+
for await (const item of response.body ?? []) {
|
82 |
+
const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes));
|
83 |
+
const chunk_type = chunk.type;
|
84 |
+
|
85 |
+
if (chunk_type === "content_block_delta") {
|
86 |
+
text += chunk.delta.text;
|
87 |
+
yield {
|
88 |
+
token: {
|
89 |
+
id: tokenId++,
|
90 |
+
text: chunk.delta.text,
|
91 |
+
logprob: 0,
|
92 |
+
special: false,
|
93 |
+
},
|
94 |
+
generated_text: null,
|
95 |
+
details: null,
|
96 |
+
} satisfies TextGenerationStreamOutput;
|
97 |
+
} else if (chunk_type === "message_stop") {
|
98 |
+
yield {
|
99 |
+
token: {
|
100 |
+
id: tokenId++,
|
101 |
+
text: "",
|
102 |
+
logprob: 0,
|
103 |
+
special: true,
|
104 |
+
},
|
105 |
+
generated_text: text,
|
106 |
+
details: null,
|
107 |
+
} satisfies TextGenerationStreamOutput;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
})();
|
111 |
+
};
|
112 |
+
}
|
113 |
+
|
114 |
+
// Prepare the messages excluding system prompts
|
115 |
+
async function prepareMessages(
|
116 |
+
messages: EndpointMessage[],
|
117 |
+
imageProcessor: ReturnType<typeof makeImageProcessor>
|
118 |
+
) {
|
119 |
+
const formattedMessages = [];
|
120 |
+
|
121 |
+
for (const message of messages) {
|
122 |
+
const content = [];
|
123 |
+
|
124 |
+
if (message.files?.length) {
|
125 |
+
content.push(...(await prepareFiles(imageProcessor, message.files)));
|
126 |
+
}
|
127 |
+
content.push({ type: "text", text: message.content });
|
128 |
+
|
129 |
+
const lastMessage = formattedMessages[formattedMessages.length - 1];
|
130 |
+
if (lastMessage && lastMessage.role === message.from) {
|
131 |
+
// If the last message has the same role, merge the content
|
132 |
+
lastMessage.content.push(...content);
|
133 |
+
} else {
|
134 |
+
formattedMessages.push({ role: message.from, content });
|
135 |
+
}
|
136 |
+
}
|
137 |
+
return formattedMessages;
|
138 |
+
}
|
139 |
+
|
140 |
+
// Process files and convert them to base64 encoded strings
|
141 |
+
async function prepareFiles(
|
142 |
+
imageProcessor: ReturnType<typeof makeImageProcessor>,
|
143 |
+
files: MessageFile[]
|
144 |
+
) {
|
145 |
+
const processedFiles = await Promise.all(files.map(imageProcessor));
|
146 |
+
return processedFiles.map((file) => ({
|
147 |
+
type: "image",
|
148 |
+
source: { type: "base64", media_type: "image/jpeg", data: file.image.toString("base64") },
|
149 |
+
}));
|
150 |
+
}
|
src/lib/server/endpoints/endpoints.ts
CHANGED
@@ -9,6 +9,7 @@ import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/e
|
|
9 |
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
|
10 |
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
|
11 |
import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
|
|
|
12 |
|
13 |
import {
|
14 |
endpointAnthropic,
|
@@ -61,6 +62,7 @@ export const endpoints = {
|
|
61 |
tgi: endpointTgi,
|
62 |
anthropic: endpointAnthropic,
|
63 |
anthropicvertex: endpointAnthropicVertex,
|
|
|
64 |
aws: endpointAws,
|
65 |
openai: endpointOai,
|
66 |
llamacpp: endpointLlamacpp,
|
@@ -76,6 +78,7 @@ export const endpointSchema = z.discriminatedUnion("type", [
|
|
76 |
endpointAnthropicParametersSchema,
|
77 |
endpointAnthropicVertexParametersSchema,
|
78 |
endpointAwsParametersSchema,
|
|
|
79 |
endpointOAIParametersSchema,
|
80 |
endpointTgiParametersSchema,
|
81 |
endpointLlamacppParametersSchema,
|
|
|
9 |
import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
|
10 |
import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
|
11 |
import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
|
12 |
+
import { endpointBedrock, endpointBedrockParametersSchema } from "./aws/endpointBedrock";
|
13 |
|
14 |
import {
|
15 |
endpointAnthropic,
|
|
|
62 |
tgi: endpointTgi,
|
63 |
anthropic: endpointAnthropic,
|
64 |
anthropicvertex: endpointAnthropicVertex,
|
65 |
+
bedrock: endpointBedrock,
|
66 |
aws: endpointAws,
|
67 |
openai: endpointOai,
|
68 |
llamacpp: endpointLlamacpp,
|
|
|
78 |
endpointAnthropicParametersSchema,
|
79 |
endpointAnthropicVertexParametersSchema,
|
80 |
endpointAwsParametersSchema,
|
81 |
+
endpointBedrockParametersSchema,
|
82 |
endpointOAIParametersSchema,
|
83 |
endpointTgiParametersSchema,
|
84 |
endpointLlamacppParametersSchema,
|
src/lib/server/models.ts
CHANGED
@@ -280,6 +280,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
|
|
280 |
return endpoints.anthropic(args);
|
281 |
case "anthropic-vertex":
|
282 |
return endpoints.anthropicvertex(args);
|
|
|
|
|
283 |
case "aws":
|
284 |
return await endpoints.aws(args);
|
285 |
case "openai":
|
|
|
280 |
return endpoints.anthropic(args);
|
281 |
case "anthropic-vertex":
|
282 |
return endpoints.anthropicvertex(args);
|
283 |
+
case "bedrock":
|
284 |
+
return endpoints.bedrock(args);
|
285 |
case "aws":
|
286 |
return await endpoints.aws(args);
|
287 |
case "openai":
|