Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [google-cloud-batch] Add support for opt-in debug logging #13317

Merged
merged 9 commits into from
Dec 12, 2024

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = ()
metadata: Sequence[Tuple[str, Union[str, bytes]]] = ()
):
"""Instantiate the pager.

Expand All @@ -81,8 +81,10 @@ def __init__(
retry (google.api_core.retry.Retry): Designation of what errors,
if any, should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be
sent along with the request as metadata. Normally, each value must be of type `str`,
but for metadata keys ending with the suffix `-bin`, the corresponding values must
be of type `bytes`.
"""
self._method = method
self._request = batch.ListJobsRequest(request)
Expand Down Expand Up @@ -141,7 +143,7 @@ def __init__(
*,
retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT,
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = ()
metadata: Sequence[Tuple[str, Union[str, bytes]]] = ()
):
"""Instantiates the pager.

Expand All @@ -155,8 +157,10 @@ def __init__(
retry (google.api_core.retry.AsyncRetry): Designation of what errors,
if any, should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be
sent along with the request as metadata. Normally, each value must be of type `str`,
but for metadata keys ending with the suffix `-bin`, the corresponding values must
be of type `bytes`.
"""
self._method = method
self._request = batch.ListJobsRequest(request)
Expand Down Expand Up @@ -219,7 +223,7 @@ def __init__(
*,
retry: OptionalRetry = gapic_v1.method.DEFAULT,
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = ()
metadata: Sequence[Tuple[str, Union[str, bytes]]] = ()
):
"""Instantiate the pager.

Expand All @@ -233,8 +237,10 @@ def __init__(
retry (google.api_core.retry.Retry): Designation of what errors,
if any, should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be
sent along with the request as metadata. Normally, each value must be of type `str`,
but for metadata keys ending with the suffix `-bin`, the corresponding values must
be of type `bytes`.
"""
self._method = method
self._request = batch.ListTasksRequest(request)
Expand Down Expand Up @@ -293,7 +299,7 @@ def __init__(
*,
retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT,
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, str]] = ()
metadata: Sequence[Tuple[str, Union[str, bytes]]] = ()
):
"""Instantiates the pager.

Expand All @@ -307,8 +313,10 @@ def __init__(
retry (google.api_core.retry.AsyncRetry): Designation of what errors,
if any, should be retried.
timeout (float): The timeout for this request.
metadata (Sequence[Tuple[str, str]]): Strings which should be
sent along with the request as metadata.
metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be
sent along with the request as metadata. Normally, each value must be of type `str`,
but for metadata keys ending with the suffix `-bin`, the corresponding values must
be of type `bytes`.
"""
self._method = method
self._request = batch.ListTasksRequest(request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging as std_logging
import pickle
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import warnings

Expand All @@ -22,7 +25,10 @@
from google.auth.transport.grpc import SslCredentials # type: ignore
from google.cloud.location import locations_pb2 # type: ignore
from google.longrunning import operations_pb2 # type: ignore
from google.protobuf.json_format import MessageToJson
import google.protobuf.message
import grpc # type: ignore
import proto # type: ignore

from google.cloud.batch_v1.types import batch
from google.cloud.batch_v1.types import job
Expand All @@ -31,6 +37,81 @@

from .base import DEFAULT_CLIENT_INFO, BatchServiceTransport

try:
from google.api_core import client_logging # type: ignore

CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER
except ImportError: # pragma: NO COVER
CLIENT_LOGGING_SUPPORTED = False

_LOGGER = std_logging.getLogger(__name__)


class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER
def intercept_unary_unary(self, continuation, client_call_details, request):
logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor(
std_logging.DEBUG
)
if logging_enabled: # pragma: NO COVER
request_metadata = client_call_details.metadata
if isinstance(request, proto.Message):
request_payload = type(request).to_json(request)
elif isinstance(request, google.protobuf.message.Message):
request_payload = MessageToJson(request)
else:
request_payload = f"{type(request).__name__}: {pickle.dumps(request)}"

request_metadata = {
key: value.decode("utf-8") if isinstance(value, bytes) else value
for key, value in request_metadata
}
grpc_request = {
"payload": request_payload,
"requestMethod": "grpc",
"metadata": dict(request_metadata),
}
_LOGGER.debug(
f"Sending request for {client_call_details.method}",
extra={
"serviceName": "google.cloud.batch.v1.BatchService",
"rpcName": client_call_details.method,
"request": grpc_request,
"metadata": grpc_request["metadata"],
},
)

response = continuation(client_call_details, request)
if logging_enabled: # pragma: NO COVER
response_metadata = response.trailing_metadata()
# Convert gRPC metadata `<class 'grpc.aio._metadata.Metadata'>` to list of tuples
metadata = (
dict([(k, str(v)) for k, v in response_metadata])
if response_metadata
else None
)
result = response.result()
if isinstance(result, proto.Message):
response_payload = type(result).to_json(result)
elif isinstance(result, google.protobuf.message.Message):
response_payload = MessageToJson(result)
else:
response_payload = f"{type(result).__name__}: {pickle.dumps(result)}"
grpc_response = {
"payload": response_payload,
"metadata": metadata,
"status": "OK",
}
_LOGGER.debug(
f"Received response for {client_call_details.method}.",
extra={
"serviceName": "google.cloud.batch.v1.BatchService",
"rpcName": client_call_details.method,
"response": grpc_response,
"metadata": grpc_response["metadata"],
},
)
return response


class BatchServiceGrpcTransport(BatchServiceTransport):
"""gRPC backend transport for BatchService.
Expand Down Expand Up @@ -187,7 +268,12 @@ def __init__(
],
)

# Wrap messages. This must be done after self._grpc_channel exists
self._interceptor = _LoggingClientInterceptor()
self._logged_channel = grpc.intercept_channel(
self._grpc_channel, self._interceptor
)

# Wrap messages. This must be done after self._logged_channel exists
self._prep_wrapped_messages(client_info)

@classmethod
Expand Down Expand Up @@ -251,7 +337,9 @@ def operations_client(self) -> operations_v1.OperationsClient:
"""
# Quick check: Only create a new client if we do not already have one.
if self._operations_client is None:
self._operations_client = operations_v1.OperationsClient(self.grpc_channel)
self._operations_client = operations_v1.OperationsClient(
self._logged_channel
)

# Return the client from cache.
return self._operations_client
Expand All @@ -273,7 +361,7 @@ def create_job(self) -> Callable[[batch.CreateJobRequest], gcb_job.Job]:
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "create_job" not in self._stubs:
self._stubs["create_job"] = self.grpc_channel.unary_unary(
self._stubs["create_job"] = self._logged_channel.unary_unary(
"/google.cloud.batch.v1.BatchService/CreateJob",
request_serializer=batch.CreateJobRequest.serialize,
response_deserializer=gcb_job.Job.deserialize,
Expand All @@ -297,7 +385,7 @@ def get_job(self) -> Callable[[batch.GetJobRequest], job.Job]:
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_job" not in self._stubs:
self._stubs["get_job"] = self.grpc_channel.unary_unary(
self._stubs["get_job"] = self._logged_channel.unary_unary(
"/google.cloud.batch.v1.BatchService/GetJob",
request_serializer=batch.GetJobRequest.serialize,
response_deserializer=job.Job.deserialize,
Expand All @@ -323,7 +411,7 @@ def delete_job(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "delete_job" not in self._stubs:
self._stubs["delete_job"] = self.grpc_channel.unary_unary(
self._stubs["delete_job"] = self._logged_channel.unary_unary(
"/google.cloud.batch.v1.BatchService/DeleteJob",
request_serializer=batch.DeleteJobRequest.serialize,
response_deserializer=operations_pb2.Operation.FromString,
Expand All @@ -347,7 +435,7 @@ def list_jobs(self) -> Callable[[batch.ListJobsRequest], batch.ListJobsResponse]
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "list_jobs" not in self._stubs:
self._stubs["list_jobs"] = self.grpc_channel.unary_unary(
self._stubs["list_jobs"] = self._logged_channel.unary_unary(
"/google.cloud.batch.v1.BatchService/ListJobs",
request_serializer=batch.ListJobsRequest.serialize,
response_deserializer=batch.ListJobsResponse.deserialize,
Expand All @@ -371,7 +459,7 @@ def get_task(self) -> Callable[[batch.GetTaskRequest], task.Task]:
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_task" not in self._stubs:
self._stubs["get_task"] = self.grpc_channel.unary_unary(
self._stubs["get_task"] = self._logged_channel.unary_unary(
"/google.cloud.batch.v1.BatchService/GetTask",
request_serializer=batch.GetTaskRequest.serialize,
response_deserializer=task.Task.deserialize,
Expand All @@ -395,15 +483,15 @@ def list_tasks(self) -> Callable[[batch.ListTasksRequest], batch.ListTasksRespon
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "list_tasks" not in self._stubs:
self._stubs["list_tasks"] = self.grpc_channel.unary_unary(
self._stubs["list_tasks"] = self._logged_channel.unary_unary(
"/google.cloud.batch.v1.BatchService/ListTasks",
request_serializer=batch.ListTasksRequest.serialize,
response_deserializer=batch.ListTasksResponse.deserialize,
)
return self._stubs["list_tasks"]

def close(self):
self.grpc_channel.close()
self._logged_channel.close()

@property
def delete_operation(
Expand All @@ -415,7 +503,7 @@ def delete_operation(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "delete_operation" not in self._stubs:
self._stubs["delete_operation"] = self.grpc_channel.unary_unary(
self._stubs["delete_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/DeleteOperation",
request_serializer=operations_pb2.DeleteOperationRequest.SerializeToString,
response_deserializer=None,
Expand All @@ -432,7 +520,7 @@ def cancel_operation(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "cancel_operation" not in self._stubs:
self._stubs["cancel_operation"] = self.grpc_channel.unary_unary(
self._stubs["cancel_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/CancelOperation",
request_serializer=operations_pb2.CancelOperationRequest.SerializeToString,
response_deserializer=None,
Expand All @@ -449,7 +537,7 @@ def get_operation(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_operation" not in self._stubs:
self._stubs["get_operation"] = self.grpc_channel.unary_unary(
self._stubs["get_operation"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/GetOperation",
request_serializer=operations_pb2.GetOperationRequest.SerializeToString,
response_deserializer=operations_pb2.Operation.FromString,
Expand All @@ -468,7 +556,7 @@ def list_operations(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "list_operations" not in self._stubs:
self._stubs["list_operations"] = self.grpc_channel.unary_unary(
self._stubs["list_operations"] = self._logged_channel.unary_unary(
"/google.longrunning.Operations/ListOperations",
request_serializer=operations_pb2.ListOperationsRequest.SerializeToString,
response_deserializer=operations_pb2.ListOperationsResponse.FromString,
Expand All @@ -487,7 +575,7 @@ def list_locations(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "list_locations" not in self._stubs:
self._stubs["list_locations"] = self.grpc_channel.unary_unary(
self._stubs["list_locations"] = self._logged_channel.unary_unary(
"/google.cloud.location.Locations/ListLocations",
request_serializer=locations_pb2.ListLocationsRequest.SerializeToString,
response_deserializer=locations_pb2.ListLocationsResponse.FromString,
Expand All @@ -504,7 +592,7 @@ def get_location(
# gRPC handles serialization and deserialization, so we just need
# to pass in the functions for each.
if "get_location" not in self._stubs:
self._stubs["get_location"] = self.grpc_channel.unary_unary(
self._stubs["get_location"] = self._logged_channel.unary_unary(
"/google.cloud.location.Locations/GetLocation",
request_serializer=locations_pb2.GetLocationRequest.SerializeToString,
response_deserializer=locations_pb2.Location.FromString,
Expand Down
Loading
Loading