Skip to content

Commit

Permalink
Add Evaluator support to update multiple accumulators (#2894)
Browse files Browse the repository at this point in the history
* Evaluator support to update multiple accumulators

Improve the performance of EvaluatorTrainingListener by enabling evaluators to update multiple accumulators from the same labels and predictions, rather than needing to recompute values.

* Fix formatting

* Update AbstractCompositeLoss.java

Aims to fix failing test
  • Loading branch information
petebankhead authored Dec 20, 2023
1 parent 1060bdd commit 1eb54c0
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 19 deletions.
17 changes: 15 additions & 2 deletions api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,22 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
updateAccumulators(new String[] {key}, labels, predictions);
}

/** {@inheritDoc} */
@Override
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
Pair<Long, NDArray> update = accuracyHelper(labels, predictions);
totalInstances.compute(key, (k, v) -> v + update.getKey());
correctInstances.compute(key, (k, v) -> v + update.getValue().sum().getLong());
NDArray value = update.getValue();
NDArray sum = value.sum();
long correct = sum.getLong();
for (String key : keys) {
totalInstances.compute(key, (k, v) -> v + update.getKey());
correctInstances.compute(key, (k, v) -> v + correct);
}
value.close();
sum.close();
}

/** {@inheritDoc} */
Expand Down
12 changes: 10 additions & 2 deletions api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,18 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
updateAccumulators(new String[] {key}, labels, predictions);
}

/** {@inheritDoc} */
@Override
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
NDArray boundingBoxError = evaluate(labels, predictions);
float update = boundingBoxError.sum().getFloat();
totalInstances.compute(key, (k, v) -> v + boundingBoxError.size());
ssdBoxPredictionError.compute(key, (k, v) -> v + update);
for (String key : keys) {
totalInstances.compute(key, (k, v) -> v + boundingBoxError.size());
ssdBoxPredictionError.compute(key, (k, v) -> v + update);
}
}

/** {@inheritDoc} */
Expand Down
19 changes: 19 additions & 0 deletions api/src/main/java/ai/djl/training/evaluator/Evaluator.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@ public String getName() {
*/
public abstract void addAccumulator(String key);

/**
* Updates the evaluator with the given keys based on a {@link NDList} of labels and
* predictions.
*
* <p>This is a synchronized operation. You should only call it at the end of a batch or epoch.
*
* <p>This is an alternative to @{link {@link #updateAccumulator(String, NDList, NDList)}} that
* may be more efficient when updating multiple accumulators at once.
*
* @param keys the keys of all the accumulators to update
* @param labels a {@code NDList} of labels
* @param predictions a {@code NDList} of predictions
*/
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
for (String key : keys) {
updateAccumulator(key, labels, predictions);
}
}

/**
* Updates the evaluator with the given key based on a {@link NDList} of labels and predictions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ public void updateAccumulator(String key, NDList labels, NDList predictions) {
evaluator.updateAccumulator(key, getLabels(labels), getPredictions(predictions));
}

/** {@inheritDoc} */
@Override
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
evaluator.updateAccumulators(keys, getLabels(labels), getPredictions(predictions));
}

/** {@inheritDoc} */
@Override
public void resetAccumulator(String key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ private void updateEvaluators(Trainer trainer, BatchData batchData, String[] acc
for (Device device : batchData.getLabels().keySet()) {
NDList labels = batchData.getLabels().get(device);
NDList predictions = batchData.getPredictions().get(device);
for (String accumulator : accumulators) {
evaluator.updateAccumulator(accumulator, labels, predictions);
}
evaluator.updateAccumulators(accumulators, labels, predictions);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ public void addAccumulator(String key) {

/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
for (int i = 0; i < components.size(); i++) {
Pair<NDList, NDList> inputs = inputForComponent(i, labels, predictions);
components.get(i).updateAccumulator(key, inputs.getKey(), inputs.getValue());
components.get(i).updateAccumulators(keys, inputs.getKey(), inputs.getValue());
}
}

Expand Down
12 changes: 10 additions & 2 deletions api/src/main/java/ai/djl/training/loss/Loss.java
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,18 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
updateAccumulators(new String[] {key}, labels, predictions);
}

/** {@inheritDoc} */
@Override
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
// this is a synchronized operation, only call it at end of batch or epoch
float update = evaluate(labels, predictions).sum().getFloat();
totalInstances.compute(key, (k, v) -> v + 1);
totalLoss.compute(key, (k, v) -> v + update);
for (String key : keys) {
totalInstances.compute(key, (k, v) -> v + 1);
totalLoss.compute(key, (k, v) -> v + update);
}
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,23 @@ public void addAccumulator(String key) {
/** {@inheritDoc} */
@Override
public void updateAccumulator(String key, NDList labels, NDList predictions) {
updateAccumulators(new String[] {key}, labels, predictions);
}

/** {@inheritDoc} */
@Override
public void updateAccumulators(String[] keys, NDList labels, NDList predictions) {
Pair<Long, NDArray> update = evaluateHelper(labels, predictions);
totalInstances.compute(key, (k, v) -> v + update.getKey());
totalLoss.compute(
key,
(k, v) -> {
try (NDArray array = update.getValue().sum()) {
return v + array.getFloat();
}
});
for (String key : keys) {
totalInstances.compute(key, (k, v) -> v + update.getKey());
totalLoss.compute(
key,
(k, v) -> {
try (NDArray array = update.getValue().sum()) {
return v + array.getFloat();
}
});
}
}

/** {@inheritDoc} */
Expand Down

0 comments on commit 1eb54c0

Please sign in to comment.