Skip to content

Commit

Permalink
[Lang] MatrixNdarray refactor part8: Add scalarization for BinaryOpSt…
Browse files Browse the repository at this point in the history
…mt with TensorType-operands (#6086)

Related issue = #5873,
#5819

This PR is working "Part ④" in
#5873.
  • Loading branch information
jim19930609 authored Sep 26, 2022
1 parent 40464a0 commit 05e037d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
2 changes: 2 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,14 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
auto ret = ctx->push_back<LocalLoadStmt>(result);
ret->tb = tb;
stmt = ret;
stmt->ret_type = ret_type;
return;
}
flatten_rvalue(rhs, ctx);
ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs->stmt, rhs->stmt));
ctx->stmts.back()->tb = tb;
stmt = ctx->back_stmt();
stmt->ret_type = ret_type;
}

void make_ifte(Expression::FlattenContext *ctx,
Expand Down
68 changes: 67 additions & 1 deletion taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class Scalarize : public IRVisitor {

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);

matrix_init_stmt->ret_type = src_dtype;

stmt->replace_usages_with(matrix_init_stmt.get());
Expand Down Expand Up @@ -178,6 +177,73 @@ class Scalarize : public IRVisitor {
}
}

/*
Before:
TensorType<4 x i32> val = BinaryStmt(TensorType<4 x i32> lhs,
TensorType<4 x i32> rhs)
* Note that "lhs" and "rhs" should have already been scalarized to
MatrixInitStmt
After:
i32 calc_val0 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[0],
rhs->cast<MatrixInitStmt>()->val[0])
i32 calc_val1 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[1],
rhs->cast<MatrixInitStmt>()->val[1])
i32 calc_val2 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[2],
rhs->cast<MatrixInitStmt>()->val[2])
i32 calc_val3 = BinaryStmt(lhs->cast<MatrixInitStmt>()->val[3],
rhs->cast<MatrixInitStmt>()->val[3])
tmp = MatrixInitStmt(calc_val0, calc_val1,
calc_val2, calc_val3)
stmt->replace_all_usages_with(tmp)
*/
void visit(BinaryOpStmt *stmt) override {
auto lhs_dtype = stmt->lhs->ret_type;
auto rhs_dtype = stmt->rhs->ret_type;

// BinaryOpExpression::type_check() should have taken care of the
// broadcasting and neccessary conversions. So we simply add an assertion
// here to make sure that the operands are of the same shape and dtype
TI_ASSERT(lhs_dtype == rhs_dtype);

if (lhs_dtype->is<TensorType>() && rhs_dtype->is<TensorType>()) {
// Scalarization for LoadStmt should have already replaced both operands
// to MatrixInitStmt
TI_ASSERT(stmt->lhs->is<MatrixInitStmt>());
TI_ASSERT(stmt->rhs->is<MatrixInitStmt>());

auto lhs_matrix_init_stmt = stmt->lhs->cast<MatrixInitStmt>();
std::vector<Stmt *> lhs_vals = lhs_matrix_init_stmt->values;

auto rhs_matrix_init_stmt = stmt->rhs->cast<MatrixInitStmt>();
std::vector<Stmt *> rhs_vals = rhs_matrix_init_stmt->values;

TI_ASSERT(rhs_vals.size() == lhs_vals.size());

size_t num_elements = lhs_vals.size();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto binary_stmt = std::make_unique<BinaryOpStmt>(
stmt->op_type, lhs_vals[i], rhs_vals[i]);
matrix_init_values.push_back(binary_stmt.get());

modifier_.insert_before(stmt, std::move(binary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

stmt->replace_usages_with(matrix_init_stmt.get());
modifier_.insert_before(stmt, std::move(matrix_init_stmt));

modifier_.erase(stmt);
}
}

void visit(Block *stmt_list) override {
for (auto &stmt : stmt_list->statements) {
stmt->accept(this);
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,3 +970,25 @@ def verify(x):
field = ti.Matrix.field(2, 2, ti.f32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)


@test_utils.test(arch=[ti.cuda, ti.cpu],
real_matrix=True,
real_matrix_scalarize=True)
def test_binary_op_scalarize():
@ti.func
def func(a: ti.template()):
a[0] = [[0., 1.], [2., 3.]]
a[1] = [[3., 4.], [5., 6.]]
a[2] = a[0] + a[0]
a[3] = a[1] * a[1]
a[4] = ti.max(a[2], a[3])

def verify(x):
assert (x[2] == [[0., 2.], [4., 6.]]).all()
assert (x[3] == [[9., 16.], [25., 36.]]).all()
assert (x[4] == [[9., 16.], [25., 36.]]).all()

field = ti.Matrix.field(2, 2, ti.f32, shape=5)
ndarray = ti.Matrix.ndarray(2, 2, ti.f32, shape=5)
_test_field_and_ndarray(field, ndarray, func, verify)

0 comments on commit 05e037d

Please sign in to comment.