-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Factorization machines #179
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Craigacp
added
Oracle employee
This PR is from an Oracle employee
squash-commits
Squash the commits when merging this PR
labels
Sep 27, 2021
…s of a performance optimization.
…to the next release.
…d output operators, updating other code to use the new overloads.
…oduced has not yet been tested.
…l and regression FMs.
…ndatory but ORT doesn't appear to care.
…ion is quite different in floats.
Craigacp
force-pushed
the
factorization-machines
branch
from
October 1, 2021 21:48
ac2b21e
to
e4d1c28
Compare
JackSullivan
reviewed
Oct 4, 2021
Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractSGDTrainer.java
Outdated
Show resolved
Hide resolved
JackSullivan
reviewed
Oct 4, 2021
JackSullivan
reviewed
Oct 4, 2021
Common/SGD/src/main/java/org/tribuo/common/sgd/FMParameters.java
Outdated
Show resolved
Hide resolved
Common/SGD/src/main/java/org/tribuo/common/sgd/FMParameters.java
Outdated
Show resolved
Hide resolved
Common/SGD/src/main/java/org/tribuo/common/sgd/FMParameters.java
Outdated
Show resolved
Hide resolved
Common/SGD/src/main/java/org/tribuo/common/sgd/FMParameters.java
Outdated
Show resolved
Hide resolved
Classification/SGD/configs/classification-factorization-machine.xml
Outdated
Show resolved
Hide resolved
Classification/SGD/configs/classification-factorization-machine.xml
Outdated
Show resolved
Hide resolved
Classification/SGD/configs/classification-factorization-machine.xml
Outdated
Show resolved
Hide resolved
Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java
Outdated
Show resolved
Hide resolved
JackSullivan
requested changes
Oct 12, 2021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good, see inline comments/suggestions
Co-authored-by: Jack Sullivan <[email protected]>
JackSullivan
approved these changes
Oct 12, 2021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Oracle employee
This PR is from an Oracle employee
squash-commits
Squash the commits when merging this PR
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Adds pairwise Factorization Machines trained using SGD, for multi-class classification, multi-label classification and regression. Factorization Machines can also be exported as ONNX models, though the translation uses dense tensors due to a lack of support for Sparse Tensors in ONNX Runtime (we might revise this as ORT's support improves). There are few implementations of multi-class factorization machines, but the code uses the obvious extension to multiple dimensions (add a per dimension bias, linear weight vector and feature embedding matrix). Due to the per output dimension embedding matrix this might not scale to large numbers of output dimensions particularly well, in such cases an MLP might be preferable.
Compared with other libraries for factorization machines the Tribuo implementation is missing l2 regularisation (which could be added by using the
ShrinkingTensor
but we didn't have time to add the necessary support), though adds support for using any gradient optimiser instead of just SGD or Adagrad. We plan to add regularisation support at a later date, along with adding l1 and l2 regularisation toLinearSGDTrainer
.In addition to the Factorization Machines there are:
DataOptions
can now standardize the features on load.RegressionInfo
has methods to access the mean, std deviation and other computed statistics. These are re-computed on deserialization rather than being stored in the serialized object.ONNXOperators.build
and optional inputs support.Motivation
Factorization Machines are a powerful prediction technique which allow pairwise feature interactions via a feature embedding matrix. This makes them more powerful than logistic regression, but not as complex as an MLP, providing a nice intermediate model which scales well under sparsity (as if the features aren't present they don't appear in the factorized representation).
Paper reference
S. Rendle, Factorization Machines, 2010 IEEE Conference on Data Mining.