diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 7a3002faa4..1dac4a118d 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -2977,6 +2977,7 @@ def upload( serving_container_args: Optional[Sequence[str]] = None, serving_container_environment_variables: Optional[Dict[str, str]] = None, serving_container_ports: Optional[Sequence[int]] = None, + serving_container_grpc_ports: Optional[Sequence[int]] = None, local_model: Optional["LocalModel"] = None, instance_schema_uri: Optional[str] = None, parameters_schema_uri: Optional[str] = None, @@ -3083,6 +3084,14 @@ def upload( no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. + serving_container_grpc_ports: Optional[Sequence[int]]=None, + Declaration of ports that are exposed by the container. Vertex AI sends gRPC + prediction requests that it receives to the first port on this list. Vertex + AI also sends liveness and health checks to this port. + If you do not specify this field, gRPC requests to the container will be + disabled. + Vertex AI does not use ports other than the first one listed. This field + corresponds to the `ports` field of the Kubernetes Containers v1 core API. local_model (Optional[LocalModel]): Optional. A LocalModel instance that includes a `serving_container_spec`. If provided, the `serving_container_spec` of the LocalModel instance @@ -3238,6 +3247,7 @@ def upload( env = None ports = None + grpc_ports = None deployment_timeout = ( duration_pb2.Duration(seconds=serving_container_deployment_timeout) if serving_container_deployment_timeout @@ -3256,6 +3266,11 @@ def upload( gca_model_compat.Port(container_port=port) for port in serving_container_ports ] + if serving_container_grpc_ports: + grpc_ports = [ + gca_model_compat.Port(container_port=port) + for port in serving_container_grpc_ports + ] if ( serving_container_startup_probe_exec or serving_container_startup_probe_period_seconds @@ -3293,6 +3308,7 @@ def upload( args=serving_container_args, env=env, ports=ports, + grpc_ports=grpc_ports, predict_route=serving_container_predict_route, health_route=serving_container_health_route, deployment_timeout=deployment_timeout, diff --git a/google/cloud/aiplatform/prediction/local_model.py b/google/cloud/aiplatform/prediction/local_model.py index 20b52527e3..313c68c58b 100644 --- a/google/cloud/aiplatform/prediction/local_model.py +++ b/google/cloud/aiplatform/prediction/local_model.py @@ -60,6 +60,7 @@ def __init__( serving_container_args: Optional[Sequence[str]] = None, serving_container_environment_variables: Optional[Dict[str, str]] = None, serving_container_ports: Optional[Sequence[int]] = None, + serving_container_grpc_ports: Optional[Sequence[int]] = None, serving_container_deployment_timeout: Optional[int] = None, serving_container_shared_memory_size_mb: Optional[int] = None, serving_container_startup_probe_exec: Optional[Sequence[str]] = None, @@ -110,6 +111,14 @@ def __init__( no impact on whether the port is actually exposed, any port listening on the default "0.0.0.0" address inside a container will be accessible from the network. + serving_container_grpc_ports: Optional[Sequence[int]]=None, + Declaration of ports that are exposed by the container. Vertex AI sends gRPC + prediction requests that it receives to the first port on this list. Vertex + AI also sends liveness and health checks to this port. + If you do not specify this field, gRPC requests to the container will be + disabled. + Vertex AI does not use ports other than the first one listed. This field + corresponds to the `ports` field of the Kubernetes Containers v1 core API. serving_container_deployment_timeout (int): Optional. Deployment timeout in seconds. serving_container_shared_memory_size_mb (int): @@ -156,6 +165,7 @@ def __init__( env = None ports = None + grpc_ports = None deployment_timeout = ( duration_pb2.Duration(seconds=serving_container_deployment_timeout) if serving_container_deployment_timeout @@ -174,6 +184,11 @@ def __init__( gca_model_compat.Port(container_port=port) for port in serving_container_ports ] + if serving_container_grpc_ports: + grpc_ports = [ + gca_model_compat.Port(container_port=port) + for port in serving_container_grpc_ports + ] if ( serving_container_startup_probe_exec or serving_container_startup_probe_period_seconds @@ -211,6 +226,7 @@ def __init__( args=serving_container_args, env=env, ports=ports, + grpc_ports=grpc_ports, predict_route=serving_container_predict_route, health_route=serving_container_health_route, deployment_timeout=deployment_timeout, diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index af771962af..3031562eb4 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -113,6 +113,7 @@ "loss_fn": "mse", } _TEST_SERVING_CONTAINER_PORTS = [8888, 10000] +_TEST_SERVING_CONTAINER_GRPC_PORTS = [7777, 7000] _TEST_SERVING_CONTAINER_DEPLOYMENT_TIMEOUT = 100 _TEST_SERVING_CONTAINER_SHARED_MEMORY_SIZE_MB = 1000 _TEST_SERVING_CONTAINER_STARTUP_PROBE_EXEC = ["a", "b"] @@ -1606,6 +1607,7 @@ def test_upload_uploads_and_gets_model_with_all_args( serving_container_args=_TEST_SERVING_CONTAINER_ARGS, serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, + serving_container_grpc_ports=_TEST_SERVING_CONTAINER_GRPC_PORTS, explanation_metadata=_TEST_EXPLANATION_METADATA, explanation_parameters=_TEST_EXPLANATION_PARAMETERS, labels=_TEST_LABEL, @@ -1634,6 +1636,11 @@ def test_upload_uploads_and_gets_model_with_all_args( for port in _TEST_SERVING_CONTAINER_PORTS ] + grpc_ports = [ + gca_model.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_GRPC_PORTS + ] + deployment_timeout = duration_pb2.Duration( seconds=_TEST_SERVING_CONTAINER_DEPLOYMENT_TIMEOUT ) @@ -1662,6 +1669,7 @@ def test_upload_uploads_and_gets_model_with_all_args( args=_TEST_SERVING_CONTAINER_ARGS, env=env, ports=ports, + grpc_ports=grpc_ports, deployment_timeout=deployment_timeout, shared_memory_size_mb=_TEST_SERVING_CONTAINER_SHARED_MEMORY_SIZE_MB, startup_probe=startup_probe, diff --git a/tests/unit/aiplatform/test_prediction.py b/tests/unit/aiplatform/test_prediction.py index 9f27c81cb8..1cb6ca5875 100644 --- a/tests/unit/aiplatform/test_prediction.py +++ b/tests/unit/aiplatform/test_prediction.py @@ -111,6 +111,7 @@ "loss_fn": "mse", } _TEST_SERVING_CONTAINER_PORTS = [8888, 10000] +_TEST_SERVING_CONTAINER_GRPC_PORTS = [7777, 7000] _TEST_ID = "1028944691210842416" _TEST_LABEL = {"team": "experimentation", "trial_id": "x435"} _TEST_APPENDED_USER_AGENT = ["fake_user_agent"] @@ -1112,6 +1113,10 @@ def test_init_with_serving_container_spec(self): gca_model_compat.Port(container_port=port) for port in _TEST_SERVING_CONTAINER_PORTS ] + grpc_ports = [ + gca_model_compat.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_GRPC_PORTS + ] container_spec = gca_model_compat.ModelContainerSpec( image_uri=_TEST_SERVING_CONTAINER_IMAGE, predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, @@ -1120,6 +1125,7 @@ def test_init_with_serving_container_spec(self): args=_TEST_SERVING_CONTAINER_ARGS, env=env, ports=ports, + grpc_ports=grpc_ports, ) local_model = LocalModel( @@ -1139,6 +1145,9 @@ def test_init_with_serving_container_spec(self): assert local_model.serving_container_spec.args == container_spec.args assert local_model.serving_container_spec.env == container_spec.env assert local_model.serving_container_spec.ports == container_spec.ports + assert ( + local_model.serving_container_spec.grpc_ports == container_spec.grpc_ports + ) def test_init_with_serving_container_spec_but_not_image_uri_throws_exception(self): env = [ @@ -1149,6 +1158,10 @@ def test_init_with_serving_container_spec_but_not_image_uri_throws_exception(sel gca_model_compat.Port(container_port=port) for port in _TEST_SERVING_CONTAINER_PORTS ] + grpc_ports = [ + gca_model_compat.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_GRPC_PORTS + ] container_spec = gca_model_compat.ModelContainerSpec( predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, health_route=_TEST_SERVING_CONTAINER_HEALTH_ROUTE, @@ -1156,6 +1169,7 @@ def test_init_with_serving_container_spec_but_not_image_uri_throws_exception(sel args=_TEST_SERVING_CONTAINER_ARGS, env=env, ports=ports, + grpc_ports=grpc_ports, ) expected_message = "Image uri is required for the serving container spec to initialize a LocalModel instance." @@ -1175,6 +1189,7 @@ def test_init_with_separate_args(self): serving_container_args=_TEST_SERVING_CONTAINER_ARGS, serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, + serving_container_grpc_ports=_TEST_SERVING_CONTAINER_GRPC_PORTS, ) env = [ @@ -1187,6 +1202,11 @@ def test_init_with_separate_args(self): for port in _TEST_SERVING_CONTAINER_PORTS ] + grpc_ports = [ + gca_model_compat.Port(container_port=port) + for port in _TEST_SERVING_CONTAINER_GRPC_PORTS + ] + container_spec = gca_model_compat.ModelContainerSpec( image_uri=_TEST_SERVING_CONTAINER_IMAGE, predict_route=_TEST_SERVING_CONTAINER_PREDICTION_ROUTE, @@ -1195,6 +1215,7 @@ def test_init_with_separate_args(self): args=_TEST_SERVING_CONTAINER_ARGS, env=env, ports=ports, + grpc_ports=grpc_ports, ) assert local_model.serving_container_spec.image_uri == container_spec.image_uri @@ -1210,6 +1231,9 @@ def test_init_with_separate_args(self): assert local_model.serving_container_spec.args == container_spec.args assert local_model.serving_container_spec.env == container_spec.env assert local_model.serving_container_spec.ports == container_spec.ports + assert ( + local_model.serving_container_spec.grpc_ports == container_spec.grpc_ports + ) def test_init_with_separate_args_but_not_image_uri_throws_exception(self): expected_message = "Serving container image uri is required to initialize a LocalModel instance." @@ -1222,6 +1246,7 @@ def test_init_with_separate_args_but_not_image_uri_throws_exception(self): serving_container_args=_TEST_SERVING_CONTAINER_ARGS, serving_container_environment_variables=_TEST_SERVING_CONTAINER_ENVIRONMENT_VARIABLES, serving_container_ports=_TEST_SERVING_CONTAINER_PORTS, + serving_container_grpc_ports=_TEST_SERVING_CONTAINER_GRPC_PORTS, ) assert str(exception.value) == expected_message