diff --git a/python/taichi/lang/util.py b/python/taichi/lang/util.py index a853daea64a65..1a70f12b74bd4 100644 --- a/python/taichi/lang/util.py +++ b/python/taichi/lang/util.py @@ -166,8 +166,18 @@ def to_pytorch_type(dt): return torch.uint8 if dt == f16: return torch.float16 + if dt in (u16, u32, u64): - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") + if hasattr(torch, "uint16"): + if dt == u16: + return torch.uint16 + if dt == u32: + return torch.uint32 + if dt == u64: + return torch.uint64 + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") + + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") assert False @@ -266,9 +276,18 @@ def to_taichi_type(dt): return u8 if dt == torch.float16: return f16 - if dt in (u16, u32, u64): - raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.") - + + if hasattr(torch, "uint16"): + if dt == torch.uint16: + return u16 + if dt == torch.uint32: + return u32 + if dt == torch.uint64: + return u64 + + raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.") + + if has_paddle(): import paddle # pylint: disable=C0415