Skip to content

Commit

Permalink
[air] Suppress "NumPy array is not writable" error in torch conversion (
Browse files Browse the repository at this point in the history
ray-project#29808)

When we convert NumPy arrays to torch tesnors in Ray Data, we run into a verbose Numpy error:

    The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior.

Since we don't write to numpy tensors we suppress this warning in the conversion. The alternative of copying the array would duplicate our memory usage.

Signed-off-by: Kai Fricke <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
krfricke authored and WeichenXu123 committed Dec 19, 2022
1 parent cc10a0d commit 1ed460d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Dict, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -132,7 +133,15 @@ def convert_ndarray_to_torch_tensor(
Returns: A Torch Tensor.
"""
ndarray = _unwrap_ndarray_object_type_if_needed(ndarray)
return torch.as_tensor(ndarray, dtype=dtype, device=device)

# The numpy array is not always writeable as it can come from the Ray object store.
# Numpy will throw a verbose warning here, which we suppress, as we don't write
# to the tensors. We also don't want to copy the array to avoid memory overhead.
# Original warning: https://github.com/pytorch/pytorch/blob/v1.13.0/
# torch/csrc/utils/tensor_numpy.cpp#L198-L206
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return torch.as_tensor(ndarray, dtype=dtype, device=device)


def convert_ndarray_batch_to_torch_tensor_batch(
Expand Down
12 changes: 12 additions & 0 deletions python/ray/air/tests/test_data_batch_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pyarrow as pa

from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.air.util.data_batch_conversion import (
convert_batch_type_to_pandas,
Expand Down Expand Up @@ -185,6 +186,17 @@ def test_numpy_object_pandas():
)


@pytest.mark.parametrize("writable", [False, True])
def test_numpy_to_tensor_warning(writable):
input_data = np.array([[1, 2, 3]], dtype=int)
input_data.setflags(write=writable)

with pytest.warns(None) as record:
tensor = convert_ndarray_to_torch_tensor(input_data)
assert not record.list, [w.message for w in record.list]
assert tensor is not None


def test_dict_fail():
input_data = {"x": "y"}
with pytest.raises(ValueError):
Expand Down

0 comments on commit 1ed460d

Please sign in to comment.