Skip to content

Commit

Permalink
Adding standardisation to ONNX export from FMRegressionModel.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Oct 1, 2021
1 parent a159bb9 commit 4ea0252
Showing 1 changed file with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,41 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) {
OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT);
OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build();
graphBuilder.addInput(inputValueProto);
String outputName = "output";
OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT);
OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build();
OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName(outputName).build();
graphBuilder.addOutput(outputValueProto);

// Build the output neutral bits of the onnx graph
String outputName = generateONNXGraph(context, graphBuilder, inputValueProto.getName());
String fmOutputName = generateONNXGraph(context, graphBuilder, inputValueProto.getName());

// Link up the output to the graph output
OnnxMl.NodeProto output = ONNXOperators.IDENTITY.build(context,outputName,"output");
graphBuilder.addNode(output);
if (standardise) {
// standardise the FM output
ImmutableRegressionInfo info = (ImmutableRegressionInfo) outputIDInfo;
double[] means = new double[outputIDInfo.size()];
double[] variances = new double[outputIDInfo.size()];
for (int i = 0; i < means.length; i++) {
means[i] = info.getMean(i);
variances[i] = info.getVariance(i);
}

// Create mean and variance initializers
OnnxMl.TensorProto outputMeanProto = ONNXUtils.arrayBuilder(context,context.generateUniqueName("y_mean"),means);
graphBuilder.addInitializer(outputMeanProto);
OnnxMl.TensorProto outputVarianceProto = ONNXUtils.arrayBuilder(context, context.generateUniqueName("y_var"),variances);
graphBuilder.addInitializer(outputVarianceProto);

// Add standardisation operations
String varianceOutput = context.generateUniqueName("y_var_scale_output");
OnnxMl.NodeProto varianceScale = ONNXOperators.MUL.build(context, new String[]{fmOutputName,outputVarianceProto.getName()}, varianceOutput);
graphBuilder.addNode(varianceScale);
OnnxMl.NodeProto meanScale = ONNXOperators.ADD.build(context, new String[]{varianceOutput,outputMeanProto.getName()}, outputName);
graphBuilder.addNode(meanScale);
} else {
// Not standardised, so link up the FM output to the graph output
OnnxMl.NodeProto output = ONNXOperators.IDENTITY.build(context, fmOutputName, outputName);
graphBuilder.addNode(output);
}

return graphBuilder.build();
}
Expand Down

0 comments on commit 4ea0252

Please sign in to comment.