From afb7731d21c2b490368d07765e6959f90856a274 Mon Sep 17 00:00:00 2001 From: Haidong Lan Date: Wed, 11 Jan 2023 16:20:23 +0800 Subject: [PATCH 1/6] Strictly check ndim with external array when ndim is present in ndarray type annotation. --- python/taichi/lang/_ndarray.py | 2 +- python/taichi/lang/kernel_impl.py | 18 ++++++++++++++---- tests/python/test_ndarray.py | 3 --- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index f10a73455cbc0..8aec42fcfb276 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -137,7 +137,7 @@ def _ndarray_matrix_from_numpy(self, arr, as_vector): raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided") if tuple(self.arr.total_shape()) != tuple(arr.shape): raise ValueError( - f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided" + f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided" ) if not arr.flags.c_contiguous: arr = np.ascontiguousarray(arr) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index a4d48dfc41e2b..5ed8d93e17ddb 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -416,11 +416,21 @@ def extract_arg(arg, anno): shape = tuple(shape) element_shape = () if isinstance(anno.dtype, MatrixType): - if len(shape) < anno.dtype.ndim: - raise ValueError( - f"Invalid argument into ti.types.ndarray() - required element_dim={anno.dtype.ndim}, " - f"but the argument has only {len(shape)} dimensions") + if anno.ndim is not None: + if len(shape) != anno.dtype.ndim + anno.ndim: + raise ValueError( + f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim} element_dim={anno.dtype.ndim}, " + f"but the argument has only {len(shape)} dimensions") + else: + if len(shape) < anno.dtype.ndim: + raise ValueError( + f"Invalid argument into ti.types.ndarray() - required element_dim={anno.dtype.ndim}, " + f"but the argument has only {len(shape)} dimensions") element_shape = shape[-anno.dtype.ndim:] + if element_shape != anno.dtype.get_shape(): + raise ValueError( + f"Invalid argument into ti.types.ndarray() - required element_shape={anno.dtype.get_shape()}, " + f"but the argument has element shape of {element_shape}") return to_taichi_type( arg.dtype), len(shape), element_shape, Layout.AOS if isinstance(anno, sparse_matrix_builder): diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index 01e1d5138b402..32d1fba4bd3e9 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -483,9 +483,6 @@ def func(a: ti.types.ndarray(ti.types.vector(n=10, dtype=ti.i32))): v = np.zeros((6, 10), dtype=np.int32) func(v) assert impl.get_runtime().get_num_compiled_functions() == 1 - v = np.zeros((6, 11), dtype=np.int32) - func(v) - assert impl.get_runtime().get_num_compiled_functions() == 2 @test_utils.test(arch=supported_archs_taichi_ndarray) From 444be47744798e07884a2fc72833b38919b4642f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Jan 2023 08:25:38 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/kernel_impl.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 5ed8d93e17ddb..5224a7c059e07 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -420,17 +420,20 @@ def extract_arg(arg, anno): if len(shape) != anno.dtype.ndim + anno.ndim: raise ValueError( f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim} element_dim={anno.dtype.ndim}, " - f"but the argument has only {len(shape)} dimensions") + f"but the argument has only {len(shape)} dimensions" + ) else: if len(shape) < anno.dtype.ndim: raise ValueError( f"Invalid argument into ti.types.ndarray() - required element_dim={anno.dtype.ndim}, " - f"but the argument has only {len(shape)} dimensions") + f"but the argument has only {len(shape)} dimensions" + ) element_shape = shape[-anno.dtype.ndim:] if element_shape != anno.dtype.get_shape(): raise ValueError( f"Invalid argument into ti.types.ndarray() - required element_shape={anno.dtype.get_shape()}, " - f"but the argument has element shape of {element_shape}") + f"but the argument has element shape of {element_shape}" + ) return to_taichi_type( arg.dtype), len(shape), element_shape, Layout.AOS if isinstance(anno, sparse_matrix_builder): From 4dcdc9780f1832712ab872a0b991f266cd262785 Mon Sep 17 00:00:00 2001 From: Haidong Lan Date: Thu, 12 Jan 2023 10:53:42 +0800 Subject: [PATCH 3/6] Skip check when element shape has None --- python/taichi/lang/kernel_impl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 5224a7c059e07..17967b3218474 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -429,9 +429,10 @@ def extract_arg(arg, anno): f"but the argument has only {len(shape)} dimensions" ) element_shape = shape[-anno.dtype.ndim:] - if element_shape != anno.dtype.get_shape(): + anno_element_shape = anno.dtype.get_shape() + if None not in anno_element_shape and element_shape != anno_element_shape: raise ValueError( - f"Invalid argument into ti.types.ndarray() - required element_shape={anno.dtype.get_shape()}, " + f"Invalid argument into ti.types.ndarray() - required element_shape={anno_element_shape}, " f"but the argument has element shape of {element_shape}" ) return to_taichi_type( From 45b4ebf8c429172626dec4127371b390a37c56aa Mon Sep 17 00:00:00 2001 From: Haidong Lan Date: Thu, 12 Jan 2023 11:48:25 +0800 Subject: [PATCH 4/6] Add tests for ndim check --- python/taichi/lang/kernel_impl.py | 9 ++++++++- tests/python/test_numpy.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 17967b3218474..5c7b808acaaa7 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -420,7 +420,7 @@ def extract_arg(arg, anno): if len(shape) != anno.dtype.ndim + anno.ndim: raise ValueError( f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim} element_dim={anno.dtype.ndim}, " - f"but the argument has only {len(shape)} dimensions" + f"but the argument has {len(shape)} dimensions" ) else: if len(shape) < anno.dtype.ndim: @@ -435,6 +435,13 @@ def extract_arg(arg, anno): f"Invalid argument into ti.types.ndarray() - required element_shape={anno_element_shape}, " f"but the argument has element shape of {element_shape}" ) + elif anno.dtype is not None: + # User specified scalar dtype + if anno.ndim is not None and len(shape) != anno.ndim: + raise ValueError( + f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim}, " + f"but the argument has {len(shape)} dimensions" + ) return to_taichi_type( arg.dtype), len(shape), element_shape, Layout.AOS if isinstance(anno, sparse_matrix_builder): diff --git a/tests/python/test_numpy.py b/tests/python/test_numpy.py index 694f1720a21f0..cc3781d3d5d9b 100644 --- a/tests/python/test_numpy.py +++ b/tests/python/test_numpy.py @@ -249,3 +249,34 @@ def fill(img: ti.types.ndarray()): with pytest.raises(ValueError, match='Non contiguous numpy arrays are not supported'): fill(a) + + +@test_utils.test() +def test_numpy_ndarray_dim_check(): + @ti.kernel + def add_one_mat(arr : ti.types.ndarray(dtype=ti.math.mat3, ndim=2)): + for i in ti.grouped(arr): + arr[i] = arr[i] + 1.0 + + @ti.kernel + def add_one_scalar(arr : ti.types.ndarray(dtype=ti.f32, ndim=2)): + for i in ti.grouped(arr): + arr[i] = arr[i] + 1.0 + + a = np.zeros(shape=(2,2,3,3), dtype=np.float32) + b = np.zeros(shape=(2,2,2,3), dtype=np.float32) + c = np.zeros(shape=(2,2,3), dtype=np.float32) + d = np.zeros(shape=(2,2), dtype=np.float32) + add_one_mat(a) + add_one_scalar(d) + np.testing.assert_allclose(a, np.ones(shape=(2,2,3,3), dtype=np.float32)) + np.testing.assert_allclose(d, np.ones(shape=(2,2), dtype=np.float32)) + with pytest.raises(ValueError, + match=r'Invalid argument into ti.types.ndarray\(\) - required element_shape=\(.*\), but the argument has element shape of \(.*\)'): + add_one_mat(b) + with pytest.raises(ValueError, + match=r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2 element_dim=2, but the argument has 3 dimensions'): + add_one_mat(c) + with pytest.raises(ValueError, + match=r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2, but the argument has 4 dimensions'): + add_one_scalar(a) \ No newline at end of file From 731d45436b649598530f4971f5db629b863901a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Jan 2023 03:49:45 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/kernel_impl.py | 10 ++++---- tests/python/test_numpy.py | 40 +++++++++++++++++++------------ 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 5c7b808acaaa7..75cbc05c6679d 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -420,8 +420,7 @@ def extract_arg(arg, anno): if len(shape) != anno.dtype.ndim + anno.ndim: raise ValueError( f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim} element_dim={anno.dtype.ndim}, " - f"but the argument has {len(shape)} dimensions" - ) + f"but the argument has {len(shape)} dimensions") else: if len(shape) < anno.dtype.ndim: raise ValueError( @@ -438,10 +437,9 @@ def extract_arg(arg, anno): elif anno.dtype is not None: # User specified scalar dtype if anno.ndim is not None and len(shape) != anno.ndim: - raise ValueError( - f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim}, " - f"but the argument has {len(shape)} dimensions" - ) + raise ValueError( + f"Invalid argument into ti.types.ndarray() - required array has ndim={anno.ndim}, " + f"but the argument has {len(shape)} dimensions") return to_taichi_type( arg.dtype), len(shape), element_shape, Layout.AOS if isinstance(anno, sparse_matrix_builder): diff --git a/tests/python/test_numpy.py b/tests/python/test_numpy.py index cc3781d3d5d9b..476bcdc2f1dbc 100644 --- a/tests/python/test_numpy.py +++ b/tests/python/test_numpy.py @@ -254,29 +254,39 @@ def fill(img: ti.types.ndarray()): @test_utils.test() def test_numpy_ndarray_dim_check(): @ti.kernel - def add_one_mat(arr : ti.types.ndarray(dtype=ti.math.mat3, ndim=2)): + def add_one_mat(arr: ti.types.ndarray(dtype=ti.math.mat3, ndim=2)): for i in ti.grouped(arr): arr[i] = arr[i] + 1.0 @ti.kernel - def add_one_scalar(arr : ti.types.ndarray(dtype=ti.f32, ndim=2)): + def add_one_scalar(arr: ti.types.ndarray(dtype=ti.f32, ndim=2)): for i in ti.grouped(arr): arr[i] = arr[i] + 1.0 - a = np.zeros(shape=(2,2,3,3), dtype=np.float32) - b = np.zeros(shape=(2,2,2,3), dtype=np.float32) - c = np.zeros(shape=(2,2,3), dtype=np.float32) - d = np.zeros(shape=(2,2), dtype=np.float32) + a = np.zeros(shape=(2, 2, 3, 3), dtype=np.float32) + b = np.zeros(shape=(2, 2, 2, 3), dtype=np.float32) + c = np.zeros(shape=(2, 2, 3), dtype=np.float32) + d = np.zeros(shape=(2, 2), dtype=np.float32) add_one_mat(a) add_one_scalar(d) - np.testing.assert_allclose(a, np.ones(shape=(2,2,3,3), dtype=np.float32)) - np.testing.assert_allclose(d, np.ones(shape=(2,2), dtype=np.float32)) - with pytest.raises(ValueError, - match=r'Invalid argument into ti.types.ndarray\(\) - required element_shape=\(.*\), but the argument has element shape of \(.*\)'): + np.testing.assert_allclose(a, np.ones(shape=(2, 2, 3, 3), + dtype=np.float32)) + np.testing.assert_allclose(d, np.ones(shape=(2, 2), dtype=np.float32)) + with pytest.raises( + ValueError, + match= + r'Invalid argument into ti.types.ndarray\(\) - required element_shape=\(.*\), but the argument has element shape of \(.*\)' + ): add_one_mat(b) - with pytest.raises(ValueError, - match=r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2 element_dim=2, but the argument has 3 dimensions'): + with pytest.raises( + ValueError, + match= + r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2 element_dim=2, but the argument has 3 dimensions' + ): add_one_mat(c) - with pytest.raises(ValueError, - match=r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2, but the argument has 4 dimensions'): - add_one_scalar(a) \ No newline at end of file + with pytest.raises( + ValueError, + match= + r'Invalid argument into ti.types.ndarray\(\) - required array has ndim=2, but the argument has 4 dimensions' + ): + add_one_scalar(a) From 719176ded741703ef4353bfd2d12728d9aaf4a01 Mon Sep 17 00:00:00 2001 From: Haidong Lan Date: Thu, 12 Jan 2023 14:15:18 +0800 Subject: [PATCH 6/6] Clear field dim in docs --- docs/lang/articles/get-started/accelerate_pytorch.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/lang/articles/get-started/accelerate_pytorch.md b/docs/lang/articles/get-started/accelerate_pytorch.md index 4527bf1b6ce94..c30efa7fd82e5 100644 --- a/docs/lang/articles/get-started/accelerate_pytorch.md +++ b/docs/lang/articles/get-started/accelerate_pytorch.md @@ -180,9 +180,9 @@ The Taichi reference code is almost identical to its Python counterpart. And a g ```python @ti.kernel def taichi_forward_v0( - out: ti.types.ndarray(field_dim=3), - w: ti.types.ndarray(field_dim=3), - k: ti.types.ndarray(field_dim=3), + out: ti.types.ndarray(ndim=3), + w: ti.types.ndarray(ndim=3), + k: ti.types.ndarray(ndim=3), eps: ti.f32): for b, c, t in out: