Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Early stopping configuration #38

Closed
jSaso opened this issue Mar 19, 2020 · 3 comments · Fixed by #2806
Closed

Early stopping configuration #38

jSaso opened this issue Mar 19, 2020 · 3 comments · Fixed by #2806
Assignees
Labels
enhancement New feature or request

Comments

@jSaso
Copy link

jSaso commented Mar 19, 2020

Description

Early stopping configuration: Specifies the various configuration options for running training with early stopping.

  • early stopping model saver - only use last best model: How model will be saved (to disk, to memory, etc)
  • Termination conditions:
    1. Iteration termination conditions: how many epoch till termination.
    2. score improvement termination condition - terminate training if best model score does not improve for N epochs
    3. best expected score - terminate training once we achieved an expected score.
    4. termination condition after certain time - terminate training after certain time
    5. other termination conditions, if they are logical

Will this change the current api? How?

We can configure when model training will stop, when one of condition above is met.
Training should be implemented as listener, early stop configuration will listen for any conditions above and terminate training.

Who will benefit from this feature?

Everybody, we can easily configure when learning will end.

References

Reference implementation:
https://github.com/eclipse/deeplearning4j/blob/b5f0ec072f3fd0da566e32f82c0e43ca36553f39/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java
There are other implementation in different NN framework.

@keerthanvasist keerthanvasist added the enhancement New feature or request label Mar 20, 2020
aksrajvanshi added a commit to aksrajvanshi/djl that referenced this issue Mar 18, 2021
@gforman44
Copy link
Contributor

This is pretty important. I'd like to see three criteria:

  1. minimum number of epochs (e.g. 2), no matter what.
  2. stop if the validation set doesn't improve for earlyStopPatience (e.g. 3) epochs
  3. if the user sends a SIGINT (or another) process signal, we should stop at tne end of the current epoch. This is different than a SIGKILL signal, which kills the process.

https://www.cyberciti.biz/faq/unix-kill-command-examples/

@gforman44
Copy link
Contributor

Here's a proposal for what I'd like to see:

Parameters for flexible stopping criteria:

    static int maxEpochs = 1000;
    static double objectiveSuccess = 0.5;// done if validation loss objective (e.g. L2Loss) < threshold
    static int minEpochs = 2;// after minimum # epochs, consider stopping if:
    static int maxMinutes = 5*60;// too much time elapsed
    static double earlyStopPctImprovement = 2;// consider early stopping if not 2% improvement
    static int earlyStopPatience = 3;// stop if insufficient improvement for 3 epochs in a row

With these parameters, then you can implement it like this in EasyTrain.fit():

    public static void fit(Trainer trainer, RandomAccessDataset trainingSet, RandomAccessDataset validateSet) throws TranslateException, IOException {
        final long start = System.currentTimeMillis();
        double prevLoss = Double.NaN;
        int improvementFailures = 0;
        for (int epoch = 0; epoch < maxEpochs; epoch++) {
            for (Batch batch: trainer.iterateDataset(trainingSet)) {
                EasyTrain.trainBatch(trainer, batch);
                trainer.step();
                batch.close();
            }

            // After each epoch, test against the validation dataset if we have one
            EasyTrain.evaluateDataset(trainer, validateSet);

            // reset training and validation evaluators at end of epoch
            trainer.notifyListeners(listener -> listener.onEpoch(trainer));

            // stopping criteria
            final double vloss = trainer.getTrainingResult().getValidateLoss();// else use train loss if no validation set
            if (vloss < objectiveSuccess) {
                System.out.printf("END: validation loss %s < objectiveSuccess %s\n", vloss, objectiveSuccess);
                return;
            }
            if (epoch+1 >= minEpochs) {
                double elapsedMinutes = (System.currentTimeMillis() - start) / 60_000.0;
                if (elapsedMinutes >= maxMinutes) {
                    System.out.printf("END: %.1f minutes elapsed >= %s maxMinutes\n", elapsedMinutes, maxMinutes);
                    return;
                }
                // consider early stopping?
                if (Double.isFinite(prevLoss)) {
                    double goalImprovement = prevLoss * (100 - earlyStopPctImprovement) / 100.0;
                    boolean improved = vloss <= goalImprovement;// false if any NANs
                    if (improved) {
                        improvementFailures = 0;
                    } else {
                        improvementFailures++;
                        if (improvementFailures >= earlyStopPatience) {
                            System.out.printf("END: failed to achieve %s%% improvement %s times in a row\n",
                                    earlyStopPctImprovement, earlyStopPatience);
                            return;
                        }
                    }
                }
            }
            if (Double.isFinite(vloss)) {
                prevLoss = vloss;
            }
        }
    }

@zachgk
Copy link
Contributor

zachgk commented May 13, 2022

@gforman44 That looks pretty good. One thing I was thinking was that we could implement the early stopping with a TrainingListener. That would give a good place to add in the early stopping configuration and helps manage all the different pieces of functionality that users may or may not want as part of their training. It could throw an EarlyStopException if it decides to end the training early.

Anyway, it sounds like you are really interested in this issue @gforman44. Do you want to implement it and submit a PR?

jagodevreede added a commit to jagodevreede/djl that referenced this issue Oct 12, 2023
jagodevreede added a commit to jagodevreede/djl that referenced this issue Oct 12, 2023
zachgk pushed a commit that referenced this issue Oct 24, 2023
* [api] Added Early stopping configuration (#38)

* [api] Added Builder for Early stopping configuration (#38)

* Explicitly set NDManager for dataset in EarlyStoppingListenerTest to make the test run on JDK11 in gradle.
frankfliu pushed a commit that referenced this issue Apr 26, 2024
* [api] Added Early stopping configuration (#38)

* [api] Added Builder for Early stopping configuration (#38)

* Explicitly set NDManager for dataset in EarlyStoppingListenerTest to make the test run on JDK11 in gradle.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants