diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index b15633832522f..7d82efb36a33e 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -125,6 +125,8 @@ def from_numpy(self, arr): ) if impl.current_cfg().ndarray_use_torch: self.arr = torch.from_numpy(arr).to(self.arr.dtype) + if impl.current_cfg().arch == _ti_core.Arch.cuda: + self.arr = self.arr.cuda() else: if hasattr(arr, 'contiguous'): arr = arr.contiguous()