diff --git a/lib/solver.cpp b/lib/solver.cpp index 2f9be27701..e7b5e52bd8 100644 --- a/lib/solver.cpp +++ b/lib/solver.cpp @@ -646,7 +646,7 @@ namespace quda { memcpy(out.true_res, true_res.data(), true_res.size() * sizeof(double)); memcpy(out.true_res_hq, true_res_hq.data(), true_res_hq.size() * sizeof(double)); - out.iter = in.iter; + out.iter = split_rank == 0 ? in.iter : 0; comm_allreduce_int(out.iter); out.ca_lambda_min = in.ca_lambda_min; diff --git a/tests/invert_test.cpp b/tests/invert_test.cpp index edbde2256d..8e400a3f7e 100644 --- a/tests/invert_test.cpp +++ b/tests/invert_test.cpp @@ -367,12 +367,7 @@ std::vector> solve(test_t param) inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i]; } - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= quda::comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= quda::comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, + printfQuda("Done: %d sub-partitions - %i total iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); if (inv_param.energy > 0) { printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n", diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 77d9cd5a75..a3ea5ef010 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -421,12 +421,7 @@ std::vector> solve(test_t param) inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i]; } - quda::comm_allreduce_int(inv_param.iter); - inv_param.iter /= comm_size() / num_sub_partition; - quda::comm_allreduce_sum(inv_param.gflops); - inv_param.gflops /= comm_size() / num_sub_partition; - quda::comm_allreduce_max(inv_param.secs); - printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, + printfQuda("Done: %d sub-partitions - %i total iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition, inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile); if (inv_param.energy > 0) { printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n",