diff --git a/python/taichi/_kernels.py b/python/taichi/_kernels.py index 2d4b52b82359b..4b1d86302f84a 100644 --- a/python/taichi/_kernels.py +++ b/python/taichi/_kernels.py @@ -34,7 +34,7 @@ def fill_ndarray(ndarray: ndarray_type.ndarray(), val: template()): @kernel def fill_ndarray_matrix(ndarray: ndarray_type.ndarray(), val: template()): for I in grouped(ndarray): - ndarray[I].fill(val) + ndarray[I] = val @kernel diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index 5cdbc345af8c9..d437145371eb4 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -6,6 +6,7 @@ from taichi.lang.util import cook_dtype, python_scope, to_numpy_type from taichi.types import primitive_types from taichi.types.ndarray_type import NdarrayTypeMetadata +from taichi.types.compound_types import TensorType from taichi.types.utils import is_real, is_signed @@ -69,6 +70,8 @@ def fill(self, val): """ if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64: self._fill_by_kernel(val) + elif isinstance(self.element_type, TensorType): + self._fill_by_kernel(val) elif self.dtype == primitive_types.f32: impl.get_runtime().prog.fill_float(self.arr, val) elif self.dtype == primitive_types.i32: diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 4d2123b3e65ec..7d7fb13e531d9 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1674,7 +1674,19 @@ def __deepcopy__(self, memo=None): def _fill_by_kernel(self, val): from taichi._kernels import fill_ndarray_matrix # pylint: disable=C0415 - fill_ndarray_matrix(self, val) + shape = self.element_type.shape() + n = shape[0] + m = 1 + if len(shape) > 1: + m = shape[1] + + prim_dtype = self.element_type.element_type() + matrix_type = MatrixType(n, m, len(shape), prim_dtype) + if isinstance(val, Matrix): + value = val + else: + value = matrix_type(val) + fill_ndarray_matrix(self, value) @python_scope def __repr__(self): @@ -1770,7 +1782,14 @@ def __deepcopy__(self, memo=None): def _fill_by_kernel(self, val): from taichi._kernels import fill_ndarray_matrix # pylint: disable=C0415 - fill_ndarray_matrix(self, val) + shape = self.element_type.shape() + prim_dtype = self.element_type.element_type() + vector_type = VectorType(shape[0], prim_dtype) + if isinstance(val, Vector): + value = val + else: + value = vector_type(val) + fill_ndarray_matrix(self, value) @python_scope def __repr__(self): diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index 15a1c971bb122..e3c89fa3e0f00 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -848,3 +848,22 @@ def test_read_write_f64_python_scope(): y = ti.ndarray(dtype=ti.math.vec2, shape=2) y[0] = [1.0, 2.0] assert y[0] == [1.0, 2.0] + + +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_fill(): + vec2 = ti.types.vector(2, ti.f32) + x_vec = ti.ndarray(vec2, (512, 512)) + x_vec.fill(1.0) + assert (x_vec[2, 2] == [1.0, 1.0]).all() + + x_vec.fill(vec2(2.0, 4.0)) + assert (x_vec[3, 3] == [2.0, 4.0]).all() + + mat2x2 = ti.types.matrix(2, 2, ti.f32) + x_mat = ti.ndarray(mat2x2, (512, 512)) + x_mat.fill(2.0) + assert (x_mat[2, 2] == [[2.0, 2.0], [2.0, 2.0]]).all() + + x_mat.fill(mat2x2([[2.0, 4.0], [1.0, 3.0]])) + assert (x_mat[3, 3] == [[2.0, 4.0], [1.0, 3.0]]).all()