From 2a64f36d7213b4ca7821640c0a70254c2e03efa4 Mon Sep 17 00:00:00 2001 From: Jago de Vreede Date: Mon, 16 Oct 2023 07:15:33 +0200 Subject: [PATCH] [api] Added Builder for Early stopping configuration (#38) --- .../listener/EarlyStoppingListener.java | 232 +++++++++++++----- .../listener/EarlyStoppingListenerTest.java | 116 +++++++-- .../tests/training/listener/package-info.java | 15 ++ 3 files changed, 273 insertions(+), 90 deletions(-) create mode 100644 integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java index 9aaf0433564..6c013c37715 100644 --- a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java +++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java @@ -1,24 +1,43 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ package ai.djl.training.listener; import ai.djl.training.Trainer; +import ai.djl.training.TrainingResult; + +import java.time.Duration; /** - * Listener that allows the training to be stopped early if the validation loss is not improving, or if time has expired. - *
- * Usage: - * Add this listener to the training config, and add it as the last one. + * Listener that allows the training to be stopped early if the validation loss is not improving, or + * if time has expired.
+ * + *

Usage: Add this listener to the training config, and add it as the last one. + * *

  *  new DefaultTrainingConfig(...)
- *        .addTrainingListeners(new EarlyStoppingListener()
+ *        .addTrainingListeners(EarlyStoppingListener.builder()
  *                .setEpochPatience(1)
  *                .setEarlyStopPctImprovement(1)
- *                .setMaxMinutes(60)
+ *                .setMaxDuration(Duration.ofMinutes(42))
  *                .setMinEpochs(1)
+ *                .build()
  *        );
  * 
- * Then surround the fit with a try catch that catches the {@link EarlyStoppingListener.EarlyStoppedException}. - *
+ * + *

Then surround the fit with a try catch that catches the {@link + * EarlyStoppingListener.EarlyStoppedException}.
* Example: + * *

  * try {
  *   EasyTrain.fit(trainer, 5, trainDataset, testDataset);
@@ -27,61 +46,54 @@
  *   log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
  * }
  * 
- *
+ * + *
* Note: Ensure that Metrics are set on the trainer. */ -public class EarlyStoppingListener implements TrainingListener { +public final class EarlyStoppingListener implements TrainingListener { private final double objectiveSuccess; - /** - * after minimum # epochs, consider stopping if: - */ - private int minEpochs; - /** - * too much time elapsed - */ - private int maxMinutes; - /** - * consider early stopping if not x% improvement - */ - private double earlyStopPctImprovement; - /** - * stop if insufficient improvement for x epochs in a row - */ - private int epochPatience; + private final int minEpochs; + private final long maxMillis; + private final double earlyStopPctImprovement; + private final int epochPatience; private long startTimeMills; private double prevLoss; private int numberOfEpochsWithoutImprovements; - public EarlyStoppingListener() { - this.objectiveSuccess = 0; - this.minEpochs = 0; - this.maxMinutes = Integer.MAX_VALUE; - this.earlyStopPctImprovement = 0; - this.epochPatience = 0; - } - - public EarlyStoppingListener(double objectiveSuccess, int minEpochs, int maxMinutes, double earlyStopPctImprovement, int earlyStopPatience) { + private EarlyStoppingListener( + double objectiveSuccess, + int minEpochs, + long maxMillis, + double earlyStopPctImprovement, + int earlyStopPatience) { this.objectiveSuccess = objectiveSuccess; this.minEpochs = minEpochs; - this.maxMinutes = maxMinutes; + this.maxMillis = maxMillis; this.earlyStopPctImprovement = earlyStopPctImprovement; this.epochPatience = earlyStopPatience; } + /** {@inheritDoc} */ @Override public void onEpoch(Trainer trainer) { int currentEpoch = trainer.getTrainingResult().getEpoch(); // stopping criteria - final double loss = getLoss(trainer); - if (loss < objectiveSuccess) { - throw new EarlyStoppedException(currentEpoch, String.format("validation loss %s < objectiveSuccess %s", loss, objectiveSuccess)); - } + final double loss = getLoss(trainer.getTrainingResult()); if (currentEpoch >= minEpochs) { - double elapsedMinutes = (System.currentTimeMillis() - startTimeMills) / 60_000.0; - if (elapsedMinutes >= maxMinutes) { - throw new EarlyStoppedException(currentEpoch, String.format("Early stopping training: %.1f minutes elapsed >= %s maxMinutes", elapsedMinutes, maxMinutes)); + if (loss < objectiveSuccess) { + throw new EarlyStoppedException( + currentEpoch, + String.format( + "validation loss %s < objectiveSuccess %s", + loss, objectiveSuccess)); + } + long elapsedMillis = System.currentTimeMillis() - startTimeMills; + if (elapsedMillis >= maxMillis) { + throw new EarlyStoppedException( + currentEpoch, + String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis)); } // consider early stopping? if (Double.isFinite(prevLoss)) { @@ -92,8 +104,11 @@ public void onEpoch(Trainer trainer) { } else { numberOfEpochsWithoutImprovements++; if (numberOfEpochsWithoutImprovements >= epochPatience) { - throw new EarlyStoppedException(currentEpoch, String.format("failed to achieve %s%% improvement %s times in a row", - earlyStopPctImprovement, epochPatience)); + throw new EarlyStoppedException( + currentEpoch, + String.format( + "failed to achieve %s%% improvement %s times in a row", + earlyStopPctImprovement, epochPatience)); } } } @@ -103,28 +118,31 @@ public void onEpoch(Trainer trainer) { } } - private static double getLoss(Trainer trainer) { - Float vLoss = trainer.getTrainingResult().getValidateLoss(); + private static double getLoss(TrainingResult trainingResult) { + Float vLoss = trainingResult.getValidateLoss(); if (vLoss != null) { return vLoss; } - Float tLoss = trainer.getTrainingResult().getTrainLoss(); + Float tLoss = trainingResult.getTrainLoss(); if (tLoss == null) { return Double.NaN; } return tLoss; } + /** {@inheritDoc} */ @Override public void onTrainingBatch(Trainer trainer, BatchData batchData) { // do nothing } + /** {@inheritDoc} */ @Override public void onValidationBatch(Trainer trainer, BatchData batchData) { // do nothing } + /** {@inheritDoc} */ @Override public void onTrainingBegin(Trainer trainer) { this.startTimeMills = System.currentTimeMillis(); @@ -132,42 +150,130 @@ public void onTrainingBegin(Trainer trainer) { this.numberOfEpochsWithoutImprovements = 0; } + /** {@inheritDoc} */ @Override public void onTrainingEnd(Trainer trainer) { // do nothing } - public EarlyStoppingListener setMinEpochs(int minEpochs) { - this.minEpochs = minEpochs; - return this; + /** + * Creates a builder to build a {@link EarlyStoppingListener}. + * + * @return a new builder + */ + public static Builder builder() { + return new Builder(); } - public EarlyStoppingListener setMaxMinutes(int maxMinutes) { - this.maxMinutes = maxMinutes; - return this; - } + /** A builder for a {@link EarlyStoppingListener}. */ + public static final class Builder { + private final double objectiveSuccess; + private int minEpochs; + private long maxMillis; + private double earlyStopPctImprovement; + private int epochPatience; - public EarlyStoppingListener setEarlyStopPctImprovement(double earlyStopPctImprovement) { - this.earlyStopPctImprovement = earlyStopPctImprovement; - return this; - } + /** Constructs a {@link Builder} with default values. */ + public Builder() { + this.objectiveSuccess = 0; + this.minEpochs = 0; + this.maxMillis = Long.MAX_VALUE; + this.earlyStopPctImprovement = 0; + this.epochPatience = 0; + } + + /** + * Set the minimum # epochs, defaults to 0. + * + * @param minEpochs the minimum # epochs + * @return this builder + */ + public Builder optMinEpochs(int minEpochs) { + this.minEpochs = minEpochs; + return this; + } + + /** + * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms. + * + * @param duration the maximum duration a training run should take + * @return this builder + */ + public Builder optMaxDuration(Duration duration) { + this.maxMillis = duration.toMillis(); + return this; + } - public EarlyStoppingListener setEpochPatience(int epochPatience) { - this.epochPatience = epochPatience; - return this; + /** + * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE. + * + * @param maxMillis the maximum # milliseconds a training run should take + * @return this builder + */ + public Builder optMaxMillis(int maxMillis) { + this.maxMillis = maxMillis; + return this; + } + + /** + * Consider early stopping if not x% improvement, defaults to 0. + * + * @param earlyStopPctImprovement the percentage improvement to consider early stopping, + * must be between 0 and 100. + * @return this builder + */ + public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) { + this.earlyStopPctImprovement = earlyStopPctImprovement; + return this; + } + + /** + * Stop if insufficient improvement for x epochs in a row, defaults to 0. + * + * @param epochPatience the number of epochs without improvement to consider stopping, must + * be greater than 0. + * @return this builder + */ + public Builder optEpochPatience(int epochPatience) { + this.epochPatience = epochPatience; + return this; + } + + /** + * Builds a {@link EarlyStoppingListener} with the specified values. + * + * @return a new {@link EarlyStoppingListener} + */ + public EarlyStoppingListener build() { + return new EarlyStoppingListener( + objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience); + } } /** - * Thrown when training is stopped early, the message will contain the reason why it is stopped early. + * Thrown when training is stopped early, the message will contain the reason why it is stopped + * early. */ public static class EarlyStoppedException extends RuntimeException { private static final long serialVersionUID = 1L; private final int stopEpoch; + + /** + * Constructs an {@link EarlyStoppedException} with the specified message and epoch. + * + * @param stopEpoch the epoch at which training was stopped early + * @param message the message/reason why training was stopped early + */ public EarlyStoppedException(int stopEpoch, String message) { super(message); this.stopEpoch = stopEpoch; } + /** + * Gets the epoch at which training was stopped early. + * + * @return the epoch at which training was stopped early. + */ public int getStopEpoch() { return stopEpoch; } diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java index c8e4d30b428..91b6993a2d9 100644 --- a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java @@ -1,3 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ package ai.djl.integration.tests.training.listener; import ai.djl.Model; @@ -18,50 +30,59 @@ import ai.djl.training.optimizer.Optimizer; import ai.djl.training.tracker.Tracker; import ai.djl.translate.TranslateException; + import org.testng.Assert; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; import java.io.IOException; +import java.time.Duration; public class EarlyStoppingListenerTest { - private final Optimizer sgd = Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build(); + private final Optimizer sgd = + Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build(); private Mnist testMnistDataset; private Mnist trainMnistDataset; @BeforeTest public void setUp() throws IOException, TranslateException { - testMnistDataset = Mnist.builder() - .optUsage(Dataset.Usage.TEST) - .setSampling(32, false) - .build(); + testMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TEST) + .optLimit(8) + .setSampling(8, false) + .build(); testMnistDataset.prepare(); - trainMnistDataset = Mnist.builder() - .optUsage(Dataset.Usage.TRAIN) - .setSampling(32, false) - .build(); + trainMnistDataset = + Mnist.builder() + .optUsage(Dataset.Usage.TRAIN) + .optLimit(16) + .setSampling(8, false) + .build(); trainMnistDataset.prepare(); } @Test public void testEarlyStoppingStopsOnEpoch2() throws Exception { - Mlp mlpModel = new Mlp(784, 1, new int[]{256}, Activation::relu); + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { model.setBlock(mlpModel); - DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optOptimizer(sgd) - .addTrainingListeners(TrainingListener.Defaults.logging()) - .addTrainingListeners(new EarlyStoppingListener() - .setEpochPatience(1) - .setEarlyStopPctImprovement(50) - .setMaxMinutes(60) - .setMinEpochs(1) - ); + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(99) + .optMaxDuration(Duration.ofMinutes(1)) + .optMinEpochs(1) + .build()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new Shape(1, 784)); @@ -72,7 +93,8 @@ public void testEarlyStoppingStopsOnEpoch2() throws Exception { // Set epoch to 5 as we expect the early stopping to stop after the second epoch EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); } catch (EarlyStoppingListener.EarlyStoppedException e) { - Assert.assertEquals(e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row"); + Assert.assertEquals( + e.getMessage(), "failed to achieve 99.0% improvement 1 times in a row"); Assert.assertEquals(e.getStopEpoch(), 2); } @@ -84,15 +106,22 @@ public void testEarlyStoppingStopsOnEpoch2() throws Exception { @Test public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception { - Mlp mlpModel = new Mlp(784, 1, new int[]{256}, Activation::relu); + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { model.setBlock(mlpModel); - DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss()) - .optOptimizer(sgd) - .addTrainingListeners(TrainingListener.Defaults.logging()) - .addTrainingListeners(new EarlyStoppingListener(0, 3, 60, 50, 1)); + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder() + .optEpochPatience(1) + .optEarlyStopPctImprovement(50) + .optMaxMillis(60_000) + .optMinEpochs(3) + .build()); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new Shape(1, 784)); @@ -103,7 +132,8 @@ public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception { // Set epoch to 5 as we expect the early stopping to stop after the second epoch EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); } catch (EarlyStoppingListener.EarlyStoppedException e) { - Assert.assertEquals(e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row"); + Assert.assertEquals( + e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row"); Assert.assertEquals(e.getStopEpoch(), 3); } @@ -113,4 +143,36 @@ public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception { } } -} \ No newline at end of file + @Test + public void testEarlyStoppingStopsOnEpoch1AsMaxDurationIs1ms() throws Exception { + Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu); + + try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) { + model.setBlock(mlpModel); + + DefaultTrainingConfig config = + new DefaultTrainingConfig(Loss.l2Loss()) + .optOptimizer(sgd) + .addTrainingListeners(TrainingListener.Defaults.logging()) + .addTrainingListeners( + EarlyStoppingListener.builder().optMaxMillis(1).build()); + + try (Trainer trainer = model.newTrainer(config)) { + trainer.initialize(new Shape(1, 784)); + Metrics metrics = new Metrics(); + trainer.setMetrics(metrics); + + try { + // Set epoch to 5 as we expect the early stopping to stop after the second epoch + EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset); + } catch (EarlyStoppingListener.EarlyStoppedException e) { + Assert.assertTrue(e.getMessage().contains("ms elapsed >= 1 maxMillis")); + Assert.assertEquals(e.getStopEpoch(), 1); + } + + TrainingResult trainingResult = trainer.getTrainingResult(); + Assert.assertEquals(trainingResult.getEpoch(), 1); + } + } + } +} diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java new file mode 100644 index 00000000000..88680e5fe89 --- /dev/null +++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains tests using the listeners {@link ai.djl.training}. */ +package ai.djl.integration.tests.training.listener;