From d19e1468c68d83ebbe5911f51eadf44a24c9eb2a Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 30 May 2021 00:35:50 -0400 Subject: [PATCH 01/31] Initial implementation of factorization machines without regularisation. --- .../sgd/fm/FMClassificationModel.java | 96 ++++++ .../sgd/fm/FMClassificationOptions.java | 90 ++++++ .../sgd/fm/FMClassificationTrainer.java | 141 +++++++++ .../classification/sgd/fm/TrainTest.java | 63 ++++ .../classification/sgd/fm/package-info.java | 20 ++ .../sgd/fm/TestFMClassification.java | 125 ++++++++ .../sgd/linear/TestSGDLinear.java | 4 +- .../tribuo/common/sgd/AbstractFMModel.java | 130 ++++++++ .../tribuo/common/sgd/AbstractFMTrainer.java | 113 +++++++ .../common/sgd/AbstractLinearSGDModel.java | 13 + .../common/sgd/AbstractLinearSGDTrainer.java | 1 + .../tribuo/common/sgd/AbstractSGDModel.java | 13 +- .../tribuo/common/sgd/AbstractSGDTrainer.java | 5 +- .../org/tribuo/common/sgd/FMParameters.java | 287 ++++++++++++++++++ .../org/tribuo/math/la/DenseSparseMatrix.java | 36 ++- .../java/org/tribuo/math/la/DenseVector.java | 8 + .../main/java/org/tribuo/math/la/Matrix.java | 8 +- .../java/org/tribuo/math/la/SGDVector.java | 9 + .../java/org/tribuo/math/la/SparseVector.java | 49 ++- .../main/java/org/tribuo/math/la/Tensor.java | 10 +- .../multilabel/sgd/fm/FMMultiLabelModel.java | 99 ++++++ .../sgd/fm/FMMultiLabelOptions.java | 80 +++++ .../sgd/fm/FMMultiLabelTrainer.java | 141 +++++++++ .../multilabel/sgd/fm/package-info.java | 20 ++ .../multilabel/sgd/linear/TestSGDLinear.java | 2 +- .../regression/sgd/fm/FMRegressionModel.java | 79 +++++ .../sgd/fm/FMRegressionTrainer.java | 150 +++++++++ .../tribuo/regression/sgd/fm/TrainTest.java | 164 ++++++++++ .../regression/sgd/fm/package-info.java | 20 ++ .../regression/sgd/linear/TestSGDLinear.java | 2 +- 30 files changed, 1963 insertions(+), 15 deletions(-) create mode 100644 Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java create mode 100644 Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationOptions.java create mode 100644 Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationTrainer.java create mode 100644 Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/TrainTest.java create mode 100644 Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/package-info.java create mode 100644 Classification/SGD/src/test/java/org/tribuo/classification/sgd/fm/TestFMClassification.java create mode 100644 Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java create mode 100644 Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMTrainer.java create mode 100644 Common/SGD/src/main/java/org/tribuo/common/sgd/FMParameters.java create mode 100644 MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java create mode 100644 MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelOptions.java create mode 100644 MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelTrainer.java create mode 100644 MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/package-info.java create mode 100644 Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java create mode 100644 Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionTrainer.java create mode 100644 Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/TrainTest.java create mode 100644 Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/package-info.java diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java new file mode 100644 index 000000000..7cb96d2c8 --- /dev/null +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.classification.sgd.fm; + +import org.tribuo.Example; +import org.tribuo.ImmutableFeatureMap; +import org.tribuo.ImmutableOutputInfo; +import org.tribuo.Prediction; +import org.tribuo.classification.Label; +import org.tribuo.common.sgd.AbstractFMModel; +import org.tribuo.common.sgd.FMParameters; +import org.tribuo.math.la.DenseVector; +import org.tribuo.math.util.VectorNormalizer; +import org.tribuo.provenance.ModelProvenance; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * The inference time version of a factorization machine trained using SGD. + *

+ * See: + *

+ * Rendle, S.
+ * Factorization machines.
+ * 2010 IEEE International Conference on Data Mining
+ * 
+ */ +public class FMClassificationModel extends AbstractFMModel