Spaces:
Running
Running
from typing import Dict, List, Union | |
import os | |
from google.cloud import aiplatform | |
from google.protobuf import json_format | |
from google.protobuf.struct_pb2 import Value | |
from google.oauth2 import service_account | |
def predict_custom_trained_model( | |
project: str, | |
endpoint_id: str, | |
instances: Union[Dict, List[Dict]], | |
location: str = "us-central1", | |
api_endpoint: str = "us-central1-aiplatform.googleapis.com", | |
): | |
""" | |
`instances` can be either single instance of type dict or a list | |
of instances. | |
""" | |
# The AI Platform services require regional API endpoints. | |
client_options = {"api_endpoint": api_endpoint} | |
credentials = service_account.Credentials.from_service_account_file( | |
os.getenv("GOOGLE_APPLICATION_CREDENTIALS")) | |
# Initialize client that will be used to create and send requests. | |
# This client only needs to be created once, and can be reused for multiple requests. | |
client = aiplatform.gapic.PredictionServiceClient( | |
credentials=credentials, | |
client_options=client_options) | |
# The format of each instance should conform to the deployed model's prediction input schema. | |
instances = instances if isinstance(instances, list) else [instances] | |
instances = [ | |
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances | |
] | |
parameters_dict = {} | |
parameters = json_format.ParseDict(parameters_dict, Value()) | |
endpoint = client.endpoint_path( | |
project=project, location=location, endpoint=endpoint_id | |
) | |
response = client.predict( | |
endpoint=endpoint, instances=instances, parameters=parameters | |
) | |
# The predictions are a google.protobuf.Value representation of the model's predictions. | |
predictions = response.predictions | |
return predictions | |