thomas-pocreau goupilew nsarrazin HF Staff commited on
Commit
5559ab7
·
unverified ·
1 Parent(s): 2ac7abd

Add support for model version in Vertex AI (#1496)

Browse files

* Add support for model version in Vertex AI

* Update src/lib/server/endpoints/google/endpointVertex.ts

Co-authored-by: goupilew <[email protected]>

* fix: optional chaining on extraBody

---------

Co-authored-by: goupilew <[email protected]>
Co-authored-by: Nathan Sarrazin <[email protected]>

docs/source/configuration/models/providers/google.md CHANGED
@@ -33,8 +33,9 @@ MODELS=`[
33
  "type": "vertex",
34
  "project": "abc-xyz",
35
  "location": "europe-west3",
36
- "model": "gemini-1.5-pro-preview-0409", // model-name
37
-
 
38
  // Optional
39
  "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
40
  "apiEndpoint": "", // alternative api endpoint url,
 
33
  "type": "vertex",
34
  "project": "abc-xyz",
35
  "location": "europe-west3",
36
+ "extraBody": {
37
+ "model_version": "gemini-1.5-pro-002",
38
+ },
39
  // Optional
40
  "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
41
  "apiEndpoint": "", // alternative api endpoint url,
src/lib/server/endpoints/google/endpointVertex.ts CHANGED
@@ -16,6 +16,7 @@ export const endpointVertexParametersSchema = z.object({
16
  model: z.any(), // allow optional and validate against emptiness
17
  type: z.literal("vertex"),
18
  location: z.string().default("europe-west1"),
 
19
  project: z.string(),
20
  apiEndpoint: z.string().optional(),
21
  safetyThreshold: z
@@ -49,7 +50,7 @@ export const endpointVertexParametersSchema = z.object({
49
  });
50
 
51
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
52
- const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal } =
53
  endpointVertexParametersSchema.parse(input);
54
 
55
  const vertex_ai = new VertexAI({
@@ -64,7 +65,7 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
64
  const hasFiles = messages.some((message) => message.files && message.files.length > 0);
65
 
66
  const generativeModel = vertex_ai.getGenerativeModel({
67
- model: model.id ?? model.name,
68
  safetySettings: safetyThreshold
69
  ? [
70
  {
 
16
  model: z.any(), // allow optional and validate against emptiness
17
  type: z.literal("vertex"),
18
  location: z.string().default("europe-west1"),
19
+ extraBody: z.object({ model_version: z.string() }).optional(),
20
  project: z.string(),
21
  apiEndpoint: z.string().optional(),
22
  safetyThreshold: z
 
50
  });
51
 
52
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
53
+ const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal, extraBody } =
54
  endpointVertexParametersSchema.parse(input);
55
 
56
  const vertex_ai = new VertexAI({
 
65
  const hasFiles = messages.some((message) => message.files && message.files.length > 0);
66
 
67
  const generativeModel = vertex_ai.getGenerativeModel({
68
+ model: extraBody?.model_version ?? model.id ?? model.name,
69
  safetySettings: safetyThreshold
70
  ? [
71
  {