diff --git a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
index 9aaf0433564..6c013c37715 100644
--- a/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
+++ b/api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
@@ -1,24 +1,43 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
package ai.djl.training.listener;
import ai.djl.training.Trainer;
+import ai.djl.training.TrainingResult;
+
+import java.time.Duration;
/**
- * Listener that allows the training to be stopped early if the validation loss is not improving, or if time has expired.
- *
- * Usage:
- * Add this listener to the training config, and add it as the last one.
+ * Listener that allows the training to be stopped early if the validation loss is not improving, or
+ * if time has expired.
+ *
+ *
Usage: Add this listener to the training config, and add it as the last one.
+ *
*
* new DefaultTrainingConfig(...)
- * .addTrainingListeners(new EarlyStoppingListener()
+ * .addTrainingListeners(EarlyStoppingListener.builder()
* .setEpochPatience(1)
* .setEarlyStopPctImprovement(1)
- * .setMaxMinutes(60)
+ * .setMaxDuration(Duration.ofMinutes(42))
* .setMinEpochs(1)
+ * .build()
* );
*
- * Then surround the fit with a try catch that catches the {@link EarlyStoppingListener.EarlyStoppedException}.
- *
+ *
+ * Then surround the fit with a try catch that catches the {@link
+ * EarlyStoppingListener.EarlyStoppedException}.
* Example:
+ *
*
* try {
* EasyTrain.fit(trainer, 5, trainDataset, testDataset);
@@ -27,61 +46,54 @@
* log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
* }
*
- *
+ *
+ *
* Note: Ensure that Metrics are set on the trainer.
*/
-public class EarlyStoppingListener implements TrainingListener {
+public final class EarlyStoppingListener implements TrainingListener {
private final double objectiveSuccess;
- /**
- * after minimum # epochs, consider stopping if:
- */
- private int minEpochs;
- /**
- * too much time elapsed
- */
- private int maxMinutes;
- /**
- * consider early stopping if not x% improvement
- */
- private double earlyStopPctImprovement;
- /**
- * stop if insufficient improvement for x epochs in a row
- */
- private int epochPatience;
+ private final int minEpochs;
+ private final long maxMillis;
+ private final double earlyStopPctImprovement;
+ private final int epochPatience;
private long startTimeMills;
private double prevLoss;
private int numberOfEpochsWithoutImprovements;
- public EarlyStoppingListener() {
- this.objectiveSuccess = 0;
- this.minEpochs = 0;
- this.maxMinutes = Integer.MAX_VALUE;
- this.earlyStopPctImprovement = 0;
- this.epochPatience = 0;
- }
-
- public EarlyStoppingListener(double objectiveSuccess, int minEpochs, int maxMinutes, double earlyStopPctImprovement, int earlyStopPatience) {
+ private EarlyStoppingListener(
+ double objectiveSuccess,
+ int minEpochs,
+ long maxMillis,
+ double earlyStopPctImprovement,
+ int earlyStopPatience) {
this.objectiveSuccess = objectiveSuccess;
this.minEpochs = minEpochs;
- this.maxMinutes = maxMinutes;
+ this.maxMillis = maxMillis;
this.earlyStopPctImprovement = earlyStopPctImprovement;
this.epochPatience = earlyStopPatience;
}
+ /** {@inheritDoc} */
@Override
public void onEpoch(Trainer trainer) {
int currentEpoch = trainer.getTrainingResult().getEpoch();
// stopping criteria
- final double loss = getLoss(trainer);
- if (loss < objectiveSuccess) {
- throw new EarlyStoppedException(currentEpoch, String.format("validation loss %s < objectiveSuccess %s", loss, objectiveSuccess));
- }
+ final double loss = getLoss(trainer.getTrainingResult());
if (currentEpoch >= minEpochs) {
- double elapsedMinutes = (System.currentTimeMillis() - startTimeMills) / 60_000.0;
- if (elapsedMinutes >= maxMinutes) {
- throw new EarlyStoppedException(currentEpoch, String.format("Early stopping training: %.1f minutes elapsed >= %s maxMinutes", elapsedMinutes, maxMinutes));
+ if (loss < objectiveSuccess) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format(
+ "validation loss %s < objectiveSuccess %s",
+ loss, objectiveSuccess));
+ }
+ long elapsedMillis = System.currentTimeMillis() - startTimeMills;
+ if (elapsedMillis >= maxMillis) {
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis));
}
// consider early stopping?
if (Double.isFinite(prevLoss)) {
@@ -92,8 +104,11 @@ public void onEpoch(Trainer trainer) {
} else {
numberOfEpochsWithoutImprovements++;
if (numberOfEpochsWithoutImprovements >= epochPatience) {
- throw new EarlyStoppedException(currentEpoch, String.format("failed to achieve %s%% improvement %s times in a row",
- earlyStopPctImprovement, epochPatience));
+ throw new EarlyStoppedException(
+ currentEpoch,
+ String.format(
+ "failed to achieve %s%% improvement %s times in a row",
+ earlyStopPctImprovement, epochPatience));
}
}
}
@@ -103,28 +118,31 @@ public void onEpoch(Trainer trainer) {
}
}
- private static double getLoss(Trainer trainer) {
- Float vLoss = trainer.getTrainingResult().getValidateLoss();
+ private static double getLoss(TrainingResult trainingResult) {
+ Float vLoss = trainingResult.getValidateLoss();
if (vLoss != null) {
return vLoss;
}
- Float tLoss = trainer.getTrainingResult().getTrainLoss();
+ Float tLoss = trainingResult.getTrainLoss();
if (tLoss == null) {
return Double.NaN;
}
return tLoss;
}
+ /** {@inheritDoc} */
@Override
public void onTrainingBatch(Trainer trainer, BatchData batchData) {
// do nothing
}
+ /** {@inheritDoc} */
@Override
public void onValidationBatch(Trainer trainer, BatchData batchData) {
// do nothing
}
+ /** {@inheritDoc} */
@Override
public void onTrainingBegin(Trainer trainer) {
this.startTimeMills = System.currentTimeMillis();
@@ -132,42 +150,130 @@ public void onTrainingBegin(Trainer trainer) {
this.numberOfEpochsWithoutImprovements = 0;
}
+ /** {@inheritDoc} */
@Override
public void onTrainingEnd(Trainer trainer) {
// do nothing
}
- public EarlyStoppingListener setMinEpochs(int minEpochs) {
- this.minEpochs = minEpochs;
- return this;
+ /**
+ * Creates a builder to build a {@link EarlyStoppingListener}.
+ *
+ * @return a new builder
+ */
+ public static Builder builder() {
+ return new Builder();
}
- public EarlyStoppingListener setMaxMinutes(int maxMinutes) {
- this.maxMinutes = maxMinutes;
- return this;
- }
+ /** A builder for a {@link EarlyStoppingListener}. */
+ public static final class Builder {
+ private final double objectiveSuccess;
+ private int minEpochs;
+ private long maxMillis;
+ private double earlyStopPctImprovement;
+ private int epochPatience;
- public EarlyStoppingListener setEarlyStopPctImprovement(double earlyStopPctImprovement) {
- this.earlyStopPctImprovement = earlyStopPctImprovement;
- return this;
- }
+ /** Constructs a {@link Builder} with default values. */
+ public Builder() {
+ this.objectiveSuccess = 0;
+ this.minEpochs = 0;
+ this.maxMillis = Long.MAX_VALUE;
+ this.earlyStopPctImprovement = 0;
+ this.epochPatience = 0;
+ }
+
+ /**
+ * Set the minimum # epochs, defaults to 0.
+ *
+ * @param minEpochs the minimum # epochs
+ * @return this builder
+ */
+ public Builder optMinEpochs(int minEpochs) {
+ this.minEpochs = minEpochs;
+ return this;
+ }
+
+ /**
+ * Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms.
+ *
+ * @param duration the maximum duration a training run should take
+ * @return this builder
+ */
+ public Builder optMaxDuration(Duration duration) {
+ this.maxMillis = duration.toMillis();
+ return this;
+ }
- public EarlyStoppingListener setEpochPatience(int epochPatience) {
- this.epochPatience = epochPatience;
- return this;
+ /**
+ * Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE.
+ *
+ * @param maxMillis the maximum # milliseconds a training run should take
+ * @return this builder
+ */
+ public Builder optMaxMillis(int maxMillis) {
+ this.maxMillis = maxMillis;
+ return this;
+ }
+
+ /**
+ * Consider early stopping if not x% improvement, defaults to 0.
+ *
+ * @param earlyStopPctImprovement the percentage improvement to consider early stopping,
+ * must be between 0 and 100.
+ * @return this builder
+ */
+ public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) {
+ this.earlyStopPctImprovement = earlyStopPctImprovement;
+ return this;
+ }
+
+ /**
+ * Stop if insufficient improvement for x epochs in a row, defaults to 0.
+ *
+ * @param epochPatience the number of epochs without improvement to consider stopping, must
+ * be greater than 0.
+ * @return this builder
+ */
+ public Builder optEpochPatience(int epochPatience) {
+ this.epochPatience = epochPatience;
+ return this;
+ }
+
+ /**
+ * Builds a {@link EarlyStoppingListener} with the specified values.
+ *
+ * @return a new {@link EarlyStoppingListener}
+ */
+ public EarlyStoppingListener build() {
+ return new EarlyStoppingListener(
+ objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience);
+ }
}
/**
- * Thrown when training is stopped early, the message will contain the reason why it is stopped early.
+ * Thrown when training is stopped early, the message will contain the reason why it is stopped
+ * early.
*/
public static class EarlyStoppedException extends RuntimeException {
private static final long serialVersionUID = 1L;
private final int stopEpoch;
+
+ /**
+ * Constructs an {@link EarlyStoppedException} with the specified message and epoch.
+ *
+ * @param stopEpoch the epoch at which training was stopped early
+ * @param message the message/reason why training was stopped early
+ */
public EarlyStoppedException(int stopEpoch, String message) {
super(message);
this.stopEpoch = stopEpoch;
}
+ /**
+ * Gets the epoch at which training was stopped early.
+ *
+ * @return the epoch at which training was stopped early.
+ */
public int getStopEpoch() {
return stopEpoch;
}
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java
index c8e4d30b428..91b6993a2d9 100644
--- a/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java
+++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/EarlyStoppingListenerTest.java
@@ -1,3 +1,15 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
package ai.djl.integration.tests.training.listener;
import ai.djl.Model;
@@ -18,50 +30,59 @@
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;
+
import org.testng.Assert;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
import java.io.IOException;
+import java.time.Duration;
public class EarlyStoppingListenerTest {
- private final Optimizer sgd = Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build();
+ private final Optimizer sgd =
+ Optimizer.sgd().setLearningRateTracker(Tracker.fixed(0.1f)).build();
private Mnist testMnistDataset;
private Mnist trainMnistDataset;
@BeforeTest
public void setUp() throws IOException, TranslateException {
- testMnistDataset = Mnist.builder()
- .optUsage(Dataset.Usage.TEST)
- .setSampling(32, false)
- .build();
+ testMnistDataset =
+ Mnist.builder()
+ .optUsage(Dataset.Usage.TEST)
+ .optLimit(8)
+ .setSampling(8, false)
+ .build();
testMnistDataset.prepare();
- trainMnistDataset = Mnist.builder()
- .optUsage(Dataset.Usage.TRAIN)
- .setSampling(32, false)
- .build();
+ trainMnistDataset =
+ Mnist.builder()
+ .optUsage(Dataset.Usage.TRAIN)
+ .optLimit(16)
+ .setSampling(8, false)
+ .build();
trainMnistDataset.prepare();
}
@Test
public void testEarlyStoppingStopsOnEpoch2() throws Exception {
- Mlp mlpModel = new Mlp(784, 1, new int[]{256}, Activation::relu);
+ Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu);
try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) {
model.setBlock(mlpModel);
- DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss())
- .optOptimizer(sgd)
- .addTrainingListeners(TrainingListener.Defaults.logging())
- .addTrainingListeners(new EarlyStoppingListener()
- .setEpochPatience(1)
- .setEarlyStopPctImprovement(50)
- .setMaxMinutes(60)
- .setMinEpochs(1)
- );
+ DefaultTrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optOptimizer(sgd)
+ .addTrainingListeners(TrainingListener.Defaults.logging())
+ .addTrainingListeners(
+ EarlyStoppingListener.builder()
+ .optEpochPatience(1)
+ .optEarlyStopPctImprovement(99)
+ .optMaxDuration(Duration.ofMinutes(1))
+ .optMinEpochs(1)
+ .build());
try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(1, 784));
@@ -72,7 +93,8 @@ public void testEarlyStoppingStopsOnEpoch2() throws Exception {
// Set epoch to 5 as we expect the early stopping to stop after the second epoch
EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset);
} catch (EarlyStoppingListener.EarlyStoppedException e) {
- Assert.assertEquals(e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row");
+ Assert.assertEquals(
+ e.getMessage(), "failed to achieve 99.0% improvement 1 times in a row");
Assert.assertEquals(e.getStopEpoch(), 2);
}
@@ -84,15 +106,22 @@ public void testEarlyStoppingStopsOnEpoch2() throws Exception {
@Test
public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception {
- Mlp mlpModel = new Mlp(784, 1, new int[]{256}, Activation::relu);
+ Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu);
try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) {
model.setBlock(mlpModel);
- DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.l2Loss())
- .optOptimizer(sgd)
- .addTrainingListeners(TrainingListener.Defaults.logging())
- .addTrainingListeners(new EarlyStoppingListener(0, 3, 60, 50, 1));
+ DefaultTrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optOptimizer(sgd)
+ .addTrainingListeners(TrainingListener.Defaults.logging())
+ .addTrainingListeners(
+ EarlyStoppingListener.builder()
+ .optEpochPatience(1)
+ .optEarlyStopPctImprovement(50)
+ .optMaxMillis(60_000)
+ .optMinEpochs(3)
+ .build());
try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(1, 784));
@@ -103,7 +132,8 @@ public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception {
// Set epoch to 5 as we expect the early stopping to stop after the second epoch
EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset);
} catch (EarlyStoppingListener.EarlyStoppedException e) {
- Assert.assertEquals(e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row");
+ Assert.assertEquals(
+ e.getMessage(), "failed to achieve 50.0% improvement 1 times in a row");
Assert.assertEquals(e.getStopEpoch(), 3);
}
@@ -113,4 +143,36 @@ public void testEarlyStoppingStopsOnEpoch3AsMinEpochsIs3() throws Exception {
}
}
-}
\ No newline at end of file
+ @Test
+ public void testEarlyStoppingStopsOnEpoch1AsMaxDurationIs1ms() throws Exception {
+ Mlp mlpModel = new Mlp(784, 1, new int[] {256}, Activation::relu);
+
+ try (Model model = Model.newInstance("lin-reg", TestUtils.getEngine())) {
+ model.setBlock(mlpModel);
+
+ DefaultTrainingConfig config =
+ new DefaultTrainingConfig(Loss.l2Loss())
+ .optOptimizer(sgd)
+ .addTrainingListeners(TrainingListener.Defaults.logging())
+ .addTrainingListeners(
+ EarlyStoppingListener.builder().optMaxMillis(1).build());
+
+ try (Trainer trainer = model.newTrainer(config)) {
+ trainer.initialize(new Shape(1, 784));
+ Metrics metrics = new Metrics();
+ trainer.setMetrics(metrics);
+
+ try {
+ // Set epoch to 5 as we expect the early stopping to stop after the second epoch
+ EasyTrain.fit(trainer, 5, trainMnistDataset, testMnistDataset);
+ } catch (EarlyStoppingListener.EarlyStoppedException e) {
+ Assert.assertTrue(e.getMessage().contains("ms elapsed >= 1 maxMillis"));
+ Assert.assertEquals(e.getStopEpoch(), 1);
+ }
+
+ TrainingResult trainingResult = trainer.getTrainingResult();
+ Assert.assertEquals(trainingResult.getEpoch(), 1);
+ }
+ }
+ }
+}
diff --git a/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java
new file mode 100644
index 00000000000..88680e5fe89
--- /dev/null
+++ b/integration/src/main/java/ai/djl/integration/tests/training/listener/package-info.java
@@ -0,0 +1,15 @@
+/*
+ * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
+ * with the License. A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
+ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
+ * and limitations under the License.
+ */
+
+/** Contains tests using the listeners {@link ai.djl.training}. */
+package ai.djl.integration.tests.training.listener;