Skip to content

Commit

Permalink
[misc] Add conversions for unsigned types, torch > 2.3.0 (#8528)
Browse files Browse the repository at this point in the history
### Brief Summary

pytorch 2.3.0 now has unsigned datatypes, add conversions for those from
taichi unsigned types.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bob Cao <[email protected]>
  • Loading branch information
3 people authored Jun 23, 2024
1 parent c40574a commit e4b0bf0
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,18 @@ def to_pytorch_type(dt):
return torch.uint8
if dt == f16:
return torch.float16

if dt in (u16, u32, u64):
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
assert False
if hasattr(torch, "uint16"):
if dt == u16:
return torch.uint16
if dt == u32:
return torch.uint32
if dt == u64:
return torch.uint64
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")

raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")


def to_paddle_type(dt):
Expand Down Expand Up @@ -266,8 +275,16 @@ def to_taichi_type(dt):
return u8
if dt == torch.float16:
return f16
if dt in (u16, u32, u64):
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")

if hasattr(torch, "uint16"):
if dt == torch.uint16:
return u16
if dt == torch.uint32:
return u32
if dt == torch.uint64:
return u64

raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")

if has_paddle():
import paddle # pylint: disable=C0415
Expand Down

0 comments on commit e4b0bf0

Please sign in to comment.