Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring AbstractLinearSGDModel and Trainer to extract SGD base classes #134

Merged
merged 2 commits into from
May 11, 2021

Conversation

Craigacp
Copy link
Member

Description

This PR extracts out AbstractSGDModel and AbstractSGDTrainer from AbstractLinearSGDModel and AbstractLinearSGDTrainer, and introduces a FeedForwardParameters which has predict and gradient methods. They don't land on Parameters because that's used in the CRF as well, and sequences have a differently shaped input.

AbstractSGDTrainer has a lot of generic parameters, but those are all hidden from users and the concrete subclasses are still typed with just the output type like most of Tribuo. The code is tested by the existing tests, and can still deserialize Tribuo 4.0 models.

Motivation

The recent introduction of AbstractLinearSGDModel/Trainer wasn't quite abstract enough. The SGD package could be used for things beyond linear models like factorization machines. This PR will make it straightforward to subclass AbstractSGDTrainer for a different model class so you don't have to reimplement the training loop.

…Model and Trainer which operate on Parameters. This will allow future non-linear additions to Tribuo's SGD models.
@Craigacp Craigacp added the Oracle employee This PR is from an Oracle employee label Apr 20, 2021
Copy link
Member

@pogren pogren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduces AbstractSGDModel which basically moves AbstractLinearSGDModel methods 'predictSingle' and the supporting inner class PredAndActive into this class. Also introduces AbstractSGDTrainer which steals most of its code from AbstractLinearSGDTrainer which is now basically delegates to its super class. Two protected constructors were removed from AbstractLinearSGDTrainer so this is not a backwards compatible change and will affect subclasses of AbstractLinearSGDTrainer (if they exist.)

LinearParameters now implements FeedForwardParameters which defines a 'predict' method whose return type is DenseVector which LinearParameters now implements. Previously LinearParameters implemented Parameters directly and defined its own 'predict' method which returned an SGDVector - so this is not a backwards compatible change and will affect subclasses of LinearParameters (if they exist.) The methods 'predict' and 'gradients' are now annotated with '@OverRide' because they are defined in FeedForwardParameters which also introduces the 'copy' method.

Other than the above noted concerns - this is a straightforward refactoring of the abstract super classes of the SGD model and training code to better share code.

@Craigacp
Copy link
Member Author

The AbstractLinearSGDTrainer was introduced after the last release, so that doesn't change the compatibility as the concrete LinearSGDTrainer for Label and Regressor still have the same constructors. We'll note the LinearParameters change in the release notes, but I think it's unlikely to break anyone (that class should probably be final, but we didn't do a thorough job hardening Tribuo wrt this before the initial release).

@Craigacp
Copy link
Member Author

Thanks Philip!

@Craigacp Craigacp merged commit e901b9c into main May 11, 2021
@Craigacp Craigacp deleted the yet-another-sgd-refactor branch May 11, 2021 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Oracle employee This PR is from an Oracle employee
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants