Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Java ONNX Interface #199

Merged
merged 8 commits into from
Dec 10, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@

package org.tribuo.classification.ensemble;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXUtils;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -134,68 +133,46 @@ public ConfiguredObjectProvenance getProvenance() {
}

/**
* Exports this voting combiner as a list of ONNX NodeProtos.
* Exports this voting combiner to ONNX.
* <p>
* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
* @param context The ONNX context object for name generation.
* @param input The name of the input tensor to combine.
* @param output The name of the voting output.
* @return A list of node protos representing the voting operation.
* @param input the node to be ensembled according to this implementation.
* @return The leaf node of the voting operation.
*/
@Override
public List<OnnxMl.NodeProto> exportCombiner(ONNXContext context, String input, String output) {
List<OnnxMl.NodeProto> nodes = new ArrayList<>();

public ONNXNode exportCombiner(ONNXNode input) {
// Take the mean over the maxed predictions
Map<String,Object> attributes = new HashMap<>();
attributes.put("axes",new int[]{2});
attributes.put("keepdims",0);
OnnxMl.NodeProto mean = ONNXOperators.REDUCE_MEAN.build(context,input,output,attributes);
nodes.add(mean);

return nodes;
return input.apply(ONNXOperators.REDUCE_MEAN, attributes);
}

/**
* Exports this ensemble combiner as a list of ONNX NodeProtos.
* Exports this voting combiner to ONNX.
* <p>
* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
* @param context The ONNX context object for name generation.
* @param input The name of the input tensor to combine.
* @param output The name of the voting output.
* @param weight The name of the combination weight initializer.
* @return A list of node protos representing the voting operation.
* @param input the node to be ensembled according to this implementation.
* @param weight The node of weights for ensembling.
* @return The leaf node of the voting operation.
*/
@Override
public List<OnnxMl.NodeProto> exportCombiner(ONNXContext context, String input, String output, String weight) {
List<OnnxMl.NodeProto> nodes = new ArrayList<>();

public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
// Unsqueeze the weights to make sure they broadcast how I want them too.
// Now the size is [1, 1, num_members].
OnnxMl.TensorProto unsqueezeAxes = ONNXUtils.arrayBuilder(context,"unsqueeze_ensemble_output",new long[]{0,1});
context.addInitializer(unsqueezeAxes);
OnnxMl.NodeProto unsqueeze = ONNXOperators.UNSQUEEZE.build(context,new String[]{weight,unsqueezeAxes.getName()},context.generateUniqueName("unsqueezed_weights"));
nodes.add(unsqueeze);
ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, 1});

ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, unsqueezeAxes);

// Multiply the input by the weights.
OnnxMl.NodeProto mulByWeights = ONNXOperators.MUL.build(context,new String[]{input,unsqueeze.getOutput(0)},context.generateUniqueName("mul_predictions_by_weights"));
nodes.add(mulByWeights);
ONNXNode mulByWeights = input.apply(ONNXOperators.MUL, unsqueezed);

// Sum the weights
OnnxMl.NodeProto weightSum = ONNXOperators.REDUCE_SUM.build(context,weight,context.generateUniqueName("ensemble_weight_sum"));
nodes.add(weightSum);
ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM);

// Take the weighted mean over the outputs
OnnxMl.TensorProto sumAxes = ONNXUtils.arrayBuilder(context,"sum_across_ensemble_axes",new long[]{2});
context.addInitializer(sumAxes);
OnnxMl.NodeProto sumAcrossMembers = ONNXOperators.REDUCE_SUM.build(context,
new String[]{mulByWeights.getOutput(0),sumAxes.getName()},
context.generateUniqueName("sum_across_ensemble"),
Collections.singletonMap("keepdims",0));
nodes.add(sumAcrossMembers);
OnnxMl.NodeProto divideByWeightSum = ONNXOperators.DIV.build(context,new String[]{sumAcrossMembers.getOutput(0),weightSum.getOutput(0)},output);
nodes.add(divideByWeightSum);

return nodes;
ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2});
return mulByWeights.apply(ONNXOperators.REDUCE_SUM, sumAxes, Collections.singletonMap("keepdims", 0))
.apply(ONNXOperators.DIV, weightSum);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@

package org.tribuo.classification.ensemble;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXNode;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.onnx.ONNXUtils;
import org.tribuo.onnx.ONNXRef;
import org.tribuo.onnx.ONNXInitializer;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -126,76 +125,50 @@ public ConfiguredObjectProvenance getProvenance() {
}

/**
* Exports this voting combiner as a list of ONNX NodeProtos.
* Exports this voting combiner to ONNX.
* <p>
* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
* @param context The ONNX context object for name generation.
* @param input The name of the input tensor to combine.
* @param output The name of the voting output.
* @return A list of node protos representing the voting operation.
* @param input The input tensor to combine.
* @return the final node proto representing the voting operation.
*/
@Override
public List<OnnxMl.NodeProto> exportCombiner(ONNXContext context, String input, String output) {
List<OnnxMl.NodeProto> nodes = new ArrayList<>();

public ONNXNode exportCombiner(ONNXNode input) {
// Hardmax!
OnnxMl.NodeProto hardMaxEnsemble = ONNXOperators.HARDMAX.build(context,input,context.generateUniqueName("hardmax_predictions"),Collections.singletonMap("axis",1));
nodes.add(hardMaxEnsemble);

// Take the mean over the maxed predictions
Map<String,Object> attributes = new HashMap<>();
attributes.put("axes",new int[]{2});
attributes.put("keepdims",0);
OnnxMl.NodeProto mean = ONNXOperators.REDUCE_MEAN.build(context,hardMaxEnsemble.getOutput(0),output,attributes);
nodes.add(mean);

return nodes;
return input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1))
.apply(ONNXOperators.REDUCE_MEAN, attributes);
}

/**
* Exports this voting combiner as a list of ONNX NodeProtos.
* Exports this voting combiner to ONNX
* <p>
* The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
* @param context The ONNX context object for name generation.
* @param input The name of the input tensor to combine.
* @param output The name of the voting output.
* @param weight The name of the combination weight initializer.
* @return A list of node protos representing the voting operation.
* @param input The input tensor to combine.
* @param weight The combination weight node.
* @return the final node proto representing the voting operation.
*/
@Override
public List<OnnxMl.NodeProto> exportCombiner(ONNXContext context, String input, String output, String weight) {
List<OnnxMl.NodeProto> nodes = new ArrayList<>();

public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
// Unsqueeze the weights to make sure they broadcast how I want them too.
// Now the size is [1, 1, num_members].
OnnxMl.TensorProto unsqueezeAxes = ONNXUtils.arrayBuilder(context,"unsqueeze_ensemble_output",new long[]{0,1});
context.addInitializer(unsqueezeAxes);
OnnxMl.NodeProto unsqueeze = ONNXOperators.UNSQUEEZE.build(context,new String[]{weight,unsqueezeAxes.getName()},context.generateUniqueName("unsqueezed_weights"));
nodes.add(unsqueeze);
ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, 1});
ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2});

// Hardmax!
OnnxMl.NodeProto hardMaxEnsemble = ONNXOperators.HARDMAX.build(context,input,context.generateUniqueName("hardmax_predictions"),Collections.singletonMap("axis",1));
nodes.add(hardMaxEnsemble);
ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, unsqueezeAxes);

// Hardmax!
// Multiply the input by the weights.
OnnxMl.NodeProto mulByWeights = ONNXOperators.MUL.build(context,new String[]{hardMaxEnsemble.getOutput(0),unsqueeze.getOutput(0)},context.generateUniqueName("mul_predictions_by_weights"));
nodes.add(mulByWeights);
ONNXNode mulByWeights = input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1))
.apply(ONNXOperators.MUL, unsqueezed);

// Sum the weights
OnnxMl.NodeProto weightSum = ONNXOperators.REDUCE_SUM.build(context,weight,context.generateUniqueName("ensemble_weight_sum"));
nodes.add(weightSum);
ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM);

// Take the weighted mean over the outputs
OnnxMl.TensorProto sumAxes = ONNXUtils.arrayBuilder(context,"sum_across_ensemble_axes",new long[]{2});
context.addInitializer(sumAxes);
OnnxMl.NodeProto sumAcrossMembers = ONNXOperators.REDUCE_SUM.build(context,
new String[]{mulByWeights.getOutput(0),sumAxes.getName()},
context.generateUniqueName("sum_across_ensemble"),
Collections.singletonMap("keepdims",0));
nodes.add(sumAcrossMembers);
OnnxMl.NodeProto divideByWeightSum = ONNXOperators.DIV.build(context,new String[]{sumAcrossMembers.getOutput(0),weightSum.getOutput(0)},output);
nodes.add(divideByWeightSum);

return nodes;
return mulByWeights.apply(ONNXOperators.REDUCE_SUM, sumAxes, Collections.singletonMap("keepdims", 0))
.apply(ONNXOperators.DIV, weightSum);
}
}
Loading