diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h
index f6699533ee4d5..b38846973ff06 100644
--- a/taichi/ir/transforms.h
+++ b/taichi/ir/transforms.h
@@ -32,7 +32,7 @@ void re_id(IRNode *root);
 void flag_access(IRNode *root);
 void eliminate_immutable_local_vars(IRNode *root);
 bool scalarize(IRNode *root, bool half2_optimization_enabled = false);
-void lower_matrix_ptr(IRNode *root);
+void lower_matrix_ptr(IRNode *root, bool force_scalarize = false);
 bool die(IRNode *root);
 bool simplify(IRNode *root, const CompileConfig &config);
 bool cfg_optimization(
diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp
index 0507b9fd52344..b8dc6b31bf94c 100644
--- a/taichi/transforms/compile_to_offloads.cpp
+++ b/taichi/transforms/compile_to_offloads.cpp
@@ -69,9 +69,16 @@ void compile_to_offloads(IRNode *ir,
   }
 
   // Removes MatrixOfMatrixPtrStmt & MatrixOfGlobalPtrStmt
-  irpass::lower_matrix_ptr(ir);
+  irpass::lower_matrix_ptr(ir, config.force_scalarize_matrix);
   print("Matrix ptr lowered");
 
+  if (config.force_scalarize_matrix) {
+    irpass::scalarize(ir, false /*half2_optimization_enabled*/);
+
+    irpass::die(ir);
+    print("Scalarized");
+  }
+
   irpass::full_simplify(
       ir, config,
       {false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone,
@@ -86,10 +93,6 @@ void compile_to_offloads(IRNode *ir,
     irpass::analysis::gather_meshfor_relation_types(ir);
   }
 
-  if (config.force_scalarize_matrix) {
-    irpass::scalarize(ir, false /*half2_optimization_enabled*/);
-  }
-
   if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) {
     // Check whether the kernel obeys the autodiff limitation e.g., gloabl data
     // access rule
@@ -366,7 +369,7 @@ void compile_function(IRNode *ir,
     }
 
     // Removes MatrixOfMatrixPtrStmt & MatrixOfGlobalPtrStmt
-    irpass::lower_matrix_ptr(ir);
+    irpass::lower_matrix_ptr(ir, config.force_scalarize_matrix);
     print("Matrix ptr lowered");
 
     irpass::demote_atomics(ir, config);
diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp
index e4deb59192e84..65c2ba602f400 100644
--- a/taichi/transforms/lower_matrix_ptr.cpp
+++ b/taichi/transforms/lower_matrix_ptr.cpp
@@ -593,13 +593,15 @@ class RemoveMatrixOfPtr : public BasicStmtVisitor {
 
 namespace irpass {
 
-void lower_matrix_ptr(IRNode *root) {
+void lower_matrix_ptr(IRNode *root, bool force_scalarize) {
   TI_AUTO_PROF;
 
-  GatherValidAOSGlobalPtrStmt gather_valid_aos_global_ptr_pass(root);
+  if (!force_scalarize) {
+    GatherValidAOSGlobalPtrStmt gather_valid_aos_global_ptr_pass(root);
 
-  LowerAOSGlobalPtrStmt lower_aos_global_ptr_stmt_pass(
-      root, gather_valid_aos_global_ptr_pass.invalid_aos_global_ptr_stmts_);
+    LowerAOSGlobalPtrStmt lower_aos_global_ptr_stmt_pass(
+        root, gather_valid_aos_global_ptr_pass.invalid_aos_global_ptr_stmts_);
+  }
 
   ScalarizeMatrixPtr scalarize_matrix_ptr_pass(root);
   LowerMatrixPtr::run(root);