From 955d6b6c8e9c56117ee77fa7b1689e9a9fb9081c Mon Sep 17 00:00:00 2001
From: "Steven G. Johnson" <stevenj@mit.edu>
Date: Thu, 18 Jul 2019 21:13:32 -0400
Subject: [PATCH] bisection search for faster splitting (#966)

* bisection search for faster splitting

* make sure bisection search terminates

* refactor split_by_cost to not repeat bisection search n times for split_into_n
---
 src/meep/vec.hpp |   4 ++
 src/vec.cpp      | 120 ++++++++++++++++++++++++++++++++---------------
 2 files changed, 87 insertions(+), 37 deletions(-)

diff --git a/src/meep/vec.hpp b/src/meep/vec.hpp
index 697c59770..c6d4afa58 100644
--- a/src/meep/vec.hpp
+++ b/src/meep/vec.hpp
@@ -1029,6 +1029,10 @@ class grid_volume {
   double origin_z() const { return origin.z(); }
 
 private:
+  std::complex<double> get_split_costs(direction d, int split_point) const;
+  grid_volume split_by_cost(int desired_chunks, int proc_num,
+                            int best_split_point, direction best_split_direction, double left_effort_fraction) const;
+  void find_best_split(int desired_chunks, int &best_split_point, direction &best_split_direction, double &left_effort_fraction) const;
   grid_volume(ndim d, double ta, int na, int nb, int nc);
   ivec io;    // integer origin ... always change via set_origin etc.!
   vec origin; // cache of operator[](io), for performance
diff --git a/src/vec.cpp b/src/vec.cpp
index d7ba583b8..4859fc20e 100644
--- a/src/vec.cpp
+++ b/src/vec.cpp
@@ -1002,20 +1002,40 @@ double grid_volume::get_cost() const {
   return fstats.cost();
 }
 
-grid_volume grid_volume::split_by_cost(int desired_chunks, int proc_num) const {
-  const size_t grid_points_owned = nowned_min();
-  if (size_t(desired_chunks) > grid_points_owned) {
+// return complex(left cost, right cost).  Should really be a tuple, but we don't want to require C++11? yet?
+std::complex<double> grid_volume::get_split_costs(direction d, int split_point) const {
+  double left_cost = 0, right_cost = 0;
+  if (split_point > 0) {
+    grid_volume v_left = *this;
+    v_left.set_num_direction(d, split_point);
+    left_cost = v_left.get_cost();
+  }
+  if (split_point < num_direction(d)) {
+    grid_volume v_right = *this;
+    v_right.set_num_direction(d, num_direction(d) - split_point);
+    v_right.shift_origin(d, split_point * 2);
+    right_cost = v_right.get_cost();
+  }
+  return std::complex<double>(left_cost, right_cost);
+}
+
+static double cost_diff(int desired_chunks, std::complex<double> costs) {
+  double left_cost = real(costs), right_cost = imag(costs);
+  return right_cost / (desired_chunks - desired_chunks / 2) - left_cost / (desired_chunks / 2);
+}
+
+void grid_volume::find_best_split(int desired_chunks, int &best_split_point, direction &best_split_direction, double &left_effort_fraction) const {
+  if (size_t(desired_chunks) > nowned_min()) {
     abort("Cannot split %zd grid points into %d parts\n", nowned_min(), desired_chunks);
   }
-  if (desired_chunks == 1) return *this;
 
-  double best_split_measure = 1e20;
-  double left_effort_fraction = 0;
-  int best_split_point = 0;
-  direction best_split_direction = NO_DIRECTION;
+  left_effort_fraction = 0;
+  best_split_point = 0;
+  best_split_direction = NO_DIRECTION;
+  if (desired_chunks == 1) return;
+
   direction longest_axis = NO_DIRECTION;
   int num_in_longest_axis = 0;
-
   LOOP_OVER_DIRECTIONS(dim, d) {
     if (num_direction(d) > num_in_longest_axis) {
       longest_axis = d;
@@ -1023,42 +1043,61 @@ grid_volume grid_volume::split_by_cost(int desired_chunks, int proc_num) const {
     }
   }
 
+  double best_split_measure = 1e20;
   LOOP_OVER_DIRECTIONS(dim, d) {
-    for (int split_point = 1; split_point < num_direction(d); ++split_point) {
-      grid_volume v_left = *this;
-      v_left.set_num_direction(d, split_point);
-      grid_volume v_right = *this;
-      v_right.set_num_direction(d, num_direction(d) - split_point);
-      v_right.shift_origin(d, split_point * 2);
-
-      double left_cost = v_left.get_cost();
-      double right_cost = v_right.get_cost();
-      double total_cost = left_cost + right_cost;
-
-      double split_measure =
-          max(left_cost / (desired_chunks / 2), right_cost / (desired_chunks - desired_chunks / 2));
-      if (split_measure < best_split_measure) {
-        if (d == longest_axis ||
-            split_measure < (best_split_measure - (0.3 * best_split_measure))) {
-          // Only use this split_measure if we're on the longest_axis, or if the split_measure is
-          // more than 30% better than the best_split_measure. This is a heuristic to prefer lower
-          // communication costs when the split_measure is somewhat close.
-          // TODO: Use machine learning to get a cost function for the communication instead of hard
-          // coding 0.3
-
-          best_split_measure = split_measure;
-          best_split_point = split_point;
-          best_split_direction = d;
-          left_effort_fraction = left_cost / total_cost;
-        }
+    int first = 0, last = num_direction(d);
+    while (first < last) { // bisection search for balanced splitting
+      int mid = (first + last) / 2;
+      double mid_diff = cost_diff(desired_chunks, get_split_costs(d, mid));
+      if (mid_diff > 0) {
+        if (first == mid) break;
+        first = mid;
+      }
+      else if (mid_diff < 0) last = mid;
+      else break;
+    }
+    int split_point = (first + last) / 2;
+    std::complex<double> costs = get_split_costs(d, split_point);
+    double left_cost = real(costs), right_cost = imag(costs);
+    double total_cost = left_cost + right_cost;
+    double split_measure =
+        max(left_cost / (desired_chunks / 2), right_cost / (desired_chunks - desired_chunks / 2));
+    if (split_measure < best_split_measure) {
+      if (d == longest_axis ||
+          split_measure < (best_split_measure - (0.3 * best_split_measure))) {
+        // Only use this split_measure if we're on the longest_axis, or if the split_measure is
+        // more than 30% better than the best_split_measure. This is a heuristic to prefer lower
+        // communication costs when the split_measure is somewhat close.
+        // TODO: Use machine learning to get a cost function for the communication instead of hard
+        // coding 0.3
+
+        best_split_measure = split_measure;
+        best_split_point = split_point;
+        best_split_direction = d;
+        left_effort_fraction = left_cost / total_cost;
       }
     }
   }
+}
+
+grid_volume grid_volume::split_by_cost(int desired_chunks, int proc_num) const {
+  int best_split_point;
+  direction best_split_direction;
+  double left_effort_fraction;
+  find_best_split(desired_chunks, best_split_point, best_split_direction, left_effort_fraction);
+  return split_by_cost(desired_chunks, proc_num, best_split_point, best_split_direction, left_effort_fraction);
+}
+
+grid_volume grid_volume::split_by_cost(int desired_chunks, int proc_num,
+                                       int best_split_point, direction best_split_direction, double left_effort_fraction) const {
+  if (desired_chunks == 1) return *this;
+
   const int split_point = best_split_point;
   const int num_in_split_dir = num_direction(best_split_direction);
 
   const int num_low = (size_t)(left_effort_fraction * desired_chunks + 0.5);
   // Revert to split() when cost method gives less grid points than chunks
+  const size_t grid_points_owned = nowned_min();
   if (size_t(num_low) > best_split_point * (grid_points_owned / num_in_split_dir) ||
       size_t(desired_chunks - num_low) >
           (grid_points_owned - best_split_point * (grid_points_owned / num_in_split_dir)))
@@ -1153,9 +1192,16 @@ std::vector<grid_volume> grid_volume::split_into_n(int n) const {
 
   if (n == 3)
     split_into_three(result);
+  else if (n == 1) {
+    result.push_back(*this);
+  }
   else {
+    int best_split_point;
+    direction best_split_direction;
+    double left_effort_fraction;
+    find_best_split(n, best_split_point, best_split_direction, left_effort_fraction);
     for (int i = 0; i < n; ++i) {
-      grid_volume split_gv = split_by_cost(n, i);
+      grid_volume split_gv = split_by_cost(n, i, best_split_point, best_split_direction, left_effort_fraction);
       result.push_back(split_gv);
     }
   }