Skip to content

Commit

Permalink
Add retry logic to each batch method of the GCS IO
Browse files Browse the repository at this point in the history
A transient error might occur when writing a lot of shards to GCS, and right now
the GCS IO does not have any retry logic in place:

https://github.com/apache/beam/blob/a06454a2/sdks/python/apache_beam/io/gcp/gcsio.py#L269

It means that in such cases the entire bundle of elements fails, and then Beam
itself will attempt to retry the entire bundle, and will fail the job if it
exceeds the number of retries.

This change adds new logic to retry only failed requests, and uses the typical
exponential backoff strategy.

Note that this change accesses a private method (`_predicate`) of the retry
object, which we could avoid by basically copying the logic over here. But
existing code already accesses `_responses` property so maybe it's not a big
deal.

https://github.com/apache/beam/blob/b4c3a4ff/sdks/python/apache_beam/io/gcp/gcsio.py#L297

Existing (unresolved) issue in the GCS client library:

googleapis/python-storage#1277
  • Loading branch information
sadovnychyi committed Jan 8, 2025
1 parent b4c3a4f commit 489ea9f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 22 deletions.
75 changes: 53 additions & 22 deletions sdks/python/apache_beam/io/gcp/gcsio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
from typing import Union

from google.cloud import storage
from google.cloud.exceptions import GoogleCloudError
from google.cloud.exceptions import NotFound
from google.cloud.exceptions import from_http_response
from google.cloud.storage.fileio import BlobReader
from google.cloud.storage.fileio import BlobWriter
from google.cloud.storage.retry import DEFAULT_RETRY
Expand Down Expand Up @@ -264,9 +266,45 @@ def delete(self, path):
except NotFound:
return

def _batch_with_retry(self, requests, fn):
current_requests = [*enumerate(requests)]
responses = [None for _ in current_requests]

@self._storage_client_retry
def run_with_retry():
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for _, request in current_requests:
fn(request)
last_retryable_exception = None
for (i, current_pair), response in zip(
[*current_requests], current_batch._responses
):
responses[i] = response
should_retry = (
response.status_code >= 400 and
self._storage_client_retry._predicate(from_http_response(response)))
if should_retry:
last_retryable_exception = from_http_response(response)
else:
current_requests.remove((i, current_pair))
if last_retryable_exception:
raise last_retryable_exception

try:
run_with_retry()
except GoogleCloudError:
pass

return responses

def _delete_batch_request(self, path):
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

def delete_batch(self, paths):
"""Deletes the objects at the given GCS paths.
Warning: any exception during batch delete will NOT be retried.
Args:
paths: List of GCS file path patterns or Dict with GCS file path patterns
Expand All @@ -285,16 +323,11 @@ def delete_batch(self, paths):
current_paths = paths[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_paths = paths[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for path in current_paths:
bucket_name, blob_name = parse_gcs_path(path)
bucket = self.client.bucket(bucket_name)
bucket.delete_blob(blob_name)

responses = self._batch_with_retry(
current_paths, self._delete_batch_request)
for i, path in enumerate(current_paths):
error_code = None
resp = current_batch._responses[i]
resp = responses[i]
if resp.status_code >= 400 and resp.status_code != 404:
error_code = resp.status_code
final_results.append((path, error_code))
Expand Down Expand Up @@ -334,9 +367,16 @@ def copy(self, src, dest):
source_generation=src_generation,
retry=self._storage_client_retry)

def _copy_batch_request(self, pair):
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)
src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

def copy_batch(self, src_dest_pairs):
"""Copies the given GCS objects from src to dest.
Warning: any exception during batch copy will NOT be retried.
Args:
src_dest_pairs: list of (src, dest) tuples of gs://<bucket>/<name> files
Expand All @@ -354,20 +394,11 @@ def copy_batch(self, src_dest_pairs):
current_pairs = src_dest_pairs[s:s + MAX_BATCH_OPERATION_SIZE]
else:
current_pairs = src_dest_pairs[s:]
current_batch = self.client.batch(raise_exception=False)
with current_batch:
for pair in current_pairs:
src_bucket_name, src_blob_name = parse_gcs_path(pair[0])
dest_bucket_name, dest_blob_name = parse_gcs_path(pair[1])
src_bucket = self.client.bucket(src_bucket_name)
src_blob = src_bucket.blob(src_blob_name)
dest_bucket = self.client.bucket(dest_bucket_name)

src_bucket.copy_blob(src_blob, dest_bucket, dest_blob_name)

responses = self._batch_with_retry(
current_pairs, self._copy_batch_request)
for i, pair in enumerate(current_pairs):
error_code = None
resp = current_batch._responses[i]
resp = responses[i]
if resp.status_code >= 400:
error_code = resp.status_code
final_results.append((pair[0], pair[1], error_code))
Expand Down
53 changes: 53 additions & 0 deletions sdks/python/apache_beam/io/gcp/gcsio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,59 @@ def test_copy(self):
'gs://gcsio-test/non-existent',
'gs://gcsio-test/non-existent-destination')

@mock.patch('apache_beam.io.gcp.gcsio.MAX_BATCH_OPERATION_SIZE', 3)
@mock.patch('time.sleep', mock.Mock())
def test_copy_batch(self):
src_dest_pairs = [
(f'gs://source_bucket/file{i}.txt', f'gs://dest_bucket/file{i}.txt')
for i in range(7)
]

def _fake_responses(status_codes):
return mock.Mock(
__enter__=mock.Mock(),
__exit__=mock.Mock(),
_responses=[
mock.Mock(
**{
'json.return_value': {
'error': {
'message': 'error'
}
},
'request.method': 'BATCH',
'request.url': 'contentid://None',
},
status_code=code,
) for code in status_codes
],
)

gcs_io = gcsio.GcsIO(
storage_client=mock.Mock(
batch=mock.Mock(
side_effect=[
_fake_responses([200, 404, 429]),
_fake_responses([429]),
_fake_responses([429]),
_fake_responses([200]),
_fake_responses([200, 429, 200]),
_fake_responses([200]),
_fake_responses([200]),
]),
))
results = gcs_io.copy_batch(src_dest_pairs)
expected = [
('gs://source_bucket/file0.txt', 'gs://dest_bucket/file0.txt', None),
('gs://source_bucket/file1.txt', 'gs://dest_bucket/file1.txt', 404),
('gs://source_bucket/file2.txt', 'gs://dest_bucket/file2.txt', None),
('gs://source_bucket/file3.txt', 'gs://dest_bucket/file3.txt', None),
('gs://source_bucket/file4.txt', 'gs://dest_bucket/file4.txt', None),
('gs://source_bucket/file5.txt', 'gs://dest_bucket/file5.txt', None),
('gs://source_bucket/file6.txt', 'gs://dest_bucket/file6.txt', None),
]
self.assertEqual(results, expected)

def test_copytree(self):
src_dir_name = 'gs://gcsio-test/source/'
dest_dir_name = 'gs://gcsio-test/dest/'
Expand Down

0 comments on commit 489ea9f

Please sign in to comment.