From b8304dad10bc440e506fa6dd2e252628aa98be20 Mon Sep 17 00:00:00 2001 From: Dmytro Sadovnychyi Date: Wed, 8 Jan 2025 20:51:35 +0100 Subject: [PATCH] Add retry logic to each batch method of the GCS IO A transient error might occur when writing a lot of shards to GCS, and right now the GCS IO does not have any retry logic in place: https://github.com/apache/beam/blob/a06454a2/sdks/python/apache_beam/io/gcp/gcsio.py#L269 It means that in such cases the entire bundle of elements fails, and then Beam itself will attempt to retry the entire bundle, and will fail the job if it exceeds the number of retries. This change adds new logic to retry only failed requests, and uses the typical exponential backoff strategy. Note that this change accesses a private method (`_predicate`) of the retry object, which we could avoid by basically copying the logic over here. But existing code already accesses `_responses` property so maybe it's not a big deal. https://github.com/apache/beam/blob/b4c3a4ff/sdks/python/apache_beam/io/gcp/gcsio.py#L297 Existing (unresolved) issue in the GCS client library: https://github.com/googleapis/python-storage/issues/1277 --- sdks/python/apache_beam/io/gcp/gcsio.py | 75 ++++++++++++++------ sdks/python/apache_beam/io/gcp/gcsio_test.py | 53 ++++++++++++++ 2 files changed, 106 insertions(+), 22 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 8056de51f43f..3d6588456023 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -36,6 +36,8 @@ from typing import Union from google.cloud import storage +from google.cloud.exceptions import from_http_response +from google.cloud.exceptions import GoogleCloudError from google.cloud.exceptions import NotFound from google.cloud.storage.fileio import BlobReader from google.cloud.storage.fileio import BlobWriter @@ -264,9 +266,45 @@ def delete(self, path): except NotFound: return + def _batch_with_retry(self, requests, fn): + current_requests = [*enumerate(requests)] + responses = [None for _ in current_requests] + + @self._storage_client_retry + def run_with_retry(): + current_batch = self.client.batch(raise_exception=False) + with current_batch: + for _, request in current_requests: + fn(request) + last_retryable_exception = None + for (i, current_pair), response in zip( + [*current_requests], current_batch._responses, strict=True + ): + responses[i] = response + should_retry = ( + response.status_code >= 400 and + self._storage_client_retry._predicate(from_http_response(response))) + if should_retry: + last_retryable_exception = from_http_response(response) + else: + current_requests.remove((i, current_pair)) + if last_retryable_exception: + raise last_retryable_exception + + try: + run_with_retry() + except GoogleCloudError: + pass + + return responses + + def _delete_batch_request(self, path): + bucket_name, blob_name = parse_gcs_path(path) + bucket = self.client.bucket(bucket_name) + bucket.delete_blob(blob_name) + def delete_batch(self, paths): """Deletes the objects at the given GCS paths. - Warning: any exception during batch delete will NOT be retried. Args: paths: List of GCS file path patterns or Dict with GCS file path patterns @@ -285,16 +323,11 @@ def delete_batch(self, paths): current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE] else: current_paths = paths[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for path in current_paths: - bucket_name, blob_name = parse_gcs_path(path) - bucket = self.client.bucket(bucket_name) - bucket.delete_blob(blob_name) - + responses = self._batch_with_retry( + current_paths, self._delete_batch_request) for i, path in enumerate(current_paths): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400 and resp.status_code != 404: error_code = resp.status_code final_results.append((path, error_code)) @@ -334,9 +367,16 @@ def copy(self, src, dest): source_generation=src_generation, retry=self._storage_client_retry) + def _copy_batch_request(self, pair): + src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) + dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) + src_bucket = self.client.bucket(src_bucket_name) + src_blob = src_bucket.blob(src_blob_name) + dest_bucket = self.client.bucket(dest_bucket_name) + src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) + def copy_batch(self, src_dest_pairs): """Copies the given GCS objects from src to dest. - Warning: any exception during batch copy will NOT be retried. Args: src_dest_pairs: list of (src, dest) tuples of gs:/// files @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs): current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE] else: current_pairs = src_dest_pairs[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for pair in current_pairs: - src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) - dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) - src_bucket = self.client.bucket(src_bucket_name) - src_blob = src_bucket.blob(src_blob_name) - dest_bucket = self.client.bucket(dest_bucket_name) - - src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) - + responses = self._batch_with_retry( + current_pairs, self._copy_batch_request) for i, pair in enumerate(current_pairs): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400: error_code = resp.status_code final_results.append((pair[0], pair[1], error_code)) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 19df15dcf7fa..dd2a4c33e7e3 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -477,6 +477,59 @@ def test_copy(self): 'gs://gcsio-test/non-existent', 'gs://gcsio-test/non-existent-destination') + @mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3) + @mock.patch('time.sleep', mock.Mock()) + def test_copy_batch(self): + src_dest_pairs = [ + (f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt') + for i in range(7) + ] + + def _fake_responses(status_codes): + return mock.Mock( + __enter__=mock.Mock(), + __exit__=mock.Mock(), + _responses=[ + mock.Mock( + **{ + 'json.return_value': { + 'error': { + 'message': 'error' + } + }, + 'request.method': 'BATCH', + 'request.url': 'contentid://None', + }, + status_code=code, + ) for code in status_codes + ], + ) + + gcs_io = gcsio.GcsIO( + storage_client=mock.Mock( + batch=mock.Mock( + side_effect=[ + _fake_responses([200, 404, 429]), + _fake_responses([429]), + _fake_responses([429]), + _fake_responses([200]), + _fake_responses([200, 429, 200]), + _fake_responses([200]), + _fake_responses([200]), + ]), + )) + results = gcs_io.copy_batch(src_dest_pairs) + expected = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None), + ('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404), + ('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None), + ('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None), + ('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None), + ('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None), + ('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None), + ] + self.assertEqual(results, expected) + def test_copytree(self): src_dir_name = 'gs://gcsio-test/source/' dest_dir_name = 'gs://gcsio-test/dest/'