Skip to content

Commit

Permalink
[lang] MatrixType bug fix: Demote ret_type for ArgLoadStmt after scal…
Browse files Browse the repository at this point in the history
…arization (#6433)

Issue: #5819

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jim19930609 and pre-commit-ci[bot] authored Oct 31, 2022
1 parent 0cc0489 commit 02309bf
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 13 deletions.
5 changes: 4 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
Vector, _IntermediateMatrix,
Vector, VectorNdarray, _IntermediateMatrix,
_MatrixFieldElement, make_matrix)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
Expand Down Expand Up @@ -824,6 +824,9 @@ def ndarray(dtype, shape):
if dtype in all_types:
return ScalarNdarray(dtype, shape)
if isinstance(dtype, MatrixType):
if dtype.ndim == 1:
return VectorNdarray(dtype.n, dtype.dtype, shape)

return MatrixNdarray(dtype.n, dtype.m, dtype.dtype, shape)

raise TaichiRuntimeError(
Expand Down
19 changes: 14 additions & 5 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,15 @@ class Scalarize : public BasicStmtVisitor {
return;
}

// BinaryOpExpression::type_check() should have taken care of the
// broadcasting and neccessary conversions. So we simply add an assertion
// here to make sure that the operands are of the same shape and dtype
TI_ASSERT(lhs_dtype == rhs_dtype);

if (lhs_dtype->is<TensorType>() && rhs_dtype->is<TensorType>()) {
// Scalarization for LoadStmt should have already replaced both operands
// to MatrixInitStmt
TI_ASSERT(stmt->lhs->is<MatrixInitStmt>());
TI_ASSERT(stmt->rhs->is<MatrixInitStmt>());

TI_ASSERT(lhs_dtype->cast<TensorType>()->get_shape() ==
rhs_dtype->cast<TensorType>()->get_shape());

auto lhs_matrix_init_stmt = stmt->lhs->cast<MatrixInitStmt>();
std::vector<Stmt *> lhs_vals = lhs_matrix_init_stmt->values;

Expand Down Expand Up @@ -568,6 +566,17 @@ class ScalarizePointers : public BasicStmtVisitor {
}
}

void visit(ArgLoadStmt *stmt) override {
auto ret_type = stmt->ret_type.ptr_removed().get_element_type();
auto arg_load =
std::make_unique<ArgLoadStmt>(stmt->arg_id, ret_type, stmt->is_ptr);

stmt->replace_usages_with(arg_load.get());

modifier_.insert_before(stmt, std::move(arg_load));
modifier_.erase(stmt);
}

private:
using BasicStmtVisitor::visit;
};
Expand Down
16 changes: 14 additions & 2 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,7 @@ def bar():
bar()


@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True)
def test_func_ndarray_arg():
def _test_func_ndarray_arg():
vec3 = ti.types.vector(3, ti.f32)

@ti.func
Expand Down Expand Up @@ -383,6 +382,19 @@ def test_error(x: ti.types.ndarray(field_dim=1)):
test_error(arr)


@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True)
def test_func_ndarray_arg():
_test_func_ndarray_arg()


@test_utils.test(arch=[ti.cpu, ti.cuda],
debug=True,
real_matrix=True,
real_matrix_scalarize=True)
def test_func_ndarray_arg_matrix_scalarize():
_test_func_ndarray_arg()


def _test_func_matrix_arg():
vec3 = ti.types.vector(3, ti.f32)

Expand Down
18 changes: 15 additions & 3 deletions tests/python/test_ggui.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,7 @@ def test_get_camera_view_and_projection_matrix():
assert (abs(projection_matrix[3, 2] - 1.0001000e-1) <= 1e-5)


@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available")
@test_utils.test(arch=supported_archs)
def test_fetching_color_attachment():
def _test_fetching_color_attachment():
window = ti.ui.Window('test', (640, 480), show_window=False)
canvas = window.get_canvas()

Expand All @@ -389,6 +387,20 @@ def render():
window.destroy()


@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available")
@test_utils.test(arch=supported_archs)
def test_fetching_color_attachment():
_test_fetching_color_attachment()


@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available")
@test_utils.test(arch=supported_archs,
real_matrix=True,
real_matrix_scalarize=True)
def test_fetching_color_attachment_matrix_scalarize():
_test_fetching_color_attachment()


@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available")
@test_utils.test(arch=supported_archs, exclude=[(ti.vulkan, "Darwin")])
def test_fetching_depth_attachment():
Expand Down
46 changes: 44 additions & 2 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ def test_ndarray_compound_element():

vec3 = ti.types.vector(3, ti.i32)
b = ti.ndarray(vec3, shape=(n, n))
assert isinstance(b, ti.MatrixNdarray)
assert isinstance(b, ti.VectorNdarray)
assert b.shape == (n, n)
assert b.element_type.element_type() == ti.i32
assert b.element_type.shape() == (3, 1)
assert b.element_type.shape() == (3, )

matrix34 = ti.types.matrix(3, 4, float)
c = ti.ndarray(matrix34, shape=(n, n + 1))
Expand Down Expand Up @@ -224,6 +224,13 @@ def test_ndarray_copy_from_ndarray():
_test_ndarray_copy_from_ndarray()


@test_utils.test(arch=supported_archs_taichi_ndarray,
real_matrix=True,
real_matrix_scalarize=True)
def test_ndarray_copy_from_ndarray_matrix_scalarize():
_test_ndarray_copy_from_ndarray()


def _test_ndarray_deepcopy():
n = 16
x = ti.ndarray(ti.i32, shape=n)
Expand Down Expand Up @@ -324,6 +331,13 @@ def test_ndarray_deepcopy():
_test_ndarray_deepcopy()


@test_utils.test(arch=supported_archs_taichi_ndarray,
real_matrix=True,
real_matrix_scalarize=True)
def test_ndarray_deepcopy_matrix_scalarize():
_test_ndarray_deepcopy()


def _test_ndarray_numpy_io():
n = 7
m = 4
Expand Down Expand Up @@ -412,6 +426,13 @@ def test_matrix_ndarray_taichi_scope_real_matrix():
_test_matrix_ndarray_taichi_scope()


@test_utils.test(arch=supported_archs_taichi_ndarray,
real_matrix=True,
real_matrix_scalarize=True)
def test_matrix_ndarray_taichi_scope_real_matrix_scalarize():
_test_matrix_ndarray_taichi_scope()


def _test_matrix_ndarray_taichi_scope_struct_for():
@ti.kernel
def func(a: ti.types.ndarray()):
Expand All @@ -438,6 +459,13 @@ def test_matrix_ndarray_taichi_scope_struct_for_real_matrix():
_test_matrix_ndarray_taichi_scope_struct_for()


@test_utils.test(arch=supported_archs_taichi_ndarray,
real_matrix=True,
real_matrix_scalarize=True)
def test_matrix_ndarray_taichi_scope_struct_for_matrix_scalarize():
_test_matrix_ndarray_taichi_scope_struct_for()


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_vector_ndarray_python_scope():
a = ti.Vector.ndarray(10, ti.i32, 5)
Expand Down Expand Up @@ -472,6 +500,13 @@ def test_vector_ndarray_taichi_scope():
_test_vector_ndarray_taichi_scope()


@test_utils.test(arch=supported_archs_taichi_ndarray,
real_matrix=True,
real_matrix_scalarize=True)
def test_vector_ndarray_taichi_scope_matrix_scalarize():
_test_vector_ndarray_taichi_scope()


@test_utils.test(arch=[ti.cpu, ti.cuda], real_matrix=True)
def test_vector_ndarray_taichi_scope_real_matrix():
_test_vector_ndarray_taichi_scope()
Expand Down Expand Up @@ -633,6 +668,13 @@ def test_ndarray_grouped_real_matrix():
_test_ndarray_grouped()


@test_utils.test(arch=supported_archs_taichi_ndarray,
real_matrix=True,
real_matrix_scalarize=True)
def test_ndarray_grouped_real_matrix_scalarize():
_test_ndarray_grouped()


@test_utils.test(arch=supported_archs_taichi_ndarray)
def test_ndarray_as_template():
@ti.kernel
Expand Down

0 comments on commit 02309bf

Please sign in to comment.