Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factorization machines #179

Merged
merged 31 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d19e146
Initial implementation of factorization machines without regularisation.
Craigacp May 30, 2021
9aec044
Tidying up copyrights and some comments.
Craigacp May 30, 2021
ebfc347
Fixing a bug in the model naming for FMRegressionModel and LinearSGDM…
Craigacp May 30, 2021
9d18400
Fixing a gradient initialisation bug for dense FMs.
Craigacp May 31, 2021
cf14b49
Adding int id lookup methods to ImmutableRegressionInfo's dimension s…
Craigacp May 31, 2021
bdffcde
Adding standardisation to factorisation machine regressors.
Craigacp May 31, 2021
6c2b17f
Adding feature scaling to DataOptions for easier testing of factoriza…
Craigacp May 31, 2021
aa563f1
Adding additional fine level logging to AbstractSGDTrainer for debugg…
Craigacp Jun 2, 2021
2cb5d4c
Fixing a bug in ArrayExample.transform which reduced the effectivenes…
Craigacp Jun 2, 2021
23d2845
Adding another rescaling option to DataOptions.
Craigacp Jun 2, 2021
d9e0083
Adding copy accessors for the FM model parameters.
Craigacp Jun 4, 2021
fd16538
Adding smoke tests for FM multi-label and regression.
Craigacp Jun 4, 2021
49d6dbc
Adding missing override annotations to various methods in SGD and Math.
Craigacp Jul 7, 2021
34e4c30
Adding missing override annotations to other classes.
Craigacp Jul 7, 2021
e0467b6
Removing the unimplemented l2 regularisation hooks. It's been pushed …
Craigacp Sep 7, 2021
87f4cc8
Adding some more overloads to ONNXOperators.build for single input an…
Craigacp Sep 26, 2021
d3ef792
Adding ONNX export to FMClassificationModel.
Craigacp Sep 26, 2021
a8af452
Adding optional inputs to ONNXOperators to support GEMM.
Craigacp Sep 26, 2021
f2a9795
Fixing FMClassificationModel onnx export so it executes. The model pr…
Craigacp Sep 26, 2021
c65f6b6
Tidying up javadoc and fixing a broken merge.
Craigacp Sep 26, 2021
2d91600
Adding a name to the generated onnx graph and some javadoc cleanup
Craigacp Sep 27, 2021
bbe2be6
Fixing ONNX export so we don't transpose already transposed matrices.
Craigacp Sep 27, 2021
7cd61b8
Moving FM ONNX export to AbstractFMModel and wiring it into multilabe…
Craigacp Sep 27, 2021
e828128
Adding graph names to the existing ONNX exports. The graph name is ma…
Craigacp Sep 27, 2021
a159bb9
Tidying up the bias tensor proto generation.
Craigacp Sep 27, 2021
4ea0252
Adding standardisation to ONNX export from FMRegressionModel.
Craigacp Sep 27, 2021
e4d1c28
Relaxing the similarity test in FMRegressionModel as the standardisat…
Craigacp Sep 27, 2021
9b500af
Fixing a bug in AbstractFMTrainer when instantiated from configuration.
Craigacp Oct 7, 2021
db5dd86
Adding example configs for factorization machines
Craigacp Oct 8, 2021
d5859a3
Apply Jack's suggestions from code review
Craigacp Oct 12, 2021
2185b4a
Fixing PR comments.
Craigacp Oct 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
-->

<config>
<component name="cart" type="org.tribuo.classification.mnb.MultinomialNaiveBayesTrainer">
<component name="mnb" type="org.tribuo.classification.mnb.MultinomialNaiveBayesTrainer">
<property name="alpha" value="1.0"/>
</component>
</config>
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<?xml version="1.0" encoding="UTF-8"?>

<!--
~ Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<!--
Description:
An example configuration file for a factorization machine trained using AdaGrad.
-->

<config>
<component name="fm" type="org.tribuo.classification.sgd.fm.FMClassificationTrainer">
<property name="objective" value="log"/>
<property name="optimiser" value="adagrad"/>
<property name="epochs" value="5"/>
<property name="loggingInterval" value="100"/>
<property name="minibatchSize" value="1"/>
<property name="seed" value="1"/>
<property name="factoredDimSize" value="5"/>
<property name="variance" value="0.5"/>
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
</component>

<component name="log" type="org.tribuo.classification.sgd.objectives.LogMulticlass"/>

<component name="adagrad" type="org.tribuo.math.optimisers.AdaGrad">
<property name="initialLearningRate" value="1.0"/>
<property name="epsilon" value="0.1"/>
</component>
</config>
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.classification.sgd.fm;

import ai.onnx.proto.OnnxMl;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.sgd.AbstractFMModel;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXShape;
import org.tribuo.onnx.ONNXUtils;
import org.tribuo.provenance.ModelProvenance;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
* The inference time version of a factorization machine trained using SGD.
* <p>
* See:
* <pre>
* Rendle, S.
* Factorization machines.
* 2010 IEEE International Conference on Data Mining
* </pre>
*/
public class FMClassificationModel extends AbstractFMModel<Label> implements ONNXExportable {
private static final long serialVersionUID = 1L;

private final VectorNormalizer normalizer;

/**
* Constructs a classification factorization machine trained via SGD.
* @param name The model name.
* @param provenance The model provenance.
* @param featureIDMap The feature domain.
* @param outputIDInfo The output domain.
* @param parameters The model parameters.
* @param normalizer The normalization function.
* @param generatesProbabilities Does this model generate probabilities?
*/
FMClassificationModel(String name, ModelProvenance provenance,
ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> outputIDInfo,
FMParameters parameters, VectorNormalizer normalizer, boolean generatesProbabilities) {
super(name, provenance, featureIDMap, outputIDInfo, parameters, generatesProbabilities);
this.normalizer = normalizer;
}

@Override
public Prediction<Label> predict(Example<Label> example) {
PredAndActive predTuple = predictSingle(example);
DenseVector prediction = predTuple.prediction;
prediction.normalize(normalizer);

double maxScore = Double.NEGATIVE_INFINITY;
Label maxLabel = null;
Map<String,Label> predMap = new LinkedHashMap<>();
for (int i = 0; i < prediction.size(); i++) {
String labelName = outputIDInfo.getOutput(i).getLabel();
double score = prediction.get(i);
Label label = new Label(labelName, score);
predMap.put(labelName,label);
if (score > maxScore) {
maxScore = score;
maxLabel = label;
}
}
return new Prediction<>(maxLabel, predMap, predTuple.numActiveFeatures, example, generatesProbabilities);
}

@Override
protected FMClassificationModel copy(String newName, ModelProvenance newProvenance) {
return new FMClassificationModel(newName,newProvenance,featureIDMap,outputIDInfo,(FMParameters)modelParameters.copy(),normalizer,generatesProbabilities);
}

@Override
protected String getDimensionName(int index) {
return outputIDInfo.getOutput(index).getLabel();
}

@Override
public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXContext context = new ONNXContext();

// Build graph
OnnxMl.GraphProto graph = exportONNXGraph(context);

return innerExportONNXModel(graph,domain,modelVersion);
}

@Override
public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) {
OnnxMl.GraphProto.Builder graphBuilder = OnnxMl.GraphProto.newBuilder();
graphBuilder.setName("FMClassificationModel");

// Make inputs and outputs
OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT);
OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build();
graphBuilder.addInput(inputValueProto);
OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT);
OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build();
graphBuilder.addOutput(outputValueProto);

// Build the output neutral bits of the onnx graph
String outputName = generateONNXGraph(context, graphBuilder, inputValueProto.getName());

// Make output normalizer
List<OnnxMl.NodeProto> normalizerProtos = normalizer.exportNormalizer(context,outputName,"output");
if (normalizerProtos.isEmpty()) {
throw new IllegalArgumentException("Normalizer " + normalizer.getClass() + " cannot be exported in ONNX models.");
} else {
graphBuilder.addAllNode(normalizerProtos);
}

return graphBuilder.build();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.classification.sgd.fm;

import com.oracle.labs.mlrg.olcut.config.ArgumentException;
import com.oracle.labs.mlrg.olcut.config.Option;
import org.tribuo.Trainer;
import org.tribuo.classification.ClassificationOptions;
import org.tribuo.classification.sgd.LabelObjective;
import org.tribuo.classification.sgd.objectives.Hinge;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.math.optimisers.GradientOptimiserOptions;

import java.util.logging.Logger;

/**
* CLI options for training a factorization machine classifier.
*/
public class FMClassificationOptions implements ClassificationOptions<FMClassificationTrainer> {
private static final Logger logger = Logger.getLogger(FMClassificationOptions.class.getName());

/**
* Available loss types.
*/
public enum LossEnum {
/**
* Hinge loss (like an SVM).
*/
HINGE,
/**
* Log loss (i.e., a logistic regression).
*/
LOG
}

public GradientOptimiserOptions sgoOptions;

/**
* Number of SGD epochs.
*/
@Option(longName = "fm-epochs", usage = "Number of SGD epochs.")
public int fmEpochs = 5;
/**
* Loss function.
*/
@Option(longName = "fm-objective", usage = "Loss function.")
public LossEnum fmObjective = LossEnum.LOG;
/**
* Log the objective after n examples.
*/
@Option(longName = "fm-logging-interval", usage = "Log the objective after <int> examples.")
public int fmLoggingInterval = 100;
/**
* Minibatch size.
*/
@Option(longName = "fm-minibatch-size", usage = "Minibatch size.")
public int fmMinibatchSize = 1;
/**
* Sets the random seed for the FMClassificationTrainer.
*/
@Option(longName = "fm-seed", usage = "Sets the random seed for the FMClassificationTrainer.")
private long fmSeed = Trainer.DEFAULT_SEED;
/**
* Factor size.
*/
@Option(longName = "fm-factor-size", usage = "Factor size.")
public int fmFactorSize = 5;
/**
* Variance of the initialization gaussian.
*/
@Option(longName = "fm-variance", usage = "Variance of the initialization gaussian.")
public double fmVariance = 0.5;

/**
* Returns the loss function specified in the arguments.
* @return The loss function.
*/
public LabelObjective getLoss() {
switch (fmObjective) {
case HINGE:
return new Hinge();
case LOG:
return new LogMulticlass();
default:
throw new ArgumentException("sgd-objective", "Unknown loss function " + fmObjective);
}
}

@Override
public FMClassificationTrainer getTrainer() {
logger.info(String.format("Set logging interval to %d", fmLoggingInterval));
return new FMClassificationTrainer(getLoss(), sgoOptions.getOptimiser(), fmEpochs, fmLoggingInterval,
fmMinibatchSize, fmSeed, fmFactorSize, fmVariance);
}
}
Loading