From ed762e2ab4a0a717ec7d417ddbc8f9fa7bd3d926 Mon Sep 17 00:00:00 2001
From: Andreas Hoenselaar <ahoens@google.com>
Date: Thu, 19 Aug 2021 12:13:46 -0700
Subject: [PATCH] Process incoming chunk data immediately upon receipt. (#1730)

Use a callback mechanism in conjunction with `MPI_Waitsome` to process data immediately instead of waiting for the slowest participant.

Co-authored-by: Andreas Hoenselaar <ahoenselaar@gmail.com>
---
 src/meep.hpp  |   5 +-
 src/mympi.cpp |  20 ++++++--
 src/step.cpp  | 129 +++++++++++++++++++++++++++++---------------------
 3 files changed, 95 insertions(+), 59 deletions(-)

diff --git a/src/meep.hpp b/src/meep.hpp
index 7f336280f..ec5b8db96 100644
--- a/src/meep.hpp
+++ b/src/meep.hpp
@@ -784,9 +784,11 @@ struct comms_sequence {
 // Upon destruction, the comms_manager waits for completion of all enqueued operations.
 class comms_manager {
  public:
+  using receive_callback = std::function<void()>;
   virtual ~comms_manager() {}
   virtual void send_real_async(const void *buf, size_t count, int dest, int tag) = 0;
-  virtual void receive_real_async(void *buf, size_t count, int source, int tag) = 0;
+  virtual void receive_real_async(void *buf, size_t count, int source, int tag,
+                                  const receive_callback &cb) = 0;
   virtual size_t max_transfer_size() const { return std::numeric_limits<size_t>::max(); };
 };
 
@@ -2162,6 +2164,7 @@ class fields {
   double max_eps() const;
   // step.cpp
   void step_boundaries(field_type);
+  void process_incoming_chunk_data(field_type ft, const chunk_pair &comm_pair);
 
   bool nosize_direction(direction d) const;
   direction normal_direction(const volume &where) const;
diff --git a/src/mympi.cpp b/src/mympi.cpp
index 4193e5c69..647bb395c 100644
--- a/src/mympi.cpp
+++ b/src/mympi.cpp
@@ -89,8 +89,18 @@ class mpi_comms_manager : public comms_manager {
   mpi_comms_manager() {}
   ~mpi_comms_manager() override {
 #ifdef HAVE_MPI
-    if (!reqs.empty()) {
-      MPI_Waitall(reqs.size(), reqs.data(), MPI_STATUSES_IGNORE);
+    int num_pending_requests = reqs.size();
+    std::vector<int> completed_indices(num_pending_requests);
+    while (num_pending_requests) {
+      int num_completed_requests = 0;
+      MPI_Waitsome(reqs.size(), reqs.data(), &num_completed_requests, completed_indices.data(),
+                   MPI_STATUSES_IGNORE);
+      for (int i = 0; i < num_completed_requests; ++i) {
+        int request_idx = completed_indices[i];
+        callbacks[request_idx]();
+        reqs[request_idx] = MPI_REQUEST_NULL;
+        --num_pending_requests;
+      }
     }
 #endif
   }
@@ -98,6 +108,7 @@ class mpi_comms_manager : public comms_manager {
   void send_real_async(const void *buf, size_t count, int dest, int tag) override {
 #ifdef HAVE_MPI
     reqs.emplace_back();
+    callbacks.push_back(/*no-op*/ []{});
     MPI_Isend(buf, static_cast<int>(count), MPI_REALNUM, dest, tag, mycomm, &reqs.back());
 #else
     (void)buf;
@@ -107,9 +118,11 @@ class mpi_comms_manager : public comms_manager {
 #endif
   }
 
-  void receive_real_async(void *buf, size_t count, int source, int tag) override {
+  void receive_real_async(void *buf, size_t count, int source, int tag,
+                          const receive_callback &cb) override {
 #ifdef HAVE_MPI
     reqs.emplace_back();
+    callbacks.push_back(cb);
     MPI_Irecv(buf, static_cast<int>(count), MPI_REALNUM, source, tag, mycomm, &reqs.back());
 #else
     (void)buf;
@@ -127,6 +140,7 @@ class mpi_comms_manager : public comms_manager {
 #ifdef HAVE_MPI
   std::vector<MPI_Request> reqs;
 #endif
+  std::vector<receive_callback> callbacks;
 };
 
 } // namespace
diff --git a/src/step.cpp b/src/step.cpp
index 8bdf8a537..88f19e9b2 100644
--- a/src/step.cpp
+++ b/src/step.cpp
@@ -168,6 +168,59 @@ void fields_chunk::phase_material(int phasein_time) {
   }
 }
 
+void fields::process_incoming_chunk_data(field_type ft, const chunk_pair &comm_pair) {
+  am_now_working_on(Boundaries);
+  int this_chunk_idx = comm_pair.second;
+  const int pair_idx = chunk_pair_to_index(comm_pair);
+  const realnum *pair_comm_block = static_cast<realnum *>(comm_blocks[ft][pair_idx]);
+
+  {
+    const comms_key key = {ft, CONNECT_PHASE, comm_pair};
+    size_t num_transfers = get_comm_size(key) / 2; // Two realnums per complex
+    if (num_transfers) {
+      const std::complex<realnum> *pair_comm_block_complex =
+          reinterpret_cast<const std::complex<realnum> *>(pair_comm_block);
+      const std::vector<realnum *> &incoming_connection =
+          chunks[this_chunk_idx]->connections_in.at(key);
+      const std::vector<std::complex<realnum> > &connection_phase_for_ft =
+          chunks[this_chunk_idx]->connection_phases[key];
+
+      for (size_t n = 0; n < num_transfers; ++n) {
+        std::complex<realnum> temp = connection_phase_for_ft[n] * pair_comm_block_complex[n];
+        *(incoming_connection[2 * n]) = temp.real();
+        *(incoming_connection[2 * n + 1]) = temp.imag();
+      }
+      pair_comm_block += 2 * num_transfers;
+    }
+  }
+
+  {
+    const comms_key key = {ft, CONNECT_NEGATE, comm_pair};
+    const size_t num_transfers = get_comm_size(key);
+    if (num_transfers) {
+      const std::vector<realnum *> &incoming_connection =
+          chunks[this_chunk_idx]->connections_in.at(key);
+      for (size_t n = 0; n < num_transfers; ++n) {
+        *(incoming_connection[n]) = -pair_comm_block[n];
+      }
+      pair_comm_block += num_transfers;
+    }
+  }
+
+  {
+    const comms_key key = {ft, CONNECT_COPY, comm_pair};
+    const size_t num_transfers = get_comm_size(key);
+    if (num_transfers) {
+      const std::vector<realnum *> &incoming_connection =
+          chunks[this_chunk_idx]->connections_in.at(key);
+      for (size_t n = 0; n < num_transfers; ++n) {
+        *(incoming_connection[n]) = pair_comm_block[n];
+      }
+    }
+  }
+  finished_working();
+}
+
 void fields::step_boundaries(field_type ft) {
   connect_chunks(); // re-connect if !chunk_connections_valid
 
@@ -178,8 +231,12 @@ void fields::step_boundaries(field_type ft) {
     const auto &sequence = comms_sequence_for_field[ft];
     for (const comms_operation &op : sequence.receive_ops) {
       if (chunks[op.other_chunk_idx]->is_mine()) { continue; }
+      chunk_pair comm_pair{op.other_chunk_idx, op.my_chunk_idx};
+      comms_manager::receive_callback cb = [this, ft, comm_pair]() {
+        process_incoming_chunk_data(ft, comm_pair);
+      };
       manager->receive_real_async(comm_blocks[ft][op.pair_idx], static_cast<int>(op.transfer_size),
-                                  op.other_proc_id, op.tag);
+                                  op.other_proc_id, op.tag, cb);
     }
 
     // Do the metals first!
@@ -198,70 +255,32 @@ void fields::step_boundaries(field_type ft) {
       for (connect_phase ip : all_connect_phases) {
         const comms_key key = {ft, ip, comm_pair};
         const size_t pair_comm_size = get_comm_size(key);
-        const std::vector<realnum *> &outgoing_connection =
-            chunks[op.my_chunk_idx]->connections_out[key];
-        for (size_t n = 0; n < pair_comm_size; ++n) {
-          outgoing_comm_block[n] = *(outgoing_connection[n]);
+        if (pair_comm_size) {
+          const std::vector<realnum *> &outgoing_connection =
+              chunks[op.my_chunk_idx]->connections_out.at(key);
+          for (size_t n = 0; n < pair_comm_size; ++n) {
+            outgoing_comm_block[n] = *(outgoing_connection[n]);
+          }
+          outgoing_comm_block += pair_comm_size;
         }
-        outgoing_comm_block += pair_comm_size;
       }
       if (chunks[op.other_chunk_idx]->is_mine()) { continue; }
       manager->send_real_async(comm_blocks[ft][pair_idx], static_cast<int>(op.transfer_size),
                                op.other_proc_id, op.tag);
     }
+
+    // Process local transfers, which do not depend on a communication mechanism across nodes.
+    for (const comms_operation &op : sequence.receive_ops) {
+      if (chunks[op.other_chunk_idx]->is_mine()) {
+        process_incoming_chunk_data(ft, {op.other_chunk_idx, op.my_chunk_idx});
+      }
+    }
     finished_working();
 
     am_now_working_on(MpiOneTime);
     // Let the communication manager drop out of scope to complete all outstanding requests.
-  }
-  finished_working();
-
-  // Finally, copy incoming data to the fields themselves, multiplying phases:
-  am_now_working_on(Boundaries);
-  for (int i = 0; i < num_chunks; i++) {
-    if (!chunks[i]->is_mine()) continue;
-
-    for (int j = 0; j < num_chunks; j++) {
-      const chunk_pair comm_pair{j, i};
-      const int pair_idx = chunk_pair_to_index(comm_pair);
-      const realnum *pair_comm_block = static_cast<realnum *>(comm_blocks[ft][pair_idx]);
-
-      {
-        const std::complex<realnum> *pair_comm_block_complex =
-            reinterpret_cast<const std::complex<realnum> *>(pair_comm_block);
-        const comms_key key = {ft, CONNECT_PHASE, comm_pair};
-        const std::vector<realnum *> &incoming_connection = chunks[i]->connections_in[key];
-        const std::vector<std::complex<realnum> > &connection_phase_for_ft =
-	  chunks[i]->connection_phases[key];
-        size_t num_transfers = get_comm_size(key) / 2; // Two realnums per complex
-
-        for (size_t n = 0; n < num_transfers; ++n) {
-          std::complex<realnum> temp = connection_phase_for_ft[n] * pair_comm_block_complex[n];
-          *(incoming_connection[2 * n]) = temp.real();
-          *(incoming_connection[2 * n + 1]) = temp.imag();
-        }
-        pair_comm_block += 2 * num_transfers;
-      }
-
-      {
-        const comms_key key = {ft, CONNECT_NEGATE, comm_pair};
-        const std::vector<realnum *> &incoming_connection = chunks[i]->connections_in[key];
-        const size_t num_transfers = get_comm_size(key);
-        for (size_t n = 0; n < num_transfers; ++n) {
-          *(incoming_connection[n]) = -pair_comm_block[n];
-        }
-        pair_comm_block += num_transfers;
-      }
-
-      {
-        const comms_key key = {ft, CONNECT_COPY, comm_pair};
-        const std::vector<realnum *> &incoming_connection = chunks[i]->connections_in[key];
-        const size_t num_transfers = get_comm_size(key);
-        for (size_t n = 0; n < num_transfers; ++n) {
-          *(incoming_connection[n]) = pair_comm_block[n];
-        }
-      }
-    }
+    // As data is received, the installed callback handles copying the data from the comm buffer
+    // back into the chunk field array.
   }
   finished_working();
 }