ABarLT nsarrazin HF Staff commited on
Commit
bf78ac3
·
unverified ·
1 Parent(s): d8a31e6

Add support for Anthropic models via AWS Bedrock (#1413)

Browse files

* Add support for Anthropic models via AWS Bedrock

* deps

* Fixed type errors

* Temporary fix for continue button showing up on Claude

* Fix continue button issue by setting the last message token's special to true

---------

Co-authored-by: Nathan Sarrazin <[email protected]>

package-lock.json CHANGED
The diff for this file is too large to render. See raw diff
 
package.json CHANGED
@@ -108,6 +108,7 @@
108
  "zod": "^3.22.3"
109
  },
110
  "optionalDependencies": {
 
111
  "@anthropic-ai/sdk": "^0.25.0",
112
  "@anthropic-ai/vertex-sdk": "^0.4.1",
113
  "@google-cloud/vertexai": "^1.1.0",
 
108
  "zod": "^3.22.3"
109
  },
110
  "optionalDependencies": {
111
+ "@aws-sdk/client-bedrock-runtime": "^3.631.0",
112
  "@anthropic-ai/sdk": "^0.25.0",
113
  "@anthropic-ai/vertex-sdk": "^0.4.1",
114
  "@google-cloud/vertexai": "^1.1.0",
src/lib/server/endpoints/aws/endpointBedrock.ts ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { z } from "zod";
2
+ import type { Endpoint } from "../endpoints";
3
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
4
+ import {
5
+ BedrockRuntimeClient,
6
+ InvokeModelWithResponseStreamCommand,
7
+ } from "@aws-sdk/client-bedrock-runtime";
8
+ import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";
9
+ import type { EndpointMessage } from "../endpoints";
10
+ import type { MessageFile } from "$lib/types/Message";
11
+
12
+ export const endpointBedrockParametersSchema = z.object({
13
+ weight: z.number().int().positive().default(1),
14
+ type: z.literal("bedrock"),
15
+ region: z.string().default("us-east-1"),
16
+ model: z.any(),
17
+ anthropicVersion: z.string().default("bedrock-2023-05-31"),
18
+ multimodal: z
19
+ .object({
20
+ image: createImageProcessorOptionsValidator({
21
+ supportedMimeTypes: [
22
+ "image/png",
23
+ "image/jpeg",
24
+ "image/webp",
25
+ "image/avif",
26
+ "image/tiff",
27
+ "image/gif",
28
+ ],
29
+ preferredMimeType: "image/webp",
30
+ maxSizeInMB: Infinity,
31
+ maxWidth: 4096,
32
+ maxHeight: 4096,
33
+ }),
34
+ })
35
+ .default({}),
36
+ });
37
+
38
+ export async function endpointBedrock(
39
+ input: z.input<typeof endpointBedrockParametersSchema>
40
+ ): Promise<Endpoint> {
41
+ const { region, model, anthropicVersion, multimodal } =
42
+ endpointBedrockParametersSchema.parse(input);
43
+ const client = new BedrockRuntimeClient({
44
+ region,
45
+ });
46
+ const imageProcessor = makeImageProcessor(multimodal.image);
47
+
48
+ return async ({ messages, preprompt, generateSettings }) => {
49
+ let system = preprompt;
50
+ // Use the first message as the system prompt if it's of type "system"
51
+ if (messages?.[0]?.from === "system") {
52
+ system = messages[0].content;
53
+ messages = messages.slice(1); // Remove the first system message from the array
54
+ }
55
+
56
+ const formattedMessages = await prepareMessages(messages, imageProcessor);
57
+
58
+ let tokenId = 0;
59
+ const parameters = { ...model.parameters, ...generateSettings };
60
+ return (async function* () {
61
+ const command = new InvokeModelWithResponseStreamCommand({
62
+ body: Buffer.from(
63
+ JSON.stringify({
64
+ anthropic_version: anthropicVersion,
65
+ max_tokens: parameters.max_new_tokens ? parameters.max_new_tokens : 4096,
66
+ messages: formattedMessages,
67
+ system,
68
+ }),
69
+ "utf-8"
70
+ ),
71
+ contentType: "application/json",
72
+ accept: "application/json",
73
+ modelId: model.id,
74
+ trace: "DISABLED",
75
+ });
76
+
77
+ const response = await client.send(command);
78
+
79
+ let text = "";
80
+
81
+ for await (const item of response.body ?? []) {
82
+ const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes));
83
+ const chunk_type = chunk.type;
84
+
85
+ if (chunk_type === "content_block_delta") {
86
+ text += chunk.delta.text;
87
+ yield {
88
+ token: {
89
+ id: tokenId++,
90
+ text: chunk.delta.text,
91
+ logprob: 0,
92
+ special: false,
93
+ },
94
+ generated_text: null,
95
+ details: null,
96
+ } satisfies TextGenerationStreamOutput;
97
+ } else if (chunk_type === "message_stop") {
98
+ yield {
99
+ token: {
100
+ id: tokenId++,
101
+ text: "",
102
+ logprob: 0,
103
+ special: true,
104
+ },
105
+ generated_text: text,
106
+ details: null,
107
+ } satisfies TextGenerationStreamOutput;
108
+ }
109
+ }
110
+ })();
111
+ };
112
+ }
113
+
114
+ // Prepare the messages excluding system prompts
115
+ async function prepareMessages(
116
+ messages: EndpointMessage[],
117
+ imageProcessor: ReturnType<typeof makeImageProcessor>
118
+ ) {
119
+ const formattedMessages = [];
120
+
121
+ for (const message of messages) {
122
+ const content = [];
123
+
124
+ if (message.files?.length) {
125
+ content.push(...(await prepareFiles(imageProcessor, message.files)));
126
+ }
127
+ content.push({ type: "text", text: message.content });
128
+
129
+ const lastMessage = formattedMessages[formattedMessages.length - 1];
130
+ if (lastMessage && lastMessage.role === message.from) {
131
+ // If the last message has the same role, merge the content
132
+ lastMessage.content.push(...content);
133
+ } else {
134
+ formattedMessages.push({ role: message.from, content });
135
+ }
136
+ }
137
+ return formattedMessages;
138
+ }
139
+
140
+ // Process files and convert them to base64 encoded strings
141
+ async function prepareFiles(
142
+ imageProcessor: ReturnType<typeof makeImageProcessor>,
143
+ files: MessageFile[]
144
+ ) {
145
+ const processedFiles = await Promise.all(files.map(imageProcessor));
146
+ return processedFiles.map((file) => ({
147
+ type: "image",
148
+ source: { type: "base64", media_type: "image/jpeg", data: file.image.toString("base64") },
149
+ }));
150
+ }
src/lib/server/endpoints/endpoints.ts CHANGED
@@ -9,6 +9,7 @@ import endpointLlamacpp, { endpointLlamacppParametersSchema } from "./llamacpp/e
9
  import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
10
  import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
11
  import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
 
12
 
13
  import {
14
  endpointAnthropic,
@@ -61,6 +62,7 @@ export const endpoints = {
61
  tgi: endpointTgi,
62
  anthropic: endpointAnthropic,
63
  anthropicvertex: endpointAnthropicVertex,
 
64
  aws: endpointAws,
65
  openai: endpointOai,
66
  llamacpp: endpointLlamacpp,
@@ -76,6 +78,7 @@ export const endpointSchema = z.discriminatedUnion("type", [
76
  endpointAnthropicParametersSchema,
77
  endpointAnthropicVertexParametersSchema,
78
  endpointAwsParametersSchema,
 
79
  endpointOAIParametersSchema,
80
  endpointTgiParametersSchema,
81
  endpointLlamacppParametersSchema,
 
9
  import endpointOllama, { endpointOllamaParametersSchema } from "./ollama/endpointOllama";
10
  import endpointVertex, { endpointVertexParametersSchema } from "./google/endpointVertex";
11
  import endpointGenAI, { endpointGenAIParametersSchema } from "./google/endpointGenAI";
12
+ import { endpointBedrock, endpointBedrockParametersSchema } from "./aws/endpointBedrock";
13
 
14
  import {
15
  endpointAnthropic,
 
62
  tgi: endpointTgi,
63
  anthropic: endpointAnthropic,
64
  anthropicvertex: endpointAnthropicVertex,
65
+ bedrock: endpointBedrock,
66
  aws: endpointAws,
67
  openai: endpointOai,
68
  llamacpp: endpointLlamacpp,
 
78
  endpointAnthropicParametersSchema,
79
  endpointAnthropicVertexParametersSchema,
80
  endpointAwsParametersSchema,
81
+ endpointBedrockParametersSchema,
82
  endpointOAIParametersSchema,
83
  endpointTgiParametersSchema,
84
  endpointLlamacppParametersSchema,
src/lib/server/models.ts CHANGED
@@ -280,6 +280,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
280
  return endpoints.anthropic(args);
281
  case "anthropic-vertex":
282
  return endpoints.anthropicvertex(args);
 
 
283
  case "aws":
284
  return await endpoints.aws(args);
285
  case "openai":
 
280
  return endpoints.anthropic(args);
281
  case "anthropic-vertex":
282
  return endpoints.anthropicvertex(args);
283
+ case "bedrock":
284
+ return endpoints.bedrock(args);
285
  case "aws":
286
  return await endpoints.aws(args);
287
  case "openai":