Refactoring AbstractLinearSGDModel and Trainer to extract SGD base classes #134
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR extracts out
AbstractSGDModel
andAbstractSGDTrainer
fromAbstractLinearSGDModel
andAbstractLinearSGDTrainer
, and introduces aFeedForwardParameters
which has predict and gradient methods. They don't land onParameters
because that's used in the CRF as well, and sequences have a differently shaped input.AbstractSGDTrainer
has a lot of generic parameters, but those are all hidden from users and the concrete subclasses are still typed with just the output type like most of Tribuo. The code is tested by the existing tests, and can still deserialize Tribuo 4.0 models.Motivation
The recent introduction of AbstractLinearSGDModel/Trainer wasn't quite abstract enough. The SGD package could be used for things beyond linear models like factorization machines. This PR will make it straightforward to subclass
AbstractSGDTrainer
for a different model class so you don't have to reimplement the training loop.