diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 5913b64aceedc..92700327cfb0f 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -135,14 +135,6 @@ void compile_to_offloads(IRNode *ir, print("Offloaded"); irpass::analysis::verify(ir); - if (config.real_matrix_scalarize) { - irpass::scalarize(ir); - - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::die(ir); - print("Scalarized"); - } - // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { @@ -187,6 +179,14 @@ void offload_to_executable(IRNode *ir, print("Detect read-only accesses"); } + if (config.real_matrix_scalarize) { + irpass::scalarize(ir); + + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); + print("Scalarized"); + } + irpass::demote_atomics(ir, config); print("Atomics demoted I"); irpass::analysis::verify(ir); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 444c63097f770..69073000b5329 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -1049,8 +1049,12 @@ class ExtractLocalPointers : public BasicStmtVisitor { Block *top_level_; explicit ExtractLocalPointers(IRNode *root) : immediate_modifier_(root) { - TI_ASSERT(root->is()); - top_level_ = root->as(); + if (root->is()) { + top_level_ = root->as()->body.get(); + } else { + TI_ASSERT(root->is()); + top_level_ = root->as(); + } root->accept(this); delayed_modifier_.modify_ir(); }