Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factorization machines #179

Merged
merged 31 commits into from
Oct 12, 2021
Merged

Conversation

Craigacp
Copy link
Member

Description

Adds pairwise Factorization Machines trained using SGD, for multi-class classification, multi-label classification and regression. Factorization Machines can also be exported as ONNX models, though the translation uses dense tensors due to a lack of support for Sparse Tensors in ONNX Runtime (we might revise this as ORT's support improves). There are few implementations of multi-class factorization machines, but the code uses the obvious extension to multiple dimensions (add a per dimension bias, linear weight vector and feature embedding matrix). Due to the per output dimension embedding matrix this might not scale to large numbers of output dimensions particularly well, in such cases an MLP might be preferable.

Compared with other libraries for factorization machines the Tribuo implementation is missing l2 regularisation (which could be added by using the ShrinkingTensor but we didn't have time to add the necessary support), though adds support for using any gradient optimiser instead of just SGD or Adagrad. We plan to add regularisation support at a later date, along with adding l1 and l2 regularisation to LinearSGDTrainer.

In addition to the Factorization Machines there are:

  • several cleanups to the SGD packages, including adding names to exported ONNX graphs (this is mandatory according to the ONNX spec, but ORT doesn't seem to check it).
  • a few additions to the Math package.
  • DataOptions can now standardize the features on load.
  • RegressionInfo has methods to access the mean, std deviation and other computed statistics. These are re-computed on deserialization rather than being stored in the serialized object.
  • several cleanups for the ONNX support including extra overloads for ONNXOperators.build and optional inputs support.

Motivation

Factorization Machines are a powerful prediction technique which allow pairwise feature interactions via a feature embedding matrix. This makes them more powerful than logistic regression, but not as complex as an MLP, providing a nice intermediate model which scales well under sparsity (as if the features aren't present they don't appear in the factorized representation).

Paper reference

S. Rendle, Factorization Machines, 2010 IEEE Conference on Data Mining.

@Craigacp Craigacp added Oracle employee This PR is from an Oracle employee squash-commits Squash the commits when merging this PR labels Sep 27, 2021
…d output operators, updating other code to use the new overloads.
@Craigacp Craigacp force-pushed the factorization-machines branch from ac2b21e to e4d1c28 Compare October 1, 2021 21:48
Copy link
Member

@JackSullivan JackSullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good, see inline comments/suggestions

Copy link
Member

@JackSullivan JackSullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

@Craigacp Craigacp merged commit e9b5ed3 into oracle:main Oct 12, 2021
@Craigacp Craigacp deleted the factorization-machines branch October 12, 2021 22:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Oracle employee This PR is from an Oracle employee squash-commits Squash the commits when merging this PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants