Skip to content

Commit

Permalink
Fix iter and flop counters for split grid
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Dec 17, 2024
1 parent 1455aae commit c0373f0
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 13 deletions.
2 changes: 1 addition & 1 deletion lib/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ namespace quda {
memcpy(out.true_res, true_res.data(), true_res.size() * sizeof(double));
memcpy(out.true_res_hq, true_res_hq.data(), true_res_hq.size() * sizeof(double));

out.iter = in.iter;
out.iter = split_rank == 0 ? in.iter : 0;
comm_allreduce_int(out.iter);

out.ca_lambda_min = in.ca_lambda_min;
Expand Down
7 changes: 1 addition & 6 deletions tests/invert_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,7 @@ std::vector<std::array<double, 2>> solve(test_t param)
inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i];
}

quda::comm_allreduce_int(inv_param.iter);
inv_param.iter /= quda::comm_size() / num_sub_partition;
quda::comm_allreduce_sum(inv_param.gflops);
inv_param.gflops /= quda::comm_size() / num_sub_partition;
quda::comm_allreduce_max(inv_param.secs);
printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
printfQuda("Done: %d sub-partitions - %i total iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile);
if (inv_param.energy > 0) {
printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n",
Expand Down
7 changes: 1 addition & 6 deletions tests/staggered_invert_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,7 @@ std::vector<std::array<double, 2>> solve(test_t param)
inv_param.true_res_hq[j + i] = inv_param.true_res_hq[i];
}

quda::comm_allreduce_int(inv_param.iter);
inv_param.iter /= comm_size() / num_sub_partition;
quda::comm_allreduce_sum(inv_param.gflops);
inv_param.gflops /= comm_size() / num_sub_partition;
quda::comm_allreduce_max(inv_param.secs);
printfQuda("Done: %d sub-partitions - %i iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
printfQuda("Done: %d sub-partitions - %i total iter / %g secs = %g Gflops, %g secs per source\n", num_sub_partition,
inv_param.iter, inv_param.secs, inv_param.gflops / inv_param.secs, inv_param.secs / Nsrc_tile);
if (inv_param.energy > 0) {
printfQuda("Energy = %g J (%g J per source), Mean power = %g W, mean temp = %g C, mean clock = %f\n\n",
Expand Down

0 comments on commit c0373f0

Please sign in to comment.