File size: 3,000 Bytes
61e5613
 
 
5b1a9aa
 
33d6f58
61e5613
 
 
 
 
 
 
 
 
5b1a9aa
 
 
 
 
 
 
 
 
 
 
 
61e5613
 
 
 
 
5b1a9aa
61e5613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33d6f58
 
 
 
61e5613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import { z } from "zod";
import type { Endpoint } from "../endpoints";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator } from "../images";
import { endpointMessagesToAnthropicMessages } from "./utils";
import type { MessageParam } from "@anthropic-ai/sdk/resources/messages.mjs";

export const endpointAnthropicVertexParametersSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("anthropic-vertex"),
	region: z.string().default("us-central1"),
	projectId: z.string(),
	defaultHeaders: z.record(z.string()).optional(),
	defaultQuery: z.record(z.string()).optional(),
	multimodal: z
		.object({
			image: createImageProcessorOptionsValidator({
				supportedMimeTypes: ["image/png", "image/jpeg", "image/webp"],
				preferredMimeType: "image/webp",
				// The 4 / 3 compensates for the 33% increase in size when converting to base64
				maxSizeInMB: (5 / 4) * 3,
				maxWidth: 4096,
				maxHeight: 4096,
			}),
		})
		.default({}),
});

export async function endpointAnthropicVertex(
	input: z.input<typeof endpointAnthropicVertexParametersSchema>
): Promise<Endpoint> {
	const { region, projectId, model, defaultHeaders, defaultQuery, multimodal } =
		endpointAnthropicVertexParametersSchema.parse(input);
	let AnthropicVertex;
	try {
		AnthropicVertex = (await import("@anthropic-ai/vertex-sdk")).AnthropicVertex;
	} catch (e) {
		throw new Error("Failed to import @anthropic-ai/vertex-sdk", { cause: e });
	}

	const anthropic = new AnthropicVertex({
		baseURL: `https://${region}-aiplatform.googleapis.com/v1`,
		region,
		projectId,
		defaultHeaders,
		defaultQuery,
	});

	return async ({ messages, preprompt }) => {
		let system = preprompt;
		if (messages?.[0]?.from === "system") {
			system = messages[0].content;
		}

		let tokenId = 0;
		return (async function* () {
			const stream = anthropic.messages.stream({
				model: model.id ?? model.name,
				messages: (await endpointMessagesToAnthropicMessages(
					messages,
					multimodal
				)) as MessageParam[],
				max_tokens: model.parameters?.max_new_tokens,
				temperature: model.parameters?.temperature,
				top_p: model.parameters?.top_p,
				top_k: model.parameters?.top_k,
				stop_sequences: model.parameters?.stop,
				system,
			});
			while (true) {
				const result = await Promise.race([stream.emitted("text"), stream.emitted("end")]);

				// Stream end
				if (result === undefined) {
					yield {
						token: {
							id: tokenId++,
							text: "",
							logprob: 0,
							special: true,
						},
						generated_text: await stream.finalText(),
						details: null,
					} satisfies TextGenerationStreamOutput;
					return;
				}

				// Text delta
				yield {
					token: {
						id: tokenId++,
						text: result as unknown as string,
						special: false,
						logprob: 0,
					},
					generated_text: null,
					details: null,
				} satisfies TextGenerationStreamOutput;
			}
		})();
	};
}