diff --git a/sdks/python/apache_beam/io/gcp/gcsio.py b/sdks/python/apache_beam/io/gcp/gcsio.py index 8056de51f43f..3d6588456023 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio.py +++ b/sdks/python/apache_beam/io/gcp/gcsio.py @@ -36,6 +36,8 @@ from typing import Union from google.cloud import storage +from google.cloud.exceptions import from_http_response +from google.cloud.exceptions import GoogleCloudError from google.cloud.exceptions import NotFound from google.cloud.storage.fileio import BlobReader from google.cloud.storage.fileio import BlobWriter @@ -264,9 +266,45 @@ def delete(self, path): except NotFound: return + def _batch_with_retry(self, requests, fn): + current_requests = [*enumerate(requests)] + responses = [None for _ in current_requests] + + @self._storage_client_retry + def run_with_retry(): + current_batch = self.client.batch(raise_exception=False) + with current_batch: + for _, request in current_requests: + fn(request) + last_retryable_exception = None + for (i, current_pair), response in zip( + [*current_requests], current_batch._responses, strict=True + ): + responses[i] = response + should_retry = ( + response.status_code >= 400 and + self._storage_client_retry._predicate(from_http_response(response))) + if should_retry: + last_retryable_exception = from_http_response(response) + else: + current_requests.remove((i, current_pair)) + if last_retryable_exception: + raise last_retryable_exception + + try: + run_with_retry() + except GoogleCloudError: + pass + + return responses + + def _delete_batch_request(self, path): + bucket_name, blob_name = parse_gcs_path(path) + bucket = self.client.bucket(bucket_name) + bucket.delete_blob(blob_name) + def delete_batch(self, paths): """Deletes the objects at the given GCS paths. - Warning: any exception during batch delete will NOT be retried. Args: paths: List of GCS file path patterns or Dict with GCS file path patterns @@ -285,16 +323,11 @@ def delete_batch(self, paths): current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE] else: current_paths = paths[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for path in current_paths: - bucket_name, blob_name = parse_gcs_path(path) - bucket = self.client.bucket(bucket_name) - bucket.delete_blob(blob_name) - + responses = self._batch_with_retry( + current_paths, self._delete_batch_request) for i, path in enumerate(current_paths): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400 and resp.status_code != 404: error_code = resp.status_code final_results.append((path, error_code)) @@ -334,9 +367,16 @@ def copy(self, src, dest): source_generation=src_generation, retry=self._storage_client_retry) + def _copy_batch_request(self, pair): + src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) + dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) + src_bucket = self.client.bucket(src_bucket_name) + src_blob = src_bucket.blob(src_blob_name) + dest_bucket = self.client.bucket(dest_bucket_name) + src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) + def copy_batch(self, src_dest_pairs): """Copies the given GCS objects from src to dest. - Warning: any exception during batch copy will NOT be retried. Args: src_dest_pairs: list of (src, dest) tuples of gs:/// files @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs): current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE] else: current_pairs = src_dest_pairs[s:] - current_batch = self.client.batch(raise_exception=False) - with current_batch: - for pair in current_pairs: - src_bucket_name, src_blob_name = parse_gcs_path(pair[0]) - dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1]) - src_bucket = self.client.bucket(src_bucket_name) - src_blob = src_bucket.blob(src_blob_name) - dest_bucket = self.client.bucket(dest_bucket_name) - - src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name) - + responses = self._batch_with_retry( + current_pairs, self._copy_batch_request) for i, pair in enumerate(current_pairs): error_code = None - resp = current_batch._responses[i] + resp = responses[i] if resp.status_code >= 400: error_code = resp.status_code final_results.append((pair[0], pair[1], error_code)) diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 19df15dcf7fa..dd2a4c33e7e3 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -477,6 +477,59 @@ def test_copy(self): 'gs://gcsio-test/non-existent', 'gs://gcsio-test/non-existent-destination') + @mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3) + @mock.patch('time.sleep', mock.Mock()) + def test_copy_batch(self): + src_dest_pairs = [ + (f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt') + for i in range(7) + ] + + def _fake_responses(status_codes): + return mock.Mock( + __enter__=mock.Mock(), + __exit__=mock.Mock(), + _responses=[ + mock.Mock( + **{ + 'json.return_value': { + 'error': { + 'message': 'error' + } + }, + 'request.method': 'BATCH', + 'request.url': 'contentid://None', + }, + status_code=code, + ) for code in status_codes + ], + ) + + gcs_io = gcsio.GcsIO( + storage_client=mock.Mock( + batch=mock.Mock( + side_effect=[ + _fake_responses([200, 404, 429]), + _fake_responses([429]), + _fake_responses([429]), + _fake_responses([200]), + _fake_responses([200, 429, 200]), + _fake_responses([200]), + _fake_responses([200]), + ]), + )) + results = gcs_io.copy_batch(src_dest_pairs) + expected = [ + ('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None), + ('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404), + ('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None), + ('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None), + ('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None), + ('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None), + ('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None), + ] + self.assertEqual(results, expected) + def test_copytree(self): src_dir_name = 'gs://gcsio-test/source/' dest_dir_name = 'gs://gcsio-test/dest/'