Skip to content

Commit

Permalink
Add conversions for unsigned types, torch > 2.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
oliver-batchelor authored and taichi-gardener committed Jun 22, 2024
1 parent c40574a commit d2a0365
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,18 @@ def to_pytorch_type(dt):
return torch.uint8
if dt == f16:
return torch.float16

if dt in (u16, u32, u64):
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
if hasattr(torch, "uint16"):
if dt == u16:
return torch.uint16
if dt == u32:
return torch.uint32
if dt == u64:
return torch.uint64
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")

raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
assert False


Expand Down Expand Up @@ -266,9 +276,18 @@ def to_taichi_type(dt):
return u8
if dt == torch.float16:
return f16
if dt in (u16, u32, u64):
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")


if hasattr(torch, "uint16"):
if dt == torch.uint16:
return u16
if dt == torch.uint32:
return u32
if dt == torch.uint64:
return u64

raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")


if has_paddle():
import paddle # pylint: disable=C0415

Expand Down

0 comments on commit d2a0365

Please sign in to comment.