diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 626ed168fe536..2ed64da432a66 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -226,20 +226,19 @@ class Scalarize : public BasicStmtVisitor { void visit(BinaryOpStmt *stmt) override { auto lhs_dtype = stmt->lhs->ret_type; auto rhs_dtype = stmt->rhs->ret_type; - - if (lhs_dtype->is() && rhs_dtype->is()) { - return; - } - - if (lhs_dtype->is() && rhs_dtype->is()) { + if (lhs_dtype->is() || rhs_dtype->is()) { + // Make sure broadcasting has been correctly applied by + // BinaryOpExpression::type_check(). + TI_ASSERT(lhs_dtype->is() && rhs_dtype->is()); + // However, since the type conversions are delayed until + // irpass::type_check(), we only check for the shape here. + TI_ASSERT(lhs_dtype->cast()->get_shape() == + rhs_dtype->cast()->get_shape()); // Scalarization for LoadStmt should have already replaced both operands - // to MatrixInitStmt + // to MatrixInitStmt. TI_ASSERT(stmt->lhs->is()); TI_ASSERT(stmt->rhs->is()); - TI_ASSERT(lhs_dtype->cast()->get_shape() == - rhs_dtype->cast()->get_shape()); - auto lhs_matrix_init_stmt = stmt->lhs->cast(); std::vector lhs_vals = lhs_matrix_init_stmt->values; @@ -322,21 +321,16 @@ class Scalarize : public BasicStmtVisitor { void visit(AtomicOpStmt *stmt) override { auto dest_dtype = stmt->dest->ret_type.ptr_removed(); auto val_dtype = stmt->val->ret_type; - - if (dest_dtype->is() && val_dtype->is()) { - return; - } - - // AtomicOpExpression::type_check() have taken care of the broadcasting, - // but the type conversions are delayed until irpass::type_check(). - // So we only check for the shape here. - TI_ASSERT(dest_dtype->is() && val_dtype->is()); - TI_ASSERT(dest_dtype->cast()->get_shape() == - val_dtype->cast()->get_shape()); - - if (dest_dtype->is() && val_dtype->is()) { + if (dest_dtype->is() || val_dtype->is()) { + // Make sure broadcasting has been correctly applied by + // AtomicOpExpression::type_check(). + TI_ASSERT(dest_dtype->is() && val_dtype->is()); + // However, since the type conversions are delayed until + // irpass::type_check(), we only check for the shape here. + TI_ASSERT(dest_dtype->cast()->get_shape() == + val_dtype->cast()->get_shape()); // Scalarization for LoadStmt should have already replaced val operand - // to MatrixInitStmt + // to MatrixInitStmt. TI_ASSERT(stmt->val->is()); auto val_matrix_init_stmt = stmt->val->cast(); @@ -411,20 +405,18 @@ class Scalarize : public BasicStmtVisitor { auto cond_dtype = stmt->op1->ret_type; auto op2_dtype = stmt->op2->ret_type; auto op3_dtype = stmt->op3->ret_type; - - if (cond_dtype->is() && op2_dtype->is() && - op3_dtype->is()) { - return; - } - - // TernaryOpExpression::type_check() have taken care of the broadcasting, - // but the type conversions are delayed until irpass::type_check(). - // So we only check for the shape here. - TI_ASSERT(cond_dtype.get_shape() == op2_dtype.get_shape()); - TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape()); - - if (cond_dtype->is() && op2_dtype->is() && + if (cond_dtype->is() || op2_dtype->is() || op3_dtype->is()) { + // Make sure broadcasting has been correctly applied by + // TernaryOpExpression::type_check(). + TI_ASSERT(cond_dtype->is() && op2_dtype->is() && + op3_dtype->is()); + // However, since the type conversions are delayed until + // irpass::type_check(), we only check for the shape here. + TI_ASSERT(cond_dtype.get_shape() == op2_dtype.get_shape()); + TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape()); + // Scalarization for LoadStmt should have already replaced all operands + // to MatrixInitStmt. TI_ASSERT(stmt->op1->is()); TI_ASSERT(stmt->op2->is()); TI_ASSERT(stmt->op3->is()); diff --git a/tests/python/test_quant_atomics.py b/tests/python/test_quant_atomics.py index fb98939f0b9fc..eee857800c849 100644 --- a/tests/python/test_quant_atomics.py +++ b/tests/python/test_quant_atomics.py @@ -43,9 +43,7 @@ def foo(): assert z[None] == 3 -@test_utils.test(require=[ti.extension.quant_basic, ti.extension.data64], - debug=True) -def test_quant_int_atomics_b64(): +def _test_quant_int_atomics_b64(): qi13 = ti.types.quant.int(13, True) x = ti.field(dtype=qi13) @@ -68,8 +66,21 @@ def foo(): assert x[2] == 315 -@test_utils.test(require=ti.extension.quant_basic, debug=True) -def test_quant_fixed_atomics(): +@test_utils.test(require=[ti.extension.quant_basic, ti.extension.data64], + debug=True) +def test_quant_int_atomics_b64(): + _test_quant_int_atomics_b64() + + +@test_utils.test(require=[ti.extension.quant_basic, ti.extension.data64], + debug=True, + real_matrix=True, + real_matrix_scalarize=True) +def test_quant_int_atomics_b64_real_matrix_scalarize(): + _test_quant_int_atomics_b64() + + +def _test_quant_fixed_atomics(): qfxt13 = ti.types.quant.fixed(bits=13, signed=True, scale=0.1) qfxt19 = ti.types.quant.fixed(bits=19, signed=False, scale=0.1) @@ -91,3 +102,16 @@ def foo(): foo() assert x[None] == approx(-3.3) assert y[None] == approx(1124.4) + + +@test_utils.test(require=ti.extension.quant_basic, debug=True) +def test_quant_fixed_atomics(): + _test_quant_fixed_atomics() + + +@test_utils.test(require=ti.extension.quant_basic, + debug=True, + real_matrix=True, + real_matrix_scalarize=True) +def test_quant_fixed_atomics_real_matrix_scalarize(): + _test_quant_fixed_atomics()