Skip to content

Commit

Permalink
Merge pull request #70 from minnerbe/improve-concurrent-optimization
Browse files Browse the repository at this point in the history
Improve concurrent optimization
  • Loading branch information
axtimwalde authored Feb 13, 2024
2 parents cbbe21d + 032d3c8 commit a53245d
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 159 deletions.
141 changes: 88 additions & 53 deletions mpicbg/src/main/java/mpicbg/models/TileConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package mpicbg.models;

import mpicbg.util.RealSum;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
Expand All @@ -27,8 +29,8 @@
import java.util.ListIterator;
import java.util.Set;
import java.util.concurrent.ExecutionException;

import mpicbg.util.RealSum;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;


/**
Expand Down Expand Up @@ -119,29 +121,45 @@ public void clear()
*/
protected void apply()
{
// final ArrayList< Thread > threads = new ArrayList< Thread >();
// for ( final Tile< ? > t : tiles )
// {
// final Thread thread = new Thread(
// new Runnable()
// {
// final public void run()
// {
// t.apply();
// }
// } );
// threads.add( thread );
// thread.start();
// }
// for ( final Thread thread : threads )
// {
// try { thread.join(); }
// catch ( InterruptedException e ){ e.printStackTrace(); }
// }
for ( final Tile< ? > t : tiles )
t.apply();
}

/**
* Apply the model of each {@link Tile} to all its
* {@link PointMatch PointMatches} using a given
* {@link ThreadPoolExecutor}.
*/
protected void apply(final ThreadPoolExecutor executor) {
final List<Tile<?>> allTiles = new ArrayList<>(tiles);
final int nTiles = allTiles.size();
final int nThreads = executor.getMaximumPoolSize();
final int tilesPerThread = nTiles / nThreads + (nTiles % nThreads == 0 ? 0 : 1);
final List<Future<Void>> applyTasks = new ArrayList<>(nThreads);

for (int j = 0; j < nThreads; j++) {
final int start = j * tilesPerThread;
final int end = Math.min((j + 1) * tilesPerThread, nTiles);
applyTasks.add(executor.submit(() -> applyToRange(allTiles, start, end)));
}

for (final Future<Void> task : applyTasks) {
try {
task.get();
} catch (final InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}

private static Void applyToRange(final List<Tile<?>> tiles, final int start, final int end) {
for (int i = start; i < end; i++) {
final Tile<?> t = tiles.get(i);
t.apply();
}
return null;
}

/**
* Estimate min/max/average displacement of all
* {@link PointMatch PointMatches} in all {@link Tile Tiles}.
Expand All @@ -161,38 +179,55 @@ protected void updateErrors()
}
cd /= tiles.size();
error = cd;
}

/**
* Estimate min/max/average displacement of all
* {@link PointMatch PointMatches} in all {@link Tile Tiles} using
* a given {@link ThreadPoolExecutor}.
*/
protected void updateErrors(final ThreadPoolExecutor executor) {
final List<Tile<?>> allTiles = new ArrayList<>(tiles);
final int nTiles = allTiles.size();
final int nThreads = executor.getMaximumPoolSize();
final int tilesPerThread = nTiles / nThreads + (nTiles % nThreads == 0 ? 0 : 1);
final List<Future<Double[]>> applyTasks = new ArrayList<>(nThreads);

for (int j = 0; j < nThreads; j++) {
final int start = j * tilesPerThread;
final int end = Math.min((j + 1) * tilesPerThread, nTiles);
applyTasks.add(executor.submit(() -> computeErrorsOfRange(allTiles, start, end)));
}

minError = Double.MAX_VALUE;
maxError = 0.0;
double sum = 0.0;
for (final Future<Double[]> task : applyTasks) {
try {
final Double[] minMaxSum = task.get();
if (minMaxSum[0] < minError) minError = minMaxSum[0];
if (minMaxSum[1] > maxError) maxError = minMaxSum[1];
sum += minMaxSum[2];
} catch (final InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
error = sum / allTiles.size();
}

// final ArrayList< Thread > threads = new ArrayList< Thread >();
//
// error = 0.0;
// minError = Double.MAX_VALUE;
// maxError = 0.0;
// for ( final Tile< ? > t : tiles )
// {
// final Thread thread = new Thread(
// new Runnable()
// {
// final public void run()
// {
// t.updateCost();
// synchronized ( this )
// {
// double d = t.getDistance();
// if ( d < minError ) minError = d;
// if ( d > maxError ) maxError = d;
// error += d;
// }
// }
// } );
// thread.start();
// threads.add( thread );
// }
// for ( final Thread thread : threads )
// {
// try { thread.join(); }
// catch ( InterruptedException e ){ e.printStackTrace(); }
// }
// error /= tiles.size();
private static Double[] computeErrorsOfRange(List<Tile<?>> tiles, int start, int end) {
double sum = 0.0;
double minError = Double.MAX_VALUE;
double maxError = 0.0;
for (int i = start; i < end; i++) {
final Tile<?> t = tiles.get(i);
t.updateCost();
final double d = t.getDistance();
if (d < minError) minError = d;
if (d > maxError) maxError = d;
sum += d;
}
return new Double[] { minError, maxError, sum };
}

/**
Expand Down Expand Up @@ -305,7 +340,7 @@ public void optimizeSilentlyConcurrent(
final double maxAllowedError,
final int maxIterations,
final int maxPlateauwidth,
final double damp ) throws NotEnoughDataPointsException, IllDefinedDataPointsException, InterruptedException, ExecutionException
final double damp ) throws InterruptedException, ExecutionException
{
TileUtil.optimizeConcurrently(observer, maxAllowedError, maxIterations, maxPlateauwidth, damp,
this, tiles, fixedTiles, Runtime.getRuntime().availableProcessors());
Expand Down
Loading

0 comments on commit a53245d

Please sign in to comment.