From 800ff2b61c2403950da46d8d826b5022d85d2f84 Mon Sep 17 00:00:00 2001 From: Ailing Date: Fri, 20 Jan 2023 10:30:56 +0800 Subject: [PATCH] [aot] Fix ndarray aot with information from type hints (#7214) Issue: fixes #7172 ### Brief Summary Ideally we should reconstruct the dtype to the Tensortype from taichi_core instead of python ones but that can be a separate PR. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/aot/utils.py | 71 ++++++++++--------- python/taichi/lang/matrix.py | 4 ++ .../cpp/aot/python_scripts/graph_aot_test.py | 17 ++--- tests/python/test_aot.py | 23 ++++++ 4 files changed, 71 insertions(+), 44 deletions(-) diff --git a/python/taichi/aot/utils.py b/python/taichi/aot/utils.py index c8e06b57a1b4f..544f3ba88374c 100644 --- a/python/taichi/aot/utils.py +++ b/python/taichi/aot/utils.py @@ -1,7 +1,8 @@ from taichi.lang._ndarray import ScalarNdarray from taichi.lang._texture import Texture from taichi.lang.exception import TaichiCompilationError -from taichi.lang.matrix import Matrix, MatrixNdarray, MatrixType, VectorNdarray +from taichi.lang.matrix import (Matrix, MatrixNdarray, MatrixType, + VectorNdarray, VectorType) from taichi.lang.util import cook_dtype from taichi.types.annotations import template from taichi.types.ndarray_type import NdarrayType @@ -11,9 +12,14 @@ def check_type_match(lhs, rhs): - if cook_dtype(lhs) == cook_dtype(rhs): - return True - return False + if isinstance(lhs, MatrixType) and isinstance(rhs, MatrixType): + return lhs.n == rhs.n and lhs.m == rhs.m and (lhs.dtype == rhs.dtype + or lhs.dtype is None + or rhs.dtype is None) + if isinstance(lhs, MatrixType) or isinstance(rhs, MatrixType): + return False + + return cook_dtype(lhs) == cook_dtype(rhs) def produce_injected_args_from_template(kernel, template_args): @@ -43,52 +49,49 @@ def produce_injected_args(kernel, symbolic_args=None): for i, arg in enumerate(kernel.arguments): anno = arg.annotation if isinstance(anno, NdarrayType): - # TODO(Haidong) we should always use MatrixType and get rid of the element shapes if symbolic_args is not None: - element_shape = tuple(symbolic_args[i].element_shape) - element_dim = len(element_shape) + # TODO: reconstruct dtype to be TensorType from taichi_core instead of the Python ones + element_dim = len(symbolic_args[i].element_shape) + if element_dim == 0 or symbolic_args[i].element_shape == (1, ): + dtype = symbolic_args[i].dtype() + elif element_dim == 1: + dtype = VectorType(symbolic_args[i].element_shape[0], + symbolic_args[i].dtype()) + elif element_dim == 2: + dtype = MatrixType(symbolic_args[i].element_shape[0], + symbolic_args[i].element_shape[1], 2, + symbolic_args[i].dtype()) + else: + raise TaichiCompilationError('Not supported') ndim = symbolic_args[i].field_dim - dtype = symbolic_args[i].dtype() else: - element_shape = anno.dtype.get_shape() - element_dim = anno.dtype.ndim ndim = anno.ndim dtype = anno.dtype - if element_shape is None or ndim is None: - raise TaichiCompilationError( - 'Please either specify both `element_shape` and `ndim` ' - 'in the param annotation, or provide an example ' - f'ndarray for param={arg.name}') if anno.ndim is not None and ndim != anno.ndim: raise TaichiCompilationError( f'{ndim} from Arg {arg.name} doesn\'t match kernel\'s annotated ndim={anno.ndim}' ) - anno_dtype = anno.dtype - if isinstance(anno_dtype, MatrixType): - anno_dtype = anno.dtype.dtype - if anno_dtype is not None: - if not check_type_match(dtype, anno_dtype): - raise TaichiCompilationError( - f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno_dtype.to_string()}' - ) - if element_dim is None or element_dim == 0 or element_shape == ( - 1, ): - injected_args.append(ScalarNdarray(dtype, (2, ) * ndim)) - elif element_dim == 1: + if anno.dtype is not None and not check_type_match( + dtype, anno.dtype): + raise TaichiCompilationError( + f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.dtype.to_string()}' + ) + + if isinstance(dtype, VectorType): injected_args.append( - VectorNdarray(element_shape[0], - dtype=dtype, + VectorNdarray(dtype.n, + dtype=dtype.dtype, shape=(2, ) * ndim)) - elif element_dim == 2: + elif isinstance(dtype, MatrixType): injected_args.append( - MatrixNdarray(element_shape[0], - element_shape[1], - dtype=dtype, + MatrixNdarray(dtype.n, + dtype.m, + dtype=dtype.dtype, shape=(2, ) * ndim)) else: - raise RuntimeError('') + injected_args.append(ScalarNdarray(dtype, (2, ) * ndim)) elif isinstance(anno, RWTextureType): texture_shape = (2, ) * anno.num_dimensions fmt = anno.fmt diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2cc5c75c352ac..81691742e36b9 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1534,6 +1534,10 @@ def get_shape(self): return (self.n, ) return (self.n, self.m) + def to_string(self): + dtype_str = self.dtype.to_string() if self.dtype is not None else '' + return f'MatrixType[{self.n},{self.m}, {dtype_str}]' + class VectorType(MatrixType): def __init__(self, n, dtype): diff --git a/tests/cpp/aot/python_scripts/graph_aot_test.py b/tests/cpp/aot/python_scripts/graph_aot_test.py index 66c711a8446fc..ae9afd0380e52 100644 --- a/tests/cpp/aot/python_scripts/graph_aot_test.py +++ b/tests/cpp/aot/python_scripts/graph_aot_test.py @@ -16,12 +16,9 @@ def run0(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)): arr[i] += base + i @ti.kernel - def run1(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)): - for i in arr: - arr[i] += base + i - - @ti.kernel - def run2(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)): + def run1(base: int, arr: ti.types.ndarray(ndim=1, + dtype=ti.types.vector(1, + ti.i32))): for i in arr: arr[i] += base + i @@ -41,12 +38,12 @@ def run2(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)): g_builder = ti.graph.GraphBuilder() g_builder.dispatch(run0, base0, arr0) - g_builder.dispatch(run1, base1, arr0) - g_builder.dispatch(run2, base2, arr0) + g_builder.dispatch(run0, base1, arr0) + g_builder.dispatch(run0, base2, arr0) - g_builder.dispatch(run0, base0, arr1) + g_builder.dispatch(run1, base0, arr1) g_builder.dispatch(run1, base1, arr1) - g_builder.dispatch(run2, base2, arr1) + g_builder.dispatch(run1, base2, arr1) run_graph = g_builder.compile() diff --git a/tests/python/test_aot.py b/tests/python/test_aot.py index c99d721830715..0c455dd833a4a 100644 --- a/tests/python/test_aot.py +++ b/tests/python/test_aot.py @@ -527,6 +527,29 @@ def run(arr: ti.types.ndarray(), val1: ti.f32, val2: ti.template()): assert args_count == 2, res # `arr` and `val1` +@test_utils.test(arch=[ti.opengl, ti.vulkan]) +def test_aot_ndarray_without_template_args(): + @ti.kernel + def kernel1(arr: ti.types.ndarray(dtype=ti.f32, ndim=2)): + for I in ti.grouped(arr): + arr[I] = 0. + + @ti.kernel + def kernel2(arr: ti.types.ndarray(dtype=ti.math.vec2, ndim=2)): + for I in ti.grouped(arr): + arr[I] = 0. + + @ti.kernel + def kernel3(arr: ti.types.ndarray(dtype=ti.math.mat2, ndim=2)): + for I in ti.grouped(arr): + arr[I] = 0. + + m = ti.aot.Module() + m.add_kernel(kernel1) + m.add_kernel(kernel2) + m.add_kernel(kernel3) + + @test_utils.test(arch=[ti.opengl, ti.vulkan]) def test_archive(): density = ti.field(float, shape=(4, 4))