diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index 6f00bf7255e5..0b3a0ea87f56 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -1,3 +1,4 @@ +import warnings from typing import Dict, List, Optional, Union import numpy as np @@ -132,7 +133,15 @@ def convert_ndarray_to_torch_tensor( Returns: A Torch Tensor. """ ndarray = _unwrap_ndarray_object_type_if_needed(ndarray) - return torch.as_tensor(ndarray, dtype=dtype, device=device) + + # The numpy array is not always writeable as it can come from the Ray object store. + # Numpy will throw a verbose warning here, which we suppress, as we don't write + # to the tensors. We also don't want to copy the array to avoid memory overhead. + # Original warning: https://github.com/pytorch/pytorch/blob/v1.13.0/ + # torch/csrc/utils/tensor_numpy.cpp#L198-L206 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return torch.as_tensor(ndarray, dtype=dtype, device=device) def convert_ndarray_batch_to_torch_tensor_batch( diff --git a/python/ray/air/tests/test_data_batch_conversion.py b/python/ray/air/tests/test_data_batch_conversion.py index d7e2307a23db..3294cb11e502 100644 --- a/python/ray/air/tests/test_data_batch_conversion.py +++ b/python/ray/air/tests/test_data_batch_conversion.py @@ -4,6 +4,7 @@ import numpy as np import pyarrow as pa +from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor from ray.air.constants import TENSOR_COLUMN_NAME from ray.air.util.data_batch_conversion import ( convert_batch_type_to_pandas, @@ -185,6 +186,17 @@ def test_numpy_object_pandas(): ) +@pytest.mark.parametrize("writable", [False, True]) +def test_numpy_to_tensor_warning(writable): + input_data = np.array([[1, 2, 3]], dtype=int) + input_data.setflags(write=writable) + + with pytest.warns(None) as record: + tensor = convert_ndarray_to_torch_tensor(input_data) + assert not record.list, [w.message for w in record.list] + assert tensor is not None + + def test_dict_fail(): input_data = {"x": "y"} with pytest.raises(ValueError):