Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR is a first step of non-breaking refactoring for improving the
flair.nn.Model
andModelTrainer
functionality. More refactorings will follow, in particular of the evaluation and data loading parts (see #563). The idea of doing this refactoring piece-by-piece is to hopefully make it more manageable.The main idea is to make it possible to use the
ModelTrainer
over arbitrary models that implement theflair.nn.Model
interface, which is currentlySequenceTagger
,TextClassifier
TextRegressor
.The addition of the
TextRegressor
made it necessary to move evaluation into the downstream tasks, since regression uses different evaluation metrics than classification. In the future, all three models will use entirely different evaluation methods:TextClassifier
will probably use scikit-learn,SequenceTagger
will go back to using the CoNLL-03 script andTextRegressor
also scikit-learn but with other metrics. Logging and plotting of results has also been refactored to deal with different evaluation outputs.At the same time, we moved loading, saving and checkpointing up to the
flair.nn.Model
base class, since this is always the same and leads to current code redundancies otherwise.So, the new
flair.nn.Model
interface has 5 classes that need to be implemented by a new downstream task model:forward_loss()
A method that takes sentences and produces a loss with autograd for backpropagation. Implementing this method will make it possible to train the downstream task model.predict()
A method that takes sentences and a mini-batch size to do prediction.evaluate()
The new localized evaluation method which may be entirely different depending on the downstream task. Returns an object with evaluation results (though this will likely change in future refactoring steps)._get_state_dict()
Returns the state dictionary of the model. Implementing this enables saving the model and model checkpoints._init_model_with_state_dict()
A method that creates a model from a state dictionary. Implementing this enables restoring models and checkpoints.