diff --git a/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java b/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java index c9a5fdf7036..8610f9e92bb 100644 --- a/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java +++ b/api/src/main/java/ai/djl/training/evaluator/AbstractAccuracy.java @@ -77,9 +77,22 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { Pair update = accuracyHelper(labels, predictions); - totalInstances.compute(key, (k, v) -> v + update.getKey()); - correctInstances.compute(key, (k, v) -> v + update.getValue().sum().getLong()); + NDArray value = update.getValue(); + NDArray sum = value.sum(); + long correct = sum.getLong(); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + update.getKey()); + correctInstances.compute(key, (k, v) -> v + correct); + } + value.close(); + sum.close(); } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java index 4af9e5de3d1..ab2d554142d 100644 --- a/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java +++ b/api/src/main/java/ai/djl/training/evaluator/BoundingBoxError.java @@ -63,10 +63,18 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { NDArray boundingBoxError = evaluate(labels, predictions); float update = boundingBoxError.sum().getFloat(); - totalInstances.compute(key, (k, v) -> v + boundingBoxError.size()); - ssdBoxPredictionError.compute(key, (k, v) -> v + update); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + boundingBoxError.size()); + ssdBoxPredictionError.compute(key, (k, v) -> v + update); + } } /** {@inheritDoc} */ diff --git a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java index 6d2c5995601..c373471f6cf 100644 --- a/api/src/main/java/ai/djl/training/evaluator/Evaluator.java +++ b/api/src/main/java/ai/djl/training/evaluator/Evaluator.java @@ -74,6 +74,25 @@ public String getName() { */ public abstract void addAccumulator(String key); + /** + * Updates the evaluator with the given keys based on a {@link NDList} of labels and + * predictions. + * + *

This is a synchronized operation. You should only call it at the end of a batch or epoch. + * + *

This is an alternative to @{link {@link #updateAccumulator(String, NDList, NDList)}} that + * may be more efficient when updating multiple accumulators at once. + * + * @param keys the keys of all the accumulators to update + * @param labels a {@code NDList} of labels + * @param predictions a {@code NDList} of predictions + */ + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { + for (String key : keys) { + updateAccumulator(key, labels, predictions); + } + } + /** * Updates the evaluator with the given key based on a {@link NDList} of labels and predictions. * diff --git a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java index a7fe08b610e..aa12cae628c 100644 --- a/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java +++ b/api/src/main/java/ai/djl/training/evaluator/IndexEvaluator.java @@ -67,6 +67,12 @@ public void updateAccumulator(String key, NDList labels, NDList predictions) { evaluator.updateAccumulator(key, getLabels(labels), getPredictions(predictions)); } + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { + evaluator.updateAccumulators(keys, getLabels(labels), getPredictions(predictions)); + } + /** {@inheritDoc} */ @Override public void resetAccumulator(String key) { diff --git a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java index 1dbfe4117cd..2556a026259 100644 --- a/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java +++ b/api/src/main/java/ai/djl/training/listener/EvaluatorTrainingListener.java @@ -144,9 +144,7 @@ private void updateEvaluators(Trainer trainer, BatchData batchData, String[] acc for (Device device : batchData.getLabels().keySet()) { NDList labels = batchData.getLabels().get(device); NDList predictions = batchData.getPredictions().get(device); - for (String accumulator : accumulators) { - evaluator.updateAccumulator(accumulator, labels, predictions); - } + evaluator.updateAccumulators(accumulators, labels, predictions); } } } diff --git a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java index 2a46416190a..2e2cdcb8c86 100644 --- a/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java +++ b/api/src/main/java/ai/djl/training/loss/AbstractCompositeLoss.java @@ -80,10 +80,10 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override - public void updateAccumulator(String key, NDList labels, NDList predictions) { + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { for (int i = 0; i < components.size(); i++) { Pair inputs = inputForComponent(i, labels, predictions); - components.get(i).updateAccumulator(key, inputs.getKey(), inputs.getValue()); + components.get(i).updateAccumulators(keys, inputs.getKey(), inputs.getValue()); } } diff --git a/api/src/main/java/ai/djl/training/loss/Loss.java b/api/src/main/java/ai/djl/training/loss/Loss.java index a661a3e9a0e..bcf39d23b39 100644 --- a/api/src/main/java/ai/djl/training/loss/Loss.java +++ b/api/src/main/java/ai/djl/training/loss/Loss.java @@ -385,10 +385,18 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { // this is a synchronized operation, only call it at end of batch or epoch float update = evaluate(labels, predictions).sum().getFloat(); - totalInstances.compute(key, (k, v) -> v + 1); - totalLoss.compute(key, (k, v) -> v + update); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + 1); + totalLoss.compute(key, (k, v) -> v + update); + } } /** {@inheritDoc} */ diff --git a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java index 5b642285c3e..9edb45ff5f0 100644 --- a/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java +++ b/extensions/timeseries/src/main/java/ai/djl/timeseries/evaluator/Rmsse.java @@ -94,15 +94,23 @@ public void addAccumulator(String key) { /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { + updateAccumulators(new String[] {key}, labels, predictions); + } + + /** {@inheritDoc} */ + @Override + public void updateAccumulators(String[] keys, NDList labels, NDList predictions) { Pair update = evaluateHelper(labels, predictions); - totalInstances.compute(key, (k, v) -> v + update.getKey()); - totalLoss.compute( - key, - (k, v) -> { - try (NDArray array = update.getValue().sum()) { - return v + array.getFloat(); - } - }); + for (String key : keys) { + totalInstances.compute(key, (k, v) -> v + update.getKey()); + totalLoss.compute( + key, + (k, v) -> { + try (NDArray array = update.getValue().sum()) { + return v + array.getFloat(); + } + }); + } } /** {@inheritDoc} */