Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Address UserWarning from Torch training #28004

Conversation

bveeramani
Copy link
Member

Why are these changes needed?

Related issue number

Fixes #28003

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@bveeramani bveeramani marked this pull request as draft August 18, 2022 22:07
Comment on lines +122 to 125
# NOTE: PyTorch raises a `UserWarning` if `ndarray` isn't writeable. See #28003.
if not ndarray.flags["WRITEABLE"]:
ndarray = np.copy(ndarray)
return torch.as_tensor(ndarray, dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that copying the ndarray here is the right solution, since we want to be able to reuse the shared memory tensor data buffers without incurring unnecessary copies. We're also under the expectation the user/model will not mutate these input tensors during training and inference, so I don't think that the copy is necessary for correctness.

What if we suppressed the warning instead via something like:

Suggested change
# NOTE: PyTorch raises a `UserWarning` if `ndarray` isn't writeable. See #28003.
if not ndarray.flags["WRITEABLE"]:
ndarray = np.copy(ndarray)
return torch.as_tensor(ndarray, dtype=dtype, device=device)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
tensor = torch.as_tensor(ndarray, dtype=dtype, device=device)
return tensor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we confident that the warning is harmless? If so, I think this is fine/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user tries to mutate the tensor (e.g. with in-place tensor operations) then this will be undefined behavior, but the common case is that this tensor data buffer will be left untouched in Plasma until we either (a) transfer the tensor to the GPU, or (b) we create a new tensor via an operation that creates a copy; in either case, we're not likely to mutate it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we should suppress the warning, though, since it's a useful signal to the user that they will need to .clone() it if they wish to mutate it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be great to suppress the repeated warnings (if that is somehow possible).

@bveeramani bveeramani self-assigned this Aug 30, 2022
@bveeramani
Copy link
Member Author

Closing this PR because I don't know how to suppress repeated warning with multiple workers

@bveeramani bveeramani closed this Sep 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AIR] Torch UserWarning during training
3 participants