Tarek Becker nsarrazin HF Staff commited on
Commit
28c09f1
·
unverified ·
1 Parent(s): 38fd55a

Add support for Amazon Nova (#1629)

Browse files

* Add support for Amazon Nova

* fix linting issue

* fix: replace model.id check with `isNova` flag and remove debug logs

---------

Co-authored-by: Nathan Sarrazin <[email protected]>

src/lib/server/endpoints/aws/endpointBedrock.ts CHANGED
@@ -11,6 +11,7 @@ export const endpointBedrockParametersSchema = z.object({
11
  region: z.string().default("us-east-1"),
12
  model: z.any(),
13
  anthropicVersion: z.string().default("bedrock-2023-05-31"),
 
14
  multimodal: z
15
  .object({
16
  image: createImageProcessorOptionsValidator({
@@ -34,7 +35,7 @@ export const endpointBedrockParametersSchema = z.object({
34
  export async function endpointBedrock(
35
  input: z.input<typeof endpointBedrockParametersSchema>
36
  ): Promise<Endpoint> {
37
- const { region, model, anthropicVersion, multimodal } =
38
  endpointBedrockParametersSchema.parse(input);
39
 
40
  let BedrockRuntimeClient, InvokeModelWithResponseStreamCommand;
@@ -59,24 +60,42 @@ export async function endpointBedrock(
59
  messages = messages.slice(1); // Remove the first system message from the array
60
  }
61
 
62
- const formattedMessages = await prepareMessages(messages, imageProcessor);
63
 
64
  let tokenId = 0;
65
  const parameters = { ...model.parameters, ...generateSettings };
66
  return (async function* () {
67
- const command = new InvokeModelWithResponseStreamCommand({
68
- body: Buffer.from(
69
- JSON.stringify({
70
- anthropic_version: anthropicVersion,
71
- max_tokens: parameters.max_new_tokens ? parameters.max_new_tokens : 4096,
72
- messages: formattedMessages,
73
- system,
74
- }),
75
- "utf-8"
76
- ),
77
  contentType: "application/json",
78
  accept: "application/json",
79
  modelId: model.id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  trace: "DISABLED",
81
  });
82
 
@@ -86,21 +105,20 @@ export async function endpointBedrock(
86
 
87
  for await (const item of response.body ?? []) {
88
  const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes));
89
- const chunk_type = chunk.type;
90
-
91
- if (chunk_type === "content_block_delta") {
92
- text += chunk.delta.text;
93
  yield {
94
  token: {
95
  id: tokenId++,
96
- text: chunk.delta.text,
97
  logprob: 0,
98
  special: false,
99
  },
100
  generated_text: null,
101
  details: null,
102
  } satisfies TextGenerationStreamOutput;
103
- } else if (chunk_type === "message_stop") {
104
  yield {
105
  token: {
106
  id: tokenId++,
@@ -120,6 +138,7 @@ export async function endpointBedrock(
120
  // Prepare the messages excluding system prompts
121
  async function prepareMessages(
122
  messages: EndpointMessage[],
 
123
  imageProcessor: ReturnType<typeof makeImageProcessor>
124
  ) {
125
  const formattedMessages = [];
@@ -128,9 +147,13 @@ async function prepareMessages(
128
  const content = [];
129
 
130
  if (message.files?.length) {
131
- content.push(...(await prepareFiles(imageProcessor, message.files)));
 
 
 
 
 
132
  }
133
- content.push({ type: "text", text: message.content });
134
 
135
  const lastMessage = formattedMessages[formattedMessages.length - 1];
136
  if (lastMessage && lastMessage.role === message.from) {
@@ -146,11 +169,22 @@ async function prepareMessages(
146
  // Process files and convert them to base64 encoded strings
147
  async function prepareFiles(
148
  imageProcessor: ReturnType<typeof makeImageProcessor>,
 
149
  files: MessageFile[]
150
  ) {
151
  const processedFiles = await Promise.all(files.map(imageProcessor));
152
- return processedFiles.map((file) => ({
153
- type: "image",
154
- source: { type: "base64", media_type: "image/jpeg", data: file.image.toString("base64") },
155
- }));
 
 
 
 
 
 
 
 
 
 
156
  }
 
11
  region: z.string().default("us-east-1"),
12
  model: z.any(),
13
  anthropicVersion: z.string().default("bedrock-2023-05-31"),
14
+ isNova: z.boolean().default(false),
15
  multimodal: z
16
  .object({
17
  image: createImageProcessorOptionsValidator({
 
35
  export async function endpointBedrock(
36
  input: z.input<typeof endpointBedrockParametersSchema>
37
  ): Promise<Endpoint> {
38
+ const { region, model, anthropicVersion, multimodal, isNova } =
39
  endpointBedrockParametersSchema.parse(input);
40
 
41
  let BedrockRuntimeClient, InvokeModelWithResponseStreamCommand;
 
60
  messages = messages.slice(1); // Remove the first system message from the array
61
  }
62
 
63
+ const formattedMessages = await prepareMessages(messages, model.id, imageProcessor);
64
 
65
  let tokenId = 0;
66
  const parameters = { ...model.parameters, ...generateSettings };
67
  return (async function* () {
68
+ const baseCommandParams = {
 
 
 
 
 
 
 
 
 
69
  contentType: "application/json",
70
  accept: "application/json",
71
  modelId: model.id,
72
+ };
73
+
74
+ const maxTokens = parameters.max_new_tokens || 4096;
75
+
76
+ let bodyContent;
77
+ if (isNova) {
78
+ bodyContent = {
79
+ messages: formattedMessages,
80
+ inferenceConfig: {
81
+ maxTokens,
82
+ topP: 0.1,
83
+ temperature: 1.0,
84
+ },
85
+ system: [{ text: system }],
86
+ };
87
+ } else {
88
+ bodyContent = {
89
+ anthropic_version: anthropicVersion,
90
+ max_tokens: maxTokens,
91
+ messages: formattedMessages,
92
+ system,
93
+ };
94
+ }
95
+
96
+ const command = new InvokeModelWithResponseStreamCommand({
97
+ ...baseCommandParams,
98
+ body: Buffer.from(JSON.stringify(bodyContent), "utf-8"),
99
  trace: "DISABLED",
100
  });
101
 
 
105
 
106
  for await (const item of response.body ?? []) {
107
  const chunk = JSON.parse(new TextDecoder().decode(item.chunk?.bytes));
108
+ if ("contentBlockDelta" in chunk || chunk.type === "content_block_delta") {
109
+ const chunkText = chunk.contentBlockDelta?.delta?.text || chunk.delta?.text || "";
110
+ text += chunkText;
 
111
  yield {
112
  token: {
113
  id: tokenId++,
114
+ text: chunkText,
115
  logprob: 0,
116
  special: false,
117
  },
118
  generated_text: null,
119
  details: null,
120
  } satisfies TextGenerationStreamOutput;
121
+ } else if ("messageStop" in chunk || chunk.type === "message_stop") {
122
  yield {
123
  token: {
124
  id: tokenId++,
 
138
  // Prepare the messages excluding system prompts
139
  async function prepareMessages(
140
  messages: EndpointMessage[],
141
+ isNova: boolean,
142
  imageProcessor: ReturnType<typeof makeImageProcessor>
143
  ) {
144
  const formattedMessages = [];
 
147
  const content = [];
148
 
149
  if (message.files?.length) {
150
+ content.push(...(await prepareFiles(imageProcessor, isNova, message.files)));
151
+ }
152
+ if (isNova) {
153
+ content.push({ text: message.content });
154
+ } else {
155
+ content.push({ type: "text", text: message.content });
156
  }
 
157
 
158
  const lastMessage = formattedMessages[formattedMessages.length - 1];
159
  if (lastMessage && lastMessage.role === message.from) {
 
169
  // Process files and convert them to base64 encoded strings
170
  async function prepareFiles(
171
  imageProcessor: ReturnType<typeof makeImageProcessor>,
172
+ isNova: boolean,
173
  files: MessageFile[]
174
  ) {
175
  const processedFiles = await Promise.all(files.map(imageProcessor));
176
+
177
+ if (isNova) {
178
+ return processedFiles.map((file) => ({
179
+ image: {
180
+ format: file.mime.substring("image/".length),
181
+ source: { bytes: file.image.toString("base64") },
182
+ },
183
+ }));
184
+ } else {
185
+ return processedFiles.map((file) => ({
186
+ type: "image",
187
+ source: { type: "base64", media_type: file.mime, data: file.image.toString("base64") },
188
+ }));
189
+ }
190
  }