Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-474: model interface #681

Merged
merged 6 commits into from
Apr 26, 2019
Merged

GH-474: model interface #681

merged 6 commits into from
Apr 26, 2019

Conversation

alanakbik
Copy link
Collaborator

This PR is a first step of non-breaking refactoring for improving the flair.nn.Model and ModelTrainer functionality. More refactorings will follow, in particular of the evaluation and data loading parts (see #563). The idea of doing this refactoring piece-by-piece is to hopefully make it more manageable.

The main idea is to make it possible to use the ModelTrainer over arbitrary models that implement the flair.nn.Model interface, which is currently

  • the SequenceTagger,
  • the TextClassifier
  • and the beta TextRegressor.

The addition of the TextRegressor made it necessary to move evaluation into the downstream tasks, since regression uses different evaluation metrics than classification. In the future, all three models will use entirely different evaluation methods: TextClassifier will probably use scikit-learn, SequenceTagger will go back to using the CoNLL-03 script and TextRegressor also scikit-learn but with other metrics. Logging and plotting of results has also been refactored to deal with different evaluation outputs.

At the same time, we moved loading, saving and checkpointing up to the flair.nn.Model base class, since this is always the same and leads to current code redundancies otherwise.

So, the new flair.nn.Model interface has 5 classes that need to be implemented by a new downstream task model:

  • forward_loss() A method that takes sentences and produces a loss with autograd for backpropagation. Implementing this method will make it possible to train the downstream task model.
  • predict() A method that takes sentences and a mini-batch size to do prediction.
  • evaluate() The new localized evaluation method which may be entirely different depending on the downstream task. Returns an object with evaluation results (though this will likely change in future refactoring steps).
  • _get_state_dict() Returns the state dictionary of the model. Implementing this enables saving the model and model checkpoints.
  • _init_model_with_state_dict() A method that creates a model from a state dictionary. Implementing this enables restoring models and checkpoints.

@alanakbik alanakbik merged commit 81cf1b5 into master Apr 26, 2019
@alanakbik alanakbik deleted the GH-474-model-interface branch May 9, 2019 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant