Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] Suppress "NumPy array is not writable" error in torch conversion #29808

Merged
merged 2 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -132,7 +133,15 @@ def convert_ndarray_to_torch_tensor(
Returns: A Torch Tensor.
"""
ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)
return torch.as_tensor(ndarray, dtype=dtype, device=device)

# The numpy array is not always writeable as it can come from the Ray object store.
# Numpy will throw a verbose warning here, which we suppress, as we don't write
# to the tensors. We also don't want to copy the array to avoid memory overhead.
# Original warning: https://github.com/pytorch/pytorch/blob/v1.13.0/
# torch/csrc/utils/tensor_numpy.cpp#L198-L206
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.as_tensor(ndarray, dtype=dtype, device=device)


def convert_ndarray_batch_to_torch_tensor_batch(
Expand Down
12 changes: 12 additions & 0 deletions python/ray/air/tests/test_data_batch_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pyarrow as pa

from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.air.util.data_batch_conversion import (
convert_batch_type_to_pandas,
Expand Down Expand Up @@ -185,6 +186,17 @@ def test_numpy_object_pandas():
)


@pytest.mark.parametrize("writable", [False, True])
def test_numpy_to_tensor_warning(writable):
input_data = np.array([[1, 2, 3]], dtype=int)
input_data.setflags(write=writable)

with pytest.warns(None) as record:
tensor = convert_ndarray_to_torch_tensor(input_data)
assert not record.list, [w.message for w in record.list]
assert tensor is not None


def test_dict_fail():
input_data = {"x": "y"}
with pytest.raises(ValueError):
Expand Down