Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit cast to uint8 when bool inputs passed to argsort in MAP #983

Merged
merged 7 commits into from
Apr 24, 2022

Conversation

krshrimali
Copy link
Contributor

@krshrimali krshrimali commented Apr 24, 2022

What does this PR do?

Fixes #981

I can confirm that this is a regression from release 0.7.2, IMO: an explicit cast from torch.bool to torch.uint8 (on CUDA only) while applying torch.argsort should fix this. Looks like PyTorch doesn't support sorting for boolean dtypes on CUDA devices:

// FIXME: remove this check once cub sort supports bool
TORCH_CHECK(self_dtype != ScalarType::Bool,
  "Sort currently does not support bool dtype on CUDA.");

https://github.com/pytorch/pytorch/blob/1a7e43be141ce01469d7605075cb1008bf19abd7/aten/src/ATen/native/cuda/Sort.cpp#L80

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Apr 24, 2022

Codecov Report

Merging #983 (98c517a) into master (f8ef656) will decrease coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff          @@
##           master   #983   +/-   ##
=====================================
- Coverage      95%    95%   -0%     
=====================================
  Files         177    177           
  Lines        7513   7514    +1     
=====================================
- Hits         7139   7127   -12     
- Misses        374    387   +13     

@Borda Borda changed the title Explicit cast to uint8 when bool inputs passed to argsort in MAP (CUDA only) Explicit cast to uint8 when bool inputs passed to argsort in MAP Apr 24, 2022
@Borda Borda added the bug / fix Something isn't working label Apr 24, 2022
@Borda Borda added this to the v0.8 milestone Apr 24, 2022
@Borda Borda enabled auto-merge (squash) April 24, 2022 14:00
@mergify mergify bot added the ready label Apr 24, 2022
Copy link
Contributor

@stancld stancld left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Borda Borda merged commit a5f2e03 into master Apr 24, 2022
@Borda Borda deleted the fix/sorting_error/cuda branch April 24, 2022 15:03
@Chris-hughes10
Copy link

Thanks for getting this turned around so quickly! As #981 is a blocker for the project my team are working on, is there an estimate on when this will be released?

@Borda
Copy link
Member

Borda commented Apr 26, 2022

is there an estimate on when this will be released?

we can do one during this week, just would also include #985 🐰
pls feel free to ping me on slack if I forget... 😇

Borda pushed a commit that referenced this pull request Apr 26, 2022
…983)

* Explicit cast to torch.uint8 for bool types on CUDA
* Add changelog entry

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <[email protected]>
(cherry picked from commit a5f2e03)
@Borda Borda added this to the v0.8 milestone May 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Sort currently does not support bool dtype on CUDA: Regression in version 0.7.3 +
4 participants