-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds a multi-label linear sgd classifier #106
Conversation
…to make the tests pass.
- Adds the gradient and loss to BinaryCrossEntropy (formerly Sigmoid). - Adds the loss to Hinge, and refactors the gradient computation. - Adds a reduction method to the SGDVector interface.
…ns when performing operations with dense arguments.
… in preparation for sharing code between the trainers.
… class. Note this commit changes the serialization format for all the models. Compatibility with the 4.0 serialised models will be restored later.
…earSGDModel. Adding a test for 4.0 models to the linear SGD package in classification and regression.
…bclasses. Suppressing the unchecked array creation warning, we know it's safe.
Classification/SGD/src/main/java/org/tribuo/classification/sgd/objectives/Hinge.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few small changes here and there and it should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the attached changes.
Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDTrainer.java
Outdated
Show resolved
Hide resolved
/** | ||
* Copies the supplied matrix. | ||
* @param other The matrix to copy. | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious why we're not using arraycopy in this constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it's not final and ShrinkingMatrix
and AdaGradRDAMatrix
both subclass DenseMatrix
and mess with the get method to apply a transformation as part of the regularisation during training. I guess I could check if it's only a DenseMatrix without any other classes and then do an arraycopy, falling back to this code otherwise, but it doesn't seem worth it at the moment.
MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDTrainer.java
Show resolved
Hide resolved
Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDTrainer.java
Outdated
Show resolved
Hide resolved
@eelstretching any thoughts on using the Pair vs the concrete class to return the gradient and the loss? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good. I think returning the Pair is the correct decision.
I removed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Description
This PR adds a native multi-label linear SGD model, refactors all the linear sgd models to share a base class, and adds a bunch of additional support to the math library for dense vectors (in preparation for further optimizations to the linear sgd models). It also switches over the vector normalizers to have an in place normalization method to reduce allocations in the inner loop of training and inference.
Motivation
The multi-label package includes a wrapper to convert any classifier into a multi-label classifier, but this wrapper is too slow for large applications. The multi-label linear sgd classifier in this PR is an order of magnitude or so faster when ran on a large text corpus. It also preps the linear models package for further work introducing vectorisation, l2 regularisation and other improvements.