Skip to content

Commit

Permalink
[bug] MatrixType bug fix: Fix error with nested StructType and Matrix…
Browse files Browse the repository at this point in the history
…Type (#6689)

Issue: #5819

### Brief Summary
1. Modified `Matrix::fill()` to broadcast `val` into VectorType if `ndim
== 1`, and to MatrixType if `ndim == 2`
2. Modified `Struct::fill()` to apply `matrix_op.fill()` in case of
`Expr` with TensorType
  • Loading branch information
jim19930609 authored Nov 24, 2022
1 parent d748795 commit e8f9816
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 17 deletions.
7 changes: 6 additions & 1 deletion python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,7 +1542,12 @@ def fill(self, val):
"""
if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr)
and not val.is_tensor()):
val = list(list(val for _ in range(self.m)) for _ in range(self.n))
if self.ndim == 2:
val = list(
list(val for _ in range(self.m)) for _ in range(self.n))
else:
assert self.ndim == 1
val = list(val for _ in range(self.n))
elif isinstance(val, Matrix):
val = val.to_list()
else:
Expand Down
12 changes: 8 additions & 4 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,14 @@ def fill(self, val):
Args:
val (Union[int, float]): Value to fill.
"""
def assign_renamed(x, y):
return ops.assign(x, y)

return self._element_wise_writeback_binary(assign_renamed, val)
for k, v in self.items:
if isinstance(v, impl.Expr) and v.ptr.is_tensor():
from taichi.lang import matrix_ops # pylint: disable=C0415
matrix_ops.fill(v, val)
elif isinstance(v, (Struct, Matrix)):
v._element_wise_binary(ops.assign, val)
else:
ops.assign(v, val)

def __len__(self):
"""Get the number of entries in a custom struct"""
Expand Down
8 changes: 2 additions & 6 deletions taichi/transforms/lower_ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,25 +396,21 @@ class LowerAST : public IRVisitor {
auto expr = assign->rhs;
auto fctx = make_flatten_ctx();
flatten_rvalue(expr, &fctx);
flatten_lvalue(dest, &fctx);
if (dest.is<IdExpression>()) {
fctx.push_back<LocalStoreStmt>(
assign->parent->lookup_var(assign->lhs.cast<IdExpression>()->id),
expr->stmt);
fctx.push_back<LocalStoreStmt>(dest->stmt, expr->stmt);
} else if (dest.is<IndexExpression>()) {
auto ix = dest.cast<IndexExpression>();
flatten_lvalue(dest, &fctx);
if (ix->is_local()) {
fctx.push_back<LocalStoreStmt>(dest->stmt, expr->stmt);
} else {
fctx.push_back<GlobalStoreStmt>(dest->stmt, expr->stmt);
}
} else if (dest.is<StrideExpression>()) {
flatten_lvalue(dest, &fctx);
fctx.push_back<GlobalStoreStmt>(dest->stmt, expr->stmt);
} else {
TI_ASSERT(dest.is<ArgLoadExpression>() &&
dest.cast<ArgLoadExpression>()->is_ptr);
flatten_lvalue(dest, &fctx);
fctx.push_back<GlobalStoreStmt>(dest->stmt, expr->stmt);
}
fctx.stmts.back()->set_tb(assign->tb);
Expand Down
39 changes: 33 additions & 6 deletions tests/python/test_custom_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def run_python_scope():
assert y[i].b == int(1.01 * i)


@test_utils.test()
def test_struct_fill():
def _test_struct_fill():
n = 32

# also tests implicit cast
Expand Down Expand Up @@ -114,6 +113,16 @@ def fill_elements():
assert np.allclose(x[i].b.to_numpy(), int(x[i].a))


@test_utils.test()
def test_struct_fill():
_test_struct_fill()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_struct_fill_matrix_scalarize():
_test_struct_fill()


@test_utils.test()
def test_matrix_type():
n = 32
Expand Down Expand Up @@ -142,8 +151,7 @@ def run_python_scope():
assert np.allclose(x[i].to_numpy(), np.array([i + 1, i, i]))


@test_utils.test()
def test_struct_type():
def _test_struct_type():
n = 32
vec3f = ti.types.vector(3, float)
line3f = ti.types.struct(linedir=vec3f, length=float)
Expand Down Expand Up @@ -204,6 +212,16 @@ def run_python_scope():
assert x[i].line.length == 5.0


@test_utils.test()
def test_struct_type():
_test_struct_type()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_struct_type_matrix_scalarize():
_test_struct_type()


@test_utils.test(exclude=ti.cc)
def test_dataclass():
# example struct class type
Expand Down Expand Up @@ -245,8 +263,7 @@ def get_area_field() -> ti.f32:
assert np.isclose(get_area_field(), 4.0 * 3.14 * 4.0)


@test_utils.test()
def test_struct_assign():
def _test_struct_assign():
n = 32
vec3f = ti.types.vector(3, float)
line3f = ti.types.struct(linedir=vec3f, length=float)
Expand Down Expand Up @@ -284,6 +301,16 @@ def run_python_scope():
assert x[i].line.length == i + 0.5


@test_utils.test()
def test_struct_assign():
_test_struct_assign()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_struct_assign_matrix_scalarize():
_test_struct_assign()


@test_utils.test()
def test_compound_type_implicit_cast():
vec2i = ti.types.vector(2, int)
Expand Down

0 comments on commit e8f9816

Please sign in to comment.