diff --git a/setup.py b/setup.py index a2a83c56..2193bdfa 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ python_requires=">=3.8", install_requires=[ "grpcio==1.62.2", + "grpcio-status==1.62.2", "googleapis-common-protos==1.63.0", "protobuf==4.25.3", "flask==3.0.3", diff --git a/testbench/grpc_server.py b/testbench/grpc_server.py index adf369f0..1bb41be1 100644 --- a/testbench/grpc_server.py +++ b/testbench/grpc_server.py @@ -21,11 +21,15 @@ import uuid from collections.abc import Iterable from concurrent import futures +from queue import PriorityQueue import crc32c +import google.protobuf.any_pb2 as any_pb2 import google.protobuf.empty_pb2 as empty_pb2 import grpc from google.protobuf import field_mask_pb2, json_format, text_format +from google.rpc import status_pb2 +from grpc_status import rpc_status import gcs import testbench @@ -581,8 +585,6 @@ def BidiReadObject(self, request_iterator, context): # A PriorityQueue holds tuples (offset, response) where the offset serves as the priority predicate. # A response is removed from the queue with the lowest offset. This will help guarantee that # responses with the same read_id will be delivered in increasing offset order. - from queue import PriorityQueue - responses = PriorityQueue() # TBD: yield_size is configurable and is used to emulate interleaved responses. yield_size = 1 @@ -597,14 +599,14 @@ def process_read_range(range, metadata=None): read_end = len(blob.media) read_id = range.read_id - if start > read_end: - # TODO: Error handling, return a list of read_range_errors in BidiReadObjectError. - return testbench.error.range_not_satisfiable(context) - if range.read_limit < 0: - # TODO: Error handling, return a list of read_range_errors in BidiReadObjectError. - # A negative read_limit will cause an error. - return testbench.error.range_not_satisfiable(context) - elif range.read_limit > 0: + if start > read_end or range.read_limit < 0: + status_msg = self._pack_bidiread_error_details( + range, grpc.StatusCode.OUT_OF_RANGE + ) + grpc_status = rpc_status.to_status(status_msg) + return context.abort(grpc_status) + + if range.read_limit > 0: # read_limit is the maximum number of data bytes the server is allowed to return across all response messages with the same read_id. read_end = min(read_end, start + range.read_limit) @@ -671,6 +673,31 @@ def process_read_range(range, metadata=None): item = responses.get() yield item[1] + def _to_read_range_error_proto(self, range, status_code): + return storage_pb2.ReadRangeError( + read_id=range.read_id, + status={ + "code": status_code.value[0], + "message": status_code.value[1], + }, + ) + + def _pack_bidiread_error_details(self, range, status_code): + range_read_error = self._to_read_range_error_proto(range, status_code) + code = status_code.value[0] + message = status_code.value[1] + detail = any_pb2.Any() + detail.Pack( + storage_pb2.BidiReadObjectError( + read_range_errors=[range_read_error], + ) + ) + return status_pb2.Status( + code=code, + message=message, + details=[detail], + ) + @retry_test(method="storage.objects.patch") def UpdateObject(self, request, context): intersection = field_mask_pb2.FieldMask( diff --git a/tests/test_grpc_server.py b/tests/test_grpc_server.py index 8b74a2be..709ba60b 100755 --- a/tests/test_grpc_server.py +++ b/tests/test_grpc_server.py @@ -24,6 +24,7 @@ import crc32c import grpc +import grpc_status from google.protobuf import field_mask_pb2, timestamp_pb2 import gcs @@ -2441,6 +2442,128 @@ def test_bidi_read_object_generation_precondition_failed(self): grpc.StatusCode.FAILED_PRECONDITION, unittest.mock.ANY ) + def test_bidi_read_object_out_of_order(self): + # Create object in database to read. + media = TestGrpc._create_block(6 * 1024 * 1024).encode("utf-8") + request = testbench.common.FakeRequest( + args={"name": "object-name"}, data=media, headers={}, environ={} + ) + blob, _ = gcs.object.Object.init_media(request, self.bucket.metadata) + self.db.insert_object("bucket-name", blob, None) + + # Test n ranges in 1 stream, where n=3. Test range requests offsets are out of order. + offset_1 = 0 + limit_1 = 3 * 1024 * 1024 + read_id_1 = int(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1000) + offset_2 = 4 * 1024 * 1024 + limit_2 = 1024 + read_id_2 = read_id_1 - 1 + offset_3 = 5 * 1024 * 1024 + limit_3 = 1 + read_id_3 = read_id_1 + 1 + + r1 = storage_pb2.BidiReadObjectRequest( + read_object_spec=storage_pb2.BidiReadObjectSpec( + bucket="projects/_/buckets/bucket-name", + object="object-name", + ), + read_ranges=[ + storage_pb2.ReadRange( + read_offset=offset_1, + read_limit=limit_1, + read_id=read_id_1, + ), + storage_pb2.ReadRange( + read_offset=offset_3, + read_limit=limit_3, + read_id=read_id_3, + ), + ], + ) + r2 = storage_pb2.BidiReadObjectRequest( + read_ranges=[ + storage_pb2.ReadRange( + read_offset=offset_2, + read_limit=limit_2, + read_id=read_id_2, + ), + ], + ) + + streamer = self.grpc.BidiReadObject([r1, r2], context=self.mock_context()) + responses = list(streamer) + read_range_1 = responses[0].object_data_ranges[0].read_range + data_1 = responses[0].object_data_ranges[0].checksummed_data + self.assertEqual(read_id_1, read_range_1.read_id) + self.assertEqual(offset_1, read_range_1.read_offset) + self.assertEqual(limit_1, read_range_1.read_limit) + self.assertEqual(crc32c.crc32c(data_1.content), data_1.crc32c) + read_range_last = responses[-1].object_data_ranges[-1].read_range + data_last = responses[-1].object_data_ranges[-1].checksummed_data + self.assertEqual(read_id_3, read_range_last.read_id) + self.assertEqual(offset_3, read_range_last.read_offset) + self.assertEqual(limit_3, read_range_last.read_limit) + self.assertEqual(crc32c.crc32c(data_last.content), data_last.crc32c) + + def test_bidi_read_out_of_range_error(self): + # Create object in database to read. + media = TestGrpc._create_block(1024 * 1024).encode("utf-8") + request = testbench.common.FakeRequest( + args={"name": "object-name"}, data=media, headers={}, environ={} + ) + blob, _ = gcs.object.Object.init_media(request, self.bucket.metadata) + self.db.insert_object("bucket-name", blob, None) + + # Test out-of-range offset. + offset_1 = 8 * 1024 * 1024 + limit_1 = 1024 + read_id_1 = int(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1000) + r1 = storage_pb2.BidiReadObjectRequest( + read_object_spec=storage_pb2.BidiReadObjectSpec( + bucket="projects/_/buckets/bucket-name", + object="object-name", + ), + read_ranges=[ + storage_pb2.ReadRange( + read_offset=offset_1, + read_limit=limit_1, + read_id=read_id_1, + ), + ], + ) + # Test out-of-range with negative read limit. + limit_2 = -2048 + offset_2 = 10 + read_id_2 = read_id_1 + 1 + r2 = storage_pb2.BidiReadObjectRequest( + read_object_spec=storage_pb2.BidiReadObjectSpec( + bucket="projects/_/buckets/bucket-name", + object="object-name", + ), + read_ranges=[ + storage_pb2.ReadRange( + read_offset=offset_2, + read_limit=limit_2, + read_id=read_id_2, + ), + ], + ) + + for request in [r1, r2]: + context = unittest.mock.Mock() + context.abort = unittest.mock.MagicMock() + context.abort.side_effect = grpc.RpcError() + with self.assertRaises(grpc.RpcError): + streamer = self.grpc.BidiReadObject([request], context=context) + list(streamer) + + context.abort.assert_called_once() + abort_status = context.abort.call_args[0][0] + grpc_status_details_bin = abort_status.trailing_metadata[0][1] + self.assertIsInstance(abort_status, grpc_status.rpc_status._Status) + self.assertIn(grpc.StatusCode.OUT_OF_RANGE, abort_status) + self.assertIn(b"BidiReadObjectError", grpc_status_details_bin) + if __name__ == "__main__": unittest.main()