Skip to content

Commit

Permalink
[api] Added Builder for Early stopping configuration (deepjavalibrary#38
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jagodevreede committed Oct 16, 2023
1 parent a2934c5 commit 2a64f36
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 90 deletions.
232 changes: 169 additions & 63 deletions api/src/main/java/ai/djl/training/listener/EarlyStoppingListener.java
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.training.listener;

import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;

import java.time.Duration;

/**
* Listener that allows the training to be stopped early if the validation loss is not improving, or if time has expired.
* <br/>
* Usage:
* Add this listener to the training config, and add it as the last one.
* Listener that allows the training to be stopped early if the validation loss is not improving, or
* if time has expired. <br>
*
* <p>Usage: Add this listener to the training config, and add it as the last one.
*
* <pre>
* new DefaultTrainingConfig(...)
* .addTrainingListeners(new EarlyStoppingListener()
* .addTrainingListeners(EarlyStoppingListener.builder()
* .setEpochPatience(1)
* .setEarlyStopPctImprovement(1)
* .setMaxMinutes(60)
* .setMaxDuration(Duration.ofMinutes(42))
* .setMinEpochs(1)
* .build()
* );
* </pre>
* Then surround the fit with a try catch that catches the {@link EarlyStoppingListener.EarlyStoppedException}.
* <br/>
*
* <p>Then surround the fit with a try catch that catches the {@link
* EarlyStoppingListener.EarlyStoppedException}. <br>
* Example:
*
* <pre>
* try {
* EasyTrain.fit(trainer, 5, trainDataset, testDataset);
Expand All @@ -27,61 +46,54 @@
* log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
* }
* </pre>
* <br/>
*
* <br>
* Note: Ensure that Metrics are set on the trainer.
*/
public class EarlyStoppingListener implements TrainingListener {
public final class EarlyStoppingListener implements TrainingListener {
private final double objectiveSuccess;

/**
* after minimum # epochs, consider stopping if:
*/
private int minEpochs;
/**
* too much time elapsed
*/
private int maxMinutes;
/**
* consider early stopping if not x% improvement
*/
private double earlyStopPctImprovement;
/**
* stop if insufficient improvement for x epochs in a row
*/
private int epochPatience;
private final int minEpochs;
private final long maxMillis;
private final double earlyStopPctImprovement;
private final int epochPatience;

private long startTimeMills;
private double prevLoss;
private int numberOfEpochsWithoutImprovements;

public EarlyStoppingListener() {
this.objectiveSuccess = 0;
this.minEpochs = 0;
this.maxMinutes = Integer.MAX_VALUE;
this.earlyStopPctImprovement = 0;
this.epochPatience = 0;
}

public EarlyStoppingListener(double objectiveSuccess, int minEpochs, int maxMinutes, double earlyStopPctImprovement, int earlyStopPatience) {
private EarlyStoppingListener(
double objectiveSuccess,
int minEpochs,
long maxMillis,
double earlyStopPctImprovement,
int earlyStopPatience) {
this.objectiveSuccess = objectiveSuccess;
this.minEpochs = minEpochs;
this.maxMinutes = maxMinutes;
this.maxMillis = maxMillis;
this.earlyStopPctImprovement = earlyStopPctImprovement;
this.epochPatience = earlyStopPatience;
}

/** {@inheritDoc} */
@Override
public void onEpoch(Trainer trainer) {
int currentEpoch = trainer.getTrainingResult().getEpoch();
// stopping criteria
final double loss = getLoss(trainer);
if (loss < objectiveSuccess) {
throw new EarlyStoppedException(currentEpoch, String.format("validation loss %s < objectiveSuccess %s", loss, objectiveSuccess));
}
final double loss = getLoss(trainer.getTrainingResult());
if (currentEpoch >= minEpochs) {
double elapsedMinutes = (System.currentTimeMillis() - startTimeMills) / 60_000.0;
if (elapsedMinutes >= maxMinutes) {
throw new EarlyStoppedException(currentEpoch, String.format("Early stopping training: %.1f minutes elapsed >= %s maxMinutes", elapsedMinutes, maxMinutes));
if (loss < objectiveSuccess) {
throw new EarlyStoppedException(
currentEpoch,
String.format(
"validation loss %s < objectiveSuccess %s",
loss, objectiveSuccess));
}
long elapsedMillis = System.currentTimeMillis() - startTimeMills;
if (elapsedMillis >= maxMillis) {
throw new EarlyStoppedException(
currentEpoch,
String.format("%s ms elapsed >= %s maxMillis", elapsedMillis, maxMillis));
}
// consider early stopping?
if (Double.isFinite(prevLoss)) {
Expand All @@ -92,8 +104,11 @@ public void onEpoch(Trainer trainer) {
} else {
numberOfEpochsWithoutImprovements++;
if (numberOfEpochsWithoutImprovements >= epochPatience) {
throw new EarlyStoppedException(currentEpoch, String.format("failed to achieve %s%% improvement %s times in a row",
earlyStopPctImprovement, epochPatience));
throw new EarlyStoppedException(
currentEpoch,
String.format(
"failed to achieve %s%% improvement %s times in a row",
earlyStopPctImprovement, epochPatience));
}
}
}
Expand All @@ -103,71 +118,162 @@ public void onEpoch(Trainer trainer) {
}
}

private static double getLoss(Trainer trainer) {
Float vLoss = trainer.getTrainingResult().getValidateLoss();
private static double getLoss(TrainingResult trainingResult) {
Float vLoss = trainingResult.getValidateLoss();
if (vLoss != null) {
return vLoss;
}
Float tLoss = trainer.getTrainingResult().getTrainLoss();
Float tLoss = trainingResult.getTrainLoss();
if (tLoss == null) {
return Double.NaN;
}
return tLoss;
}

/** {@inheritDoc} */
@Override
public void onTrainingBatch(Trainer trainer, BatchData batchData) {
// do nothing
}

/** {@inheritDoc} */
@Override
public void onValidationBatch(Trainer trainer, BatchData batchData) {
// do nothing
}

/** {@inheritDoc} */
@Override
public void onTrainingBegin(Trainer trainer) {
this.startTimeMills = System.currentTimeMillis();
this.prevLoss = Double.NaN;
this.numberOfEpochsWithoutImprovements = 0;
}

/** {@inheritDoc} */
@Override
public void onTrainingEnd(Trainer trainer) {
// do nothing
}

public EarlyStoppingListener setMinEpochs(int minEpochs) {
this.minEpochs = minEpochs;
return this;
/**
* Creates a builder to build a {@link EarlyStoppingListener}.
*
* @return a new builder
*/
public static Builder builder() {
return new Builder();
}

public EarlyStoppingListener setMaxMinutes(int maxMinutes) {
this.maxMinutes = maxMinutes;
return this;
}
/** A builder for a {@link EarlyStoppingListener}. */
public static final class Builder {
private final double objectiveSuccess;
private int minEpochs;
private long maxMillis;
private double earlyStopPctImprovement;
private int epochPatience;

public EarlyStoppingListener setEarlyStopPctImprovement(double earlyStopPctImprovement) {
this.earlyStopPctImprovement = earlyStopPctImprovement;
return this;
}
/** Constructs a {@link Builder} with default values. */
public Builder() {
this.objectiveSuccess = 0;
this.minEpochs = 0;
this.maxMillis = Long.MAX_VALUE;
this.earlyStopPctImprovement = 0;
this.epochPatience = 0;
}

/**
* Set the minimum # epochs, defaults to 0.
*
* @param minEpochs the minimum # epochs
* @return this builder
*/
public Builder optMinEpochs(int minEpochs) {
this.minEpochs = minEpochs;
return this;
}

/**
* Set the maximum duration a training run should take, defaults to Long.MAX_VALUE in ms.
*
* @param duration the maximum duration a training run should take
* @return this builder
*/
public Builder optMaxDuration(Duration duration) {
this.maxMillis = duration.toMillis();
return this;
}

public EarlyStoppingListener setEpochPatience(int epochPatience) {
this.epochPatience = epochPatience;
return this;
/**
* Set the maximum # milliseconds a training run should take, defaults to Long.MAX_VALUE.
*
* @param maxMillis the maximum # milliseconds a training run should take
* @return this builder
*/
public Builder optMaxMillis(int maxMillis) {
this.maxMillis = maxMillis;
return this;
}

/**
* Consider early stopping if not x% improvement, defaults to 0.
*
* @param earlyStopPctImprovement the percentage improvement to consider early stopping,
* must be between 0 and 100.
* @return this builder
*/
public Builder optEarlyStopPctImprovement(double earlyStopPctImprovement) {
this.earlyStopPctImprovement = earlyStopPctImprovement;
return this;
}

/**
* Stop if insufficient improvement for x epochs in a row, defaults to 0.
*
* @param epochPatience the number of epochs without improvement to consider stopping, must
* be greater than 0.
* @return this builder
*/
public Builder optEpochPatience(int epochPatience) {
this.epochPatience = epochPatience;
return this;
}

/**
* Builds a {@link EarlyStoppingListener} with the specified values.
*
* @return a new {@link EarlyStoppingListener}
*/
public EarlyStoppingListener build() {
return new EarlyStoppingListener(
objectiveSuccess, minEpochs, maxMillis, earlyStopPctImprovement, epochPatience);
}
}

/**
* Thrown when training is stopped early, the message will contain the reason why it is stopped early.
* Thrown when training is stopped early, the message will contain the reason why it is stopped
* early.
*/
public static class EarlyStoppedException extends RuntimeException {
private static final long serialVersionUID = 1L;
private final int stopEpoch;

/**
* Constructs an {@link EarlyStoppedException} with the specified message and epoch.
*
* @param stopEpoch the epoch at which training was stopped early
* @param message the message/reason why training was stopped early
*/
public EarlyStoppedException(int stopEpoch, String message) {
super(message);
this.stopEpoch = stopEpoch;
}

/**
* Gets the epoch at which training was stopped early.
*
* @return the epoch at which training was stopped early.
*/
public int getStopEpoch() {
return stopEpoch;
}
Expand Down
Loading

0 comments on commit 2a64f36

Please sign in to comment.