Skip to content

Commit

Permalink
make IO functions more robust (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 20, 2024
1 parent 4f2c8ef commit c47df7c
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 137 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- (Optimization) Mark model input sizes as dynamic for `torch.compile()` to avoid recompile during evals or variable-sequence / batch size training. This doesn't seem to hurt throughput.
- Made HTTPS and GCS IO functions more robust.

## [v1.6.3](https://github.com/allenai/OLMo-core/releases/tag/v1.6.3) - 2024-11-15

Expand Down
265 changes: 128 additions & 137 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import re
import shutil
import time
from functools import wraps
from pathlib import Path
from typing import Any, Generator, Optional, Tuple, Union
from typing import Any, Callable, Generator, Optional, Tuple, Type, Union

try:
from functools import cache
Expand Down Expand Up @@ -355,23 +356,63 @@ def _format_bytes(num: Union[int, float], suffix="B") -> str:
return f"{num:.1f}Yi{suffix}"


def retriable(
max_attempts: int = 3,
retriable_errors: Tuple[Type[Exception], ...] = (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
),
retry_condition: Optional[Callable[[Exception], bool]] = None,
):
def decorator(func):
@wraps(func)
def new_func(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as exc:
if isinstance(exc, retriable_errors) or (
retry_condition is not None and retry_condition(exc)
):
if attempt >= max_attempts:
# When torch's DataLoader intercepts exceptions, it may try to re-raise them
# by recalling their constructor with a single message arg. Torch has some
# logic to deal with the absence of a single-parameter constructor, but it
# doesn't gracefully handle other possible failures in calling such a constructor
# This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
# in us losing the true exception info. To avoid this, we change the exception
# to a type that has a single-parameter constructor.
raise OLMoNetworkError(
f"'{func.__name__}' failed {max_attempts} attempts with: {exc}"
) from exc
else:
log.warning(
f"'{func.__name__}' failed attempt {attempt} with retriable error: {exc}"
)
_wait_before_retry(attempt)
else:
raise

return new_func

return decorator


######################
## HTTPS IO helpers ##
######################


@retriable()
def _http_file_size(url: str) -> int:
import requests

response = requests.head(url, allow_redirects=True)
content_length = response.headers.get("content-length")
assert content_length
return int(content_length)


@retriable()
def _http_get_bytes_range(url: str, bytes_start: int, num_bytes: int) -> bytes:
import requests

response = requests.get(
url, headers={"Range": f"bytes={bytes_start}-{bytes_start+num_bytes-1}"}
)
Expand All @@ -387,9 +428,8 @@ def _http_get_bytes_range(url: str, bytes_start: int, num_bytes: int) -> bytes:
return result


@retriable()
def _http_file_exists(url: str) -> bool:
import requests

response = requests.head(url)
if response.status_code == 404:
return False
Expand Down Expand Up @@ -433,6 +473,7 @@ def _get_gcs_conditional_retry():
return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"])


@retriable()
def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound

Expand All @@ -447,6 +488,7 @@ def _gcs_file_size(bucket_name: str, key: str) -> int:
return blob.size


@retriable()
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from google.api_core.exceptions import NotFound

Expand All @@ -462,6 +504,7 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes
)


@retriable()
def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
Expand All @@ -473,6 +516,25 @@ def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool =
blob.upload_from_filename(source, retry=_get_gcs_conditional_retry())


@retriable()
def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound

storage_client = _get_gcs_client()

prefix = prefix.strip("/")
if prefix:
prefix += "/"

try:
bucket = storage_client.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=prefix, retry=_get_gcs_retry())
for blob in blobs:
bucket.delete_blob(blob.name)
except NotFound:
return


def _gcs_list_directory(bucket_name: str, prefix: str) -> Generator[str, None, None]:
from google.api_core.exceptions import NotFound

Expand Down Expand Up @@ -504,24 +566,6 @@ def _gcs_list_directory(bucket_name: str, prefix: str) -> Generator[str, None, N
yield f"gs://{bucket_name}/{folder.strip('/')}"


def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound

storage_client = _get_gcs_client()

prefix = prefix.strip("/")
if prefix:
prefix += "/"

try:
bucket = storage_client.bucket(bucket_name)
blobs = bucket.list_blobs(prefix=prefix, retry=_get_gcs_retry())
for blob in blobs:
bucket.delete_blob(blob.name)
except NotFound:
return


###################
## S3 IO helpers ##
###################
Expand Down Expand Up @@ -588,146 +632,93 @@ def _get_s3_endpoint_url(scheme: str) -> Optional[str]:
raise NotImplementedError(f"Cannot get endpoint url for scheme {scheme}")


def _s3_file_size(scheme: str, bucket_name: str, key: str, max_attempts: int = 3) -> int:
from botocore.exceptions import ClientError
def _s3_retry_condition(err: Exception) -> bool:
import botocore.exceptions as boto_errors

err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
err = e

if attempt < max_attempts:
log.warning(
"%s failed attempt %d with retriable error: %s",
_s3_file_size.__name__,
attempt,
err,
)
_wait_before_retry(attempt)
return isinstance(
err, (boto_errors.ClientError, boto_errors.HTTPClientError, boto_errors.ConnectionError)
)


@retriable(retry_condition=_s3_retry_condition)
def _s3_file_size(scheme: str, bucket_name: str, key: str) -> int:
from botocore.exceptions import ClientError

raise OLMoNetworkError("Failed to get s3 file size") from err
try:
return _get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)["ContentLength"]
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
else:
raise


@retriable(retry_condition=_s3_retry_condition)
def _s3_get_bytes_range(
scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int, max_attempts: int = 3
scheme: str, bucket_name: str, key: str, bytes_start: int, num_bytes: int
) -> bytes:
from botocore.exceptions import ClientError, ConnectionError, HTTPClientError

err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
return (
_get_s3_client(scheme)
.get_object(
Bucket=bucket_name,
Key=key,
Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}",
)["Body"]
.read()
)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
err = e
except (HTTPClientError, ConnectionError) as e:
# ResponseStreamingError (subclass of HTTPClientError) can happen as
# a result of a failed read from the stream (http.client.IncompleteRead).
# Retrying can help in this case.
err = e

if attempt < max_attempts:
log.warning(
"%s failed attempt %d with retriable error: %s",
_s3_get_bytes_range.__name__,
attempt,
err,
)
_wait_before_retry(attempt)
from botocore.exceptions import ClientError

# When torch's DataLoader intercepts exceptions, it may try to re-raise them
# by recalling their constructor with a single message arg. Torch has some
# logic to deal with the absence of a single-parameter constructor, but it
# doesn't gracefully handle other possible failures in calling such a constructor
# This can cause an irrelevant exception (e.g. KeyError: 'error'), resulting
# in us losing the true exception info. To avoid this, we change the exception
# to a type that has a single-parameter constructor.
raise OLMoNetworkError("Failed to get bytes range from s3") from err
try:
return (
_get_s3_client(scheme)
.get_object(
Bucket=bucket_name,
Key=key,
Range=f"bytes={bytes_start}-{bytes_start + num_bytes - 1}",
)["Body"]
.read()
)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
raise FileNotFoundError(f"s3://{bucket_name}/{key}") from e
else:
raise


@retriable(retry_condition=_s3_retry_condition)
def _s3_upload(
source: Path,
scheme: str,
bucket_name: str,
key: str,
save_overwrite: bool = False,
max_attempts: int = 3,
):
from botocore.exceptions import ClientError

err: Optional[Exception] = None
if not save_overwrite:
for attempt in range(1, max_attempts + 1):
try:
_get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
err = None
break
err = e

if attempt < max_attempts:
log.warning(
"%s failed attempt %d with retriable error: %s",
_s3_upload.__name__,
attempt,
err,
)
_wait_before_retry(attempt)

if err is not None:
raise OLMoNetworkError("Failed to check object existence during s3 upload") from err
try:
_get_s3_client(scheme).head_object(Bucket=bucket_name, Key=key)
raise FileExistsError(
f"s3://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
pass
else:
raise

try:
_get_s3_client(scheme).upload_file(source, bucket_name, key)
except ClientError as e:
raise OLMoNetworkError("Failed to upload to s3") from e
_get_s3_client(scheme).upload_file(source, bucket_name, key)


def _s3_clear_directory(scheme: str, bucket_name: str, prefix: str, max_attempts: int = 3):
@retriable(retry_condition=_s3_retry_condition)
def _s3_clear_directory(scheme: str, bucket_name: str, prefix: str):
from botocore.exceptions import ClientError

if not prefix.endswith("/"):
prefix = prefix + "/"

err: Optional[Exception] = None
for attempt in range(1, max_attempts + 1):
try:
for o in _get_s3_client(scheme).list_objects_v2(Bucket=bucket_name, Prefix=prefix)[
"Contents"
]:
_get_s3_client(scheme).delete_object(Bucket=bucket_name, Key=o["Key"])
return
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
return
err = e
except KeyError:
try:
for o in _get_s3_client(scheme).list_objects_v2(Bucket=bucket_name, Prefix=prefix)[
"Contents"
]:
_get_s3_client(scheme).delete_object(Bucket=bucket_name, Key=o["Key"])
return
except ClientError as e:
if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
return

if attempt < max_attempts:
log.warning(
"%s failed attempt %d with retriable error: %s", _s3_upload.__name__, attempt, err
)
_wait_before_retry(attempt)

raise OLMoNetworkError("Failed to remove S3 directory") from err
except KeyError:
return


def _s3_list_directory(scheme: str, bucket_name: str, prefix: str) -> Generator[str, None, None]:
Expand Down
Loading

0 comments on commit c47df7c

Please sign in to comment.