From 8a771f2289ac57eb6da040fa37b397f56d49c094 Mon Sep 17 00:00:00 2001 From: John Sullivan Date: Thu, 14 Oct 2021 12:12:48 -0400 Subject: [PATCH 1/3] Refactoring methods that create `OnnxMl.TensorProto` instances into util methods. --- .../sgd/fm/FMClassificationModel.java | 6 +- .../sgd/linear/LinearSGDModel.java | 7 +- .../tribuo/common/sgd/AbstractFMModel.java | 77 +------- .../common/sgd/AbstractLinearSGDModel.java | 52 ++---- .../java/org/tribuo/onnx/ONNXOperators.java | 2 +- .../main/java/org/tribuo/onnx/ONNXUtils.java | 100 ---------- .../org/tribuo/math/onnx/ONNXMathUtils.java | 172 ++++++++++++++++++ .../multilabel/sgd/fm/FMMultiLabelModel.java | 6 +- .../multilabel/sgd/linear/LinearSGDModel.java | 7 +- .../regression/sgd/fm/FMRegressionModel.java | 10 +- .../regression/sgd/linear/LinearSGDModel.java | 8 +- .../regression/slm/SparseLinearModel.java | 86 +++------ 12 files changed, 241 insertions(+), 292 deletions(-) delete mode 100644 Core/src/main/java/org/tribuo/onnx/ONNXUtils.java create mode 100644 Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java index e9a44f29f..5749024a9 100644 --- a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/fm/FMClassificationModel.java @@ -29,7 +29,7 @@ import org.tribuo.onnx.ONNXContext; import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.provenance.ModelProvenance; import java.util.LinkedHashMap; @@ -116,10 +116,10 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { graphBuilder.setName("FMClassificationModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build(); graphBuilder.addOutput(outputValueProto); diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java index 344d83a1e..e59686861 100644 --- a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/linear/LinearSGDModel.java @@ -31,11 +31,10 @@ import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.provenance.ModelProvenance; import java.io.IOException; -import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -125,10 +124,10 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { graphBuilder.setName("Classification-LinearSGDModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build(); graphBuilder.addOutput(outputValueProto); diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java index 1ed2933f0..86042b67c 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java @@ -17,7 +17,6 @@ package org.tribuo.common.sgd; import ai.onnx.proto.OnnxMl; -import com.google.protobuf.ByteString; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.Excuse; @@ -27,16 +26,17 @@ import org.tribuo.Tribuo; import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.Matrix; +import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.Tensor; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.onnx.ONNXContext; import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.provenance.ModelProvenance; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -216,68 +216,6 @@ protected OnnxMl.ModelProto innerExportONNXModel(OnnxMl.GraphProto graph, String return builder.build(); } - /** - * Builds a TensorProto containing the supplied DenseMatrix. - * - * @param context The ONNX context for naming. - * @param name The name for this tensor proto. - * @param matrix The matrix to store. - * @param transpose Should the matrix be transposed into the tensor? - * @return The matrix TensorProto. - */ - protected static OnnxMl.TensorProto matrixBuilder(ONNXContext context, String name, DenseMatrix matrix, boolean transpose) { - OnnxMl.TensorProto.Builder matrixBuilder = OnnxMl.TensorProto.newBuilder(); - matrixBuilder.setName(context.generateUniqueName(name)); - int dim1, dim2; - if (transpose) { - dim1 = matrix.getDimension2Size(); - dim2 = matrix.getDimension1Size(); - } else { - dim1 = matrix.getDimension1Size(); - dim2 = matrix.getDimension2Size(); - } - matrixBuilder.addDims(dim1); - matrixBuilder.addDims(dim2); - matrixBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - ByteBuffer buffer = ByteBuffer.allocate(dim1 * dim2 * 4).order(ByteOrder.LITTLE_ENDIAN); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int i = 0; i < dim1; i++) { - for (int j = 0; j < dim2; j++) { - if (transpose) { - floatBuffer.put((float) matrix.get(j, i)); - } else { - floatBuffer.put((float) matrix.get(i, j)); - } - } - } - floatBuffer.rewind(); - matrixBuilder.setRawData(ByteString.copyFrom(buffer)); - return matrixBuilder.build(); - } - - /** - * Builds a TensorProto containing the supplied dense vector. - * - * @param context The ONNX context for naming. - * @param name The name for this tensor proto. - * @param vector The vector to store. - * @return The vector TensorProto. - */ - protected static OnnxMl.TensorProto vectorBuilder(ONNXContext context, String name, DenseVector vector) { - OnnxMl.TensorProto.Builder vectorBuilder = OnnxMl.TensorProto.newBuilder(); - vectorBuilder.setName(context.generateUniqueName(name)); - vectorBuilder.addDims(vector.size()); - vectorBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - ByteBuffer buffer = ByteBuffer.allocate(vector.size() * 4).order(ByteOrder.LITTLE_ENDIAN); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int i = 0; i < vector.size(); i++) { - floatBuffer.put((float) vector.get(i)); - } - floatBuffer.rewind(); - vectorBuilder.setRawData(ByteString.copyFrom(buffer)); - return vectorBuilder.build(); - } - /** * Constructs the shared stem of the Factorization Machine, used by all output types. *

@@ -299,18 +237,19 @@ protected String generateONNXGraph(ONNXContext context, OnnxMl.GraphProto.Builde graphBuilder.addInitializer(twoConst); // Add weights - OnnxMl.TensorProto weightInitializerProto = matrixBuilder(context, "fm_linear_weights", (DenseMatrix) modelParams[1], true); + OnnxMl.TensorProto weightInitializerProto = ONNXMathUtils.floatMatrixBuilder(context, "fm_linear_weights", (Matrix) modelParams[1], true); graphBuilder.addInitializer(weightInitializerProto); // Add biases - OnnxMl.TensorProto biasInitializerProto = vectorBuilder(context, "fm_biases", (DenseVector) modelParams[0]); + OnnxMl.TensorProto biasInitializerProto = ONNXMathUtils.floatVectorBuilder(context, "fm_biases", (SGDVector) modelParams[0]); graphBuilder.addInitializer(biasInitializerProto); // Add embedding vectors OnnxMl.TensorProto[] embeddingProtos = new OnnxMl.TensorProto[outputIDInfo.size()]; for (int i = 0; i < outputIDInfo.size(); i++) { - embeddingProtos[i] = matrixBuilder(context, "fm_embedding_" + i, (DenseMatrix) modelParams[i + 2], true); + embeddingProtos[i] = ONNXMathUtils.floatMatrixBuilder(context, "fm_embedding_" + i, (Matrix) modelParams[i + 2], false); graphBuilder.addInitializer(embeddingProtos[i]); + System.out.println("base shape:" + Arrays.toString(modelParams[i + 2].getShape()) + "\nonnx shape: " + embeddingProtos[i].getDimsList().toString()); } // Make gemm diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java index 2469dff11..c5643ab44 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java @@ -17,7 +17,6 @@ package org.tribuo.common.sgd; import ai.onnx.proto.OnnxMl; -import com.google.protobuf.ByteString; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.Excuse; @@ -30,15 +29,15 @@ import org.tribuo.Tribuo; import org.tribuo.math.LinearParameters; import org.tribuo.math.la.DenseMatrix; +import org.tribuo.math.la.Matrix; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.onnx.ONNXContext; import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.provenance.ModelProvenance; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -197,22 +196,15 @@ protected OnnxMl.ModelProto innerExportONNXModel(OnnxMl.GraphProto graph, String * @return The weight TensorProto. */ protected OnnxMl.TensorProto weightBuilder(ONNXContext context) { - DenseMatrix weightMatrix = (DenseMatrix) modelParameters.get()[0]; - OnnxMl.TensorProto.Builder weightBuilder = OnnxMl.TensorProto.newBuilder(); - weightBuilder.setName(context.generateUniqueName("linear_sgd_weights")); - weightBuilder.addDims(featureIDMap.size()); - weightBuilder.addDims(outputIDInfo.size()); - weightBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - ByteBuffer buffer = ByteBuffer.allocate(featureIDMap.size() * outputIDInfo.size() * 4).order(ByteOrder.LITTLE_ENDIAN); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int j = 0; j < weightMatrix.getDimension2Size() - 1; j++) { - for (int i = 0; i < weightMatrix.getDimension1Size(); i++) { - floatBuffer.put((float) weightMatrix.get(i, j)); - } - } - floatBuffer.rewind(); - weightBuilder.setRawData(ByteString.copyFrom(buffer)); - return weightBuilder.build(); + final Matrix weightMatrix = (Matrix) modelParameters.get()[0]; + return ONNXMathUtils.floatTensorBuilder(context, "linear_sgd_weights", Arrays.asList(featureIDMap.size(), outputIDInfo.size()), + fb -> { + for (int j = 0; j < weightMatrix.getDimension2Size() - 1; j++) { + for (int i = 0; i < weightMatrix.getDimension1Size(); i++) { + fb.put((float) weightMatrix.get(i, j)); + } + } + }); } /** @@ -221,19 +213,13 @@ protected OnnxMl.TensorProto weightBuilder(ONNXContext context) { * @return The bias TensorProto. */ protected OnnxMl.TensorProto biasBuilder(ONNXContext context) { - DenseMatrix weightMatrix = (DenseMatrix) modelParameters.get()[0]; - OnnxMl.TensorProto.Builder biasBuilder = OnnxMl.TensorProto.newBuilder(); - biasBuilder.setName(context.generateUniqueName("linear_sgd_biases")); - biasBuilder.addDims(outputIDInfo.size()); - biasBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - ByteBuffer buffer = ByteBuffer.allocate(outputIDInfo.size()*4).order(ByteOrder.LITTLE_ENDIAN); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int i = 0; i < weightMatrix.getDimension1Size(); i++) { - floatBuffer.put((float)weightMatrix.get(i,weightMatrix.getDimension2Size()-1)); - } - floatBuffer.rewind(); - biasBuilder.setRawData(ByteString.copyFrom(buffer)); - return biasBuilder.build(); + Matrix weightMatrix = (Matrix) modelParameters.get()[0]; + return ONNXMathUtils.floatTensorBuilder(context, "linear_sgd_biases", Collections.singletonList(outputIDInfo.size()), + fb -> { + for (int i = 0; i < weightMatrix.getDimension1Size(); i++) { + fb.put((float)weightMatrix.get(i,weightMatrix.getDimension2Size()-1)); + } + }); } } diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java b/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java index 1c054ab78..f547f34a8 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java +++ b/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java @@ -307,7 +307,7 @@ public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] out for (String o : outputs) { nodeBuilder.addOutput(o); } - nodeBuilder.setName(context.generateUniqueName(opName)); + nodeBuilder.setName(context.generateUniqueName(opName) + ":" + outputs[0]); nodeBuilder.setOpType(opName); for (Map.Entry e : attributeValues.entrySet()) { ONNXAttribute attr = attributes.get(e.getKey()); diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java b/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java deleted file mode 100644 index e4cb406ee..000000000 --- a/Core/src/main/java/org/tribuo/onnx/ONNXUtils.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.tribuo.onnx; - -import ai.onnx.proto.OnnxMl; -import com.google.protobuf.ByteString; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; - -/** - * Helper functions for building ONNX protos. - */ -public abstract class ONNXUtils { - - /** - * Private constructor for abstract util class. - */ - private ONNXUtils() {} - - /** - * Builds a type proto for the specified shape and tensor type. - * @param shape The shape. - * @param type The tensor type. - * @return The type proto. - */ - public static OnnxMl.TypeProto buildTensorTypeNode(ONNXShape shape, OnnxMl.TensorProto.DataType type) { - OnnxMl.TypeProto.Builder builder = OnnxMl.TypeProto.newBuilder(); - - OnnxMl.TypeProto.Tensor.Builder tensorBuilder = OnnxMl.TypeProto.Tensor.newBuilder(); - tensorBuilder.setElemType(type.getNumber()); - tensorBuilder.setShape(shape.getProto()); - builder.setTensorType(tensorBuilder.build()); - - return builder.build(); - } - - /** - * Builds a TensorProto containing the array. - *

- * Downcasts the doubles into floats as ONNX's fp64 support is poor compared to fp32. - * @param context The naming context. - * @param name The base name for the proto. - * @param parameters The array to store in the proto. - * @return A TensorProto containing the array as floats. - */ - public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters) { - return arrayBuilder(context,name,parameters,true); - } - - /** - * Builds a TensorProto containing the array. - *

- * Optionally downcasts the doubles into floats. - * @param context The naming context. - * @param name The base name for the proto. - * @param parameters The array to store in the proto. - * @param downcast Downcasts the doubles into floats. - * @return A TensorProto containing the array as either floats or doubles. - */ - public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters, boolean downcast) { - OnnxMl.TensorProto.Builder arrBuilder = OnnxMl.TensorProto.newBuilder(); - arrBuilder.setName(context.generateUniqueName(name)); - arrBuilder.addDims(parameters.length); - int capacity = downcast ? parameters.length * 4 : parameters.length * 8; - ByteBuffer buffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN); - if (downcast) { - arrBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int i = 0; i < parameters.length; i++) { - floatBuffer.put((float) parameters[i]); - } - floatBuffer.rewind(); - } else { - arrBuilder.setDataType(OnnxMl.TensorProto.DataType.DOUBLE.getNumber()); - DoubleBuffer doubleBuffer = buffer.asDoubleBuffer(); - doubleBuffer.put(parameters); - doubleBuffer.rewind(); - } - arrBuilder.setRawData(ByteString.copyFrom(buffer)); - return arrBuilder.build(); - } - -} diff --git a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java new file mode 100644 index 000000000..a3a4b946e --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.onnx; + +import ai.onnx.proto.OnnxMl; +import com.google.protobuf.ByteString; +import org.tribuo.math.la.Matrix; +import org.tribuo.math.la.SGDVector; +import org.tribuo.onnx.ONNXContext; +import org.tribuo.onnx.ONNXShape; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +/** + * Helper functions for building ONNX protos. + */ +public abstract class ONNXMathUtils { + + /** + * Private constructor for abstract util class. + */ + private ONNXMathUtils() {} + + /** + * Builds a type proto for the specified shape and tensor type. + * @param shape The shape. + * @param type The tensor type. + * @return The type proto. + */ + public static OnnxMl.TypeProto buildTensorTypeNode(ONNXShape shape, OnnxMl.TensorProto.DataType type) { + OnnxMl.TypeProto.Builder builder = OnnxMl.TypeProto.newBuilder(); + + OnnxMl.TypeProto.Tensor.Builder tensorBuilder = OnnxMl.TypeProto.Tensor.newBuilder(); + tensorBuilder.setElemType(type.getNumber()); + tensorBuilder.setShape(shape.getProto()); + builder.setTensorType(tensorBuilder.build()); + + return builder.build(); + } + + /** + * Generic method to create float {@link ai.onnx.proto.OnnxMl.TensorProto} instances. + * + * @param context the naming context. + * @param name the base name for the proto. + * @param dims the dimensions of the input data. + * @param dataPopulator a method to populate a {@link FloatBuffer} that will be written into the TensorProto's rawData field. + * @return a float-typed TensorProto representation of the data. + */ + public static OnnxMl.TensorProto floatTensorBuilder(ONNXContext context, String name, List dims, Consumer dataPopulator) { + int size = dims.stream().reduce((a, b) -> a * b).orElse(0); + ByteBuffer buffer = ByteBuffer.allocate(size * 4).order(ByteOrder.LITTLE_ENDIAN); + FloatBuffer floatBuffer = buffer.asFloatBuffer(); + dataPopulator.accept(floatBuffer); + floatBuffer.rewind(); + return OnnxMl.TensorProto.newBuilder() + .setName(context.generateUniqueName(name)) + .setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()) + .addAllDims(() -> dims.stream().map(Integer::longValue).iterator()) + .setRawData(ByteString.copyFrom(buffer)) + .build(); + } + + /** + * Generic method to create double {@link ai.onnx.proto.OnnxMl.TensorProto} instances. + *

+ * Note that ONNX fp64 support is poor compared to fp32. + * @param context the naming context. + * @param name the base name for the proto. + * @param dims the dimensions of the input data. + * @param dataPopulator a method to populate a {@link DoubleBuffer} that will be written into the TensorProto's rawData field. + * @return a double-typed TensorProto representation of the data. + */ + public static OnnxMl.TensorProto doubleTensorBuilder(ONNXContext context, String name, List dims, Consumer dataPopulator) { + int size = dims.stream().reduce((a, b) -> a * b).orElse(0); + ByteBuffer buffer = ByteBuffer.allocate(size * 8).order(ByteOrder.LITTLE_ENDIAN); + DoubleBuffer doubleBuffer = buffer.asDoubleBuffer(); + dataPopulator.accept(doubleBuffer); + doubleBuffer.rewind(); + return OnnxMl.TensorProto.newBuilder() + .setName(context.generateUniqueName(name)) + .setDataType(OnnxMl.TensorProto.DataType.DOUBLE.getNumber()) + .addAllDims(() -> dims.stream().map(Integer::longValue).iterator()) + .setRawData(ByteString.copyFrom(buffer)) + .build(); + } + + /** + * Builds a TensorProto containing the array. + *

+ * Downcasts the doubles into floats as ONNX's fp64 support is poor compared to fp32. + * @param context The naming context. + * @param name The base name for the proto. + * @param parameters The array to store in the proto. + * @return A TensorProto containing the array as floats. + */ + public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters) { + return arrayBuilder(context,name,parameters,true); + } + + /** + * Builds a TensorProto containing the array. + *

+ * Optionally downcasts the doubles into floats. + * @param context The naming context. + * @param name The base name for the proto. + * @param parameters The array to store in the proto. + * @param downcast Downcasts the doubles into floats. + * @return A TensorProto containing the array as either floats or doubles. + */ + public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters, boolean downcast) { + return downcast + ? floatTensorBuilder(context, name, Collections.singletonList(parameters.length), + fb -> Arrays.stream(parameters).forEachOrdered(d -> fb.put((float)d))) + : doubleTensorBuilder(context, name, Collections.singletonList(parameters.length), + db -> Arrays.stream(parameters).forEachOrdered(db::put)); + } + + /** + * Builds a TensorProto containing the {@link SGDVector}. + * @param context The naming context. + * @param name The base name for the proto. + * @param vector the SGDVector to store in the proto. + * @return A TensorProto containing the vector. + */ + public static OnnxMl.TensorProto floatVectorBuilder(ONNXContext context, String name, SGDVector vector) { + return floatTensorBuilder(context, name, Collections.singletonList(vector.size()), + fb -> vector.forEach(vt -> fb.put(vt.index,(float) vt.value))); + } + + /** + * Builds a TensorProto containing the {@link Matrix}. + * @param context The naming context. + * @param name The base name for the proto. + * @param matrix the matrix to store in the proto. + * @param transpose Whether to transpose the vector before writing it. + * @return A TensorProto containing the matrix + */ + public static OnnxMl.TensorProto floatMatrixBuilder(ONNXContext context, String name, Matrix matrix, boolean transpose) { + return floatTensorBuilder(context, name, + Arrays.stream(matrix.getShape()).boxed().collect(Collectors.toList()), + fb -> matrix.forEach(mt -> { + int address = transpose + ? mt.j * matrix.getDimension2Size() + mt.i + : mt.i * matrix.getDimension1Size() + mt.j; + System.out.println("tuple: " + mt.toString() + " address: " +address + " buffersize:" + fb.capacity()); + fb.put(address, (float) mt.value); + })); + } +} diff --git a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java index 0242c2c13..05006466d 100644 --- a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java +++ b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/fm/FMMultiLabelModel.java @@ -30,7 +30,7 @@ import org.tribuo.onnx.ONNXContext; import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.provenance.ModelProvenance; import java.util.HashMap; @@ -119,10 +119,10 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { graphBuilder.setName("FMMultiLabelModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build(); graphBuilder.addOutput(outputValueProto); diff --git a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java index 98f28a26f..91004a7f8 100644 --- a/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java +++ b/MultiLabel/SGD/src/main/java/org/tribuo/multilabel/sgd/linear/LinearSGDModel.java @@ -31,10 +31,9 @@ import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.provenance.ModelProvenance; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -121,10 +120,10 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { graphBuilder.setName("MultiLabel-LinearSGDModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build(); graphBuilder.addOutput(outputValueProto); diff --git a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java index 973be0250..cc8909975 100644 --- a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java +++ b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/fm/FMRegressionModel.java @@ -27,7 +27,7 @@ import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.provenance.ModelProvenance; import org.tribuo.regression.ImmutableRegressionInfo; import org.tribuo.regression.Regressor; @@ -121,11 +121,11 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { graphBuilder.setName("FMMultiLabelModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); String outputName = "output"; - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName(outputName).build(); graphBuilder.addOutput(outputValueProto); @@ -143,9 +143,9 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { } // Create mean and variance initializers - OnnxMl.TensorProto outputMeanProto = ONNXUtils.arrayBuilder(context,context.generateUniqueName("y_mean"),means); + OnnxMl.TensorProto outputMeanProto = ONNXMathUtils.arrayBuilder(context,context.generateUniqueName("y_mean"),means); graphBuilder.addInitializer(outputMeanProto); - OnnxMl.TensorProto outputVarianceProto = ONNXUtils.arrayBuilder(context, context.generateUniqueName("y_var"),variances); + OnnxMl.TensorProto outputVarianceProto = ONNXMathUtils.arrayBuilder(context, context.generateUniqueName("y_var"),variances); graphBuilder.addInitializer(outputVarianceProto); // Add standardisation operations diff --git a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java index 801a2c0f9..3f16e5ba2 100644 --- a/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java +++ b/Regression/SGD/src/main/java/org/tribuo/regression/sgd/linear/LinearSGDModel.java @@ -28,14 +28,12 @@ import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.provenance.ModelProvenance; import org.tribuo.regression.Regressor; import java.io.IOException; import java.util.Arrays; -import java.util.Collections; -import java.util.List; /** * The inference time version of a linear model trained using SGD. @@ -107,10 +105,10 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { graphBuilder.setName("Regression-LinearSGDModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,featureIDMap.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1,outputIDInfo.size()}, new String[]{"batch",null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build(); graphBuilder.addOutput(outputValueProto); diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java index ea007eefb..67c4fd0f8 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java @@ -17,7 +17,6 @@ package org.tribuo.regression.slm; import ai.onnx.proto.OnnxMl; -import com.google.protobuf.ByteString; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.Excuse; @@ -30,11 +29,11 @@ import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SparseVector; import org.tribuo.math.la.VectorTuple; +import org.tribuo.math.onnx.ONNXMathUtils; import org.tribuo.onnx.ONNXContext; import org.tribuo.onnx.ONNXExportable; import org.tribuo.onnx.ONNXOperators; import org.tribuo.onnx.ONNXShape; -import org.tribuo.onnx.ONNXUtils; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.regression.ImmutableRegressionInfo; @@ -43,9 +42,6 @@ import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.FloatBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -247,91 +243,51 @@ public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) { return builder.build(); } - /** - * Builds a TensorProto containing the model weights. - * @param context The naming context. - * @return The weight TensorProto. - */ - protected OnnxMl.TensorProto weightBuilder(ONNXContext context) { - // Make a dense copy of the weights so the other logic is O(dn) not O(dn log n). - DenseVector[] denseWeights = new DenseVector[weights.length]; - for (int i = 0; i < denseWeights.length; i++) { - denseWeights[i] = weights[i].densify(); - } - OnnxMl.TensorProto.Builder weightBuilder = OnnxMl.TensorProto.newBuilder(); - weightBuilder.setName(context.generateUniqueName("slm_weights")); - weightBuilder.addDims(featureIDMap.size()); - weightBuilder.addDims(outputIDInfo.size()); - weightBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - ByteBuffer buffer = ByteBuffer.allocate(featureIDMap.size() * outputIDInfo.size() * 4).order(ByteOrder.LITTLE_ENDIAN); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int j = 0; j < featureIDMap.size(); j++) { - for (int i = 0; i < denseWeights.length; i++) { - floatBuffer.put((float) denseWeights[i].get(j)); - } - } - floatBuffer.rewind(); - weightBuilder.setRawData(ByteString.copyFrom(buffer)); - return weightBuilder.build(); - } - - /** - * Builds a TensorProto containing the model biases. - * @param context The naming context. - * @return The bias TensorProto. - */ - protected OnnxMl.TensorProto biasBuilder(ONNXContext context) { - OnnxMl.TensorProto.Builder biasBuilder = OnnxMl.TensorProto.newBuilder(); - biasBuilder.setName(context.generateUniqueName("slm_biases")); - biasBuilder.addDims(outputIDInfo.size()); - biasBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); - ByteBuffer buffer = ByteBuffer.allocate(outputIDInfo.size()*4).order(ByteOrder.LITTLE_ENDIAN); - FloatBuffer floatBuffer = buffer.asFloatBuffer(); - for (int i = 0; i < weights.length; i++) { - if (bias) { - floatBuffer.put((float)weights[i].get(featureIDMap.size())); - } else { - floatBuffer.put(0.0f); - } - } - floatBuffer.rewind(); - biasBuilder.setRawData(ByteString.copyFrom(buffer)); - return biasBuilder.build(); - } - @Override public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { OnnxMl.GraphProto.Builder graphBuilder = OnnxMl.GraphProto.newBuilder(); graphBuilder.setName("Regression-SparseLinearModel"); // Make inputs and outputs - OnnxMl.TypeProto inputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1, featureIDMap.size()}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto inputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1, featureIDMap.size()}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto inputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(inputType).setName("input").build(); graphBuilder.addInput(inputValueProto); - OnnxMl.TypeProto outputType = ONNXUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1, outputIDInfo.size()}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT); + OnnxMl.TypeProto outputType = ONNXMathUtils.buildTensorTypeNode(new ONNXShape(new long[]{-1, outputIDInfo.size()}, new String[]{"batch", null}), OnnxMl.TensorProto.DataType.FLOAT); OnnxMl.ValueInfoProto outputValueProto = OnnxMl.ValueInfoProto.newBuilder().setType(outputType).setName("output").build(); graphBuilder.addOutput(outputValueProto); // Add weights - OnnxMl.TensorProto weightInitializerProto = weightBuilder(context); + OnnxMl.TensorProto weightInitializerProto = ONNXMathUtils.floatTensorBuilder(context, "slm_weights", Arrays.asList(featureIDMap.size(), outputIDInfo.size()), + fb -> { + DenseVector[] denseWeights = new DenseVector[weights.length]; + for (int i = 0; i < denseWeights.length; i++) { + denseWeights[i] = weights[i].densify(); + } + for (int j = 0; j < featureIDMap.size(); j++) { + for (int i = 0; i < denseWeights.length; i++) { + fb.put((float) denseWeights[i].get(j)); + } + } + }); graphBuilder.addInitializer(weightInitializerProto); // Add biases - OnnxMl.TensorProto biasInitializerProto = biasBuilder(context); + OnnxMl.TensorProto biasInitializerProto = ONNXMathUtils.floatTensorBuilder(context, "slm_biases", Collections.singletonList(outputIDInfo.size()), + fb -> Arrays.stream(weights).forEachOrdered(sv -> fb.put((float) sv.get(featureIDMap.size())))); graphBuilder.addInitializer(biasInitializerProto); // Add feature and output means double[] xMean = bias ? Arrays.copyOf(featureMeans.toArray(),featureIDMap.size()) : featureMeans.toArray(); - OnnxMl.TensorProto featureMeanProto = ONNXUtils.arrayBuilder(context, "feature_mean",xMean); + OnnxMl.TensorProto featureMeanProto = ONNXMathUtils.arrayBuilder(context, "feature_mean",xMean); graphBuilder.addInitializer(featureMeanProto); - OnnxMl.TensorProto outputMeanProto = ONNXUtils.arrayBuilder(context,"y_mean",yMean); + OnnxMl.TensorProto outputMeanProto = ONNXMathUtils.arrayBuilder(context,"y_mean",yMean); graphBuilder.addInitializer(outputMeanProto); // Add feature and output variances double[] xVariance = bias ? Arrays.copyOf(featureVariance.toArray(),featureIDMap.size()) : featureVariance.toArray(); - OnnxMl.TensorProto featureVarianceProto = ONNXUtils.arrayBuilder(context,"feature_var",xVariance); + OnnxMl.TensorProto featureVarianceProto = ONNXMathUtils.arrayBuilder(context,"feature_var",xVariance); graphBuilder.addInitializer(featureVarianceProto); - OnnxMl.TensorProto outputVarianceProto = ONNXUtils.arrayBuilder(context, "y_var",yVariance); + OnnxMl.TensorProto outputVarianceProto = ONNXMathUtils.arrayBuilder(context, "y_var",yVariance); graphBuilder.addInitializer(outputVarianceProto); // Scale features From 846fd00ad12ca324507a5b8d52dacd9e972deccb Mon Sep 17 00:00:00 2001 From: John Sullivan Date: Fri, 22 Oct 2021 16:55:31 -0400 Subject: [PATCH 2/3] Fix `OnnxMathUtils.floatMatrixBuilder` to properly transpose non-square matrices --- .../java/org/tribuo/common/sgd/AbstractFMModel.java | 4 +--- .../main/java/org/tribuo/math/onnx/ONNXMathUtils.java | 11 +++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java index 86042b67c..16088f475 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java @@ -36,7 +36,6 @@ import org.tribuo.provenance.ModelProvenance; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -247,9 +246,8 @@ protected String generateONNXGraph(ONNXContext context, OnnxMl.GraphProto.Builde // Add embedding vectors OnnxMl.TensorProto[] embeddingProtos = new OnnxMl.TensorProto[outputIDInfo.size()]; for (int i = 0; i < outputIDInfo.size(); i++) { - embeddingProtos[i] = ONNXMathUtils.floatMatrixBuilder(context, "fm_embedding_" + i, (Matrix) modelParams[i + 2], false); + embeddingProtos[i] = ONNXMathUtils.floatMatrixBuilder(context, "fm_embedding_" + i, (Matrix) modelParams[i + 2], true); graphBuilder.addInitializer(embeddingProtos[i]); - System.out.println("base shape:" + Arrays.toString(modelParams[i + 2].getShape()) + "\nonnx shape: " + embeddingProtos[i].getDimsList().toString()); } // Make gemm diff --git a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java index a3a4b946e..b150d42c2 100644 --- a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java +++ b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java @@ -159,13 +159,16 @@ public static OnnxMl.TensorProto floatVectorBuilder(ONNXContext context, String * @return A TensorProto containing the matrix */ public static OnnxMl.TensorProto floatMatrixBuilder(ONNXContext context, String name, Matrix matrix, boolean transpose) { + List dims = Arrays.stream(matrix.getShape()).boxed().collect(Collectors.toList()); + if(transpose) { + Collections.reverse(dims); + } return floatTensorBuilder(context, name, - Arrays.stream(matrix.getShape()).boxed().collect(Collectors.toList()), + dims, fb -> matrix.forEach(mt -> { int address = transpose - ? mt.j * matrix.getDimension2Size() + mt.i - : mt.i * matrix.getDimension1Size() + mt.j; - System.out.println("tuple: " + mt.toString() + " address: " +address + " buffersize:" + fb.capacity()); + ? mt.j * matrix.getDimension1Size() + mt.i + : mt.i * matrix.getDimension2Size() + mt.j; fb.put(address, (float) mt.value); })); } From d14f9238a642e506e5d8345b4ecbc908cdd01a07 Mon Sep 17 00:00:00 2001 From: John Sullivan Date: Mon, 25 Oct 2021 14:24:17 -0400 Subject: [PATCH 3/3] Updates based on comments --- .../common/sgd/AbstractLinearSGDModel.java | 5 +++-- .../main/java/org/tribuo/onnx/ONNXOperators.java | 2 +- .../org/tribuo/provenance/ModelProvenance.java | 2 -- .../java/org/tribuo/math/onnx/ONNXMathUtils.java | 16 +++++++++------- .../tribuo/regression/slm/SparseLinearModel.java | 11 ++++------- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java index c5643ab44..55d448a11 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractLinearSGDModel.java @@ -36,6 +36,7 @@ import org.tribuo.onnx.ONNXOperators; import org.tribuo.provenance.ModelProvenance; +import java.nio.FloatBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -198,7 +199,7 @@ protected OnnxMl.ModelProto innerExportONNXModel(OnnxMl.GraphProto graph, String protected OnnxMl.TensorProto weightBuilder(ONNXContext context) { final Matrix weightMatrix = (Matrix) modelParameters.get()[0]; return ONNXMathUtils.floatTensorBuilder(context, "linear_sgd_weights", Arrays.asList(featureIDMap.size(), outputIDInfo.size()), - fb -> { + (FloatBuffer fb) -> { for (int j = 0; j < weightMatrix.getDimension2Size() - 1; j++) { for (int i = 0; i < weightMatrix.getDimension1Size(); i++) { fb.put((float) weightMatrix.get(i, j)); @@ -215,7 +216,7 @@ protected OnnxMl.TensorProto weightBuilder(ONNXContext context) { protected OnnxMl.TensorProto biasBuilder(ONNXContext context) { Matrix weightMatrix = (Matrix) modelParameters.get()[0]; return ONNXMathUtils.floatTensorBuilder(context, "linear_sgd_biases", Collections.singletonList(outputIDInfo.size()), - fb -> { + (FloatBuffer fb) -> { for (int i = 0; i < weightMatrix.getDimension1Size(); i++) { fb.put((float)weightMatrix.get(i,weightMatrix.getDimension2Size()-1)); } diff --git a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java b/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java index f547f34a8..1c054ab78 100644 --- a/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java +++ b/Core/src/main/java/org/tribuo/onnx/ONNXOperators.java @@ -307,7 +307,7 @@ public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] out for (String o : outputs) { nodeBuilder.addOutput(o); } - nodeBuilder.setName(context.generateUniqueName(opName) + ":" + outputs[0]); + nodeBuilder.setName(context.generateUniqueName(opName)); nodeBuilder.setOpType(opName); for (Map.Entry e : attributeValues.entrySet()) { ONNXAttribute attr = attributes.get(e.getKey()); diff --git a/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java b/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java index 929233182..61d8d8c69 100644 --- a/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java +++ b/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java @@ -19,7 +19,6 @@ import com.oracle.labs.mlrg.olcut.provenance.MapProvenance; import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.Provenance; -import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException; import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance; import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance; import com.oracle.labs.mlrg.olcut.util.Pair; @@ -32,7 +31,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; /** * Contains provenance information for an instance of a {@link org.tribuo.Model}. diff --git a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java index b150d42c2..276f3f923 100644 --- a/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java +++ b/Math/src/main/java/org/tribuo/math/onnx/ONNXMathUtils.java @@ -78,7 +78,7 @@ public static OnnxMl.TensorProto floatTensorBuilder(ONNXContext context, String return OnnxMl.TensorProto.newBuilder() .setName(context.generateUniqueName(name)) .setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()) - .addAllDims(() -> dims.stream().map(Integer::longValue).iterator()) + .addAllDims(dims.stream().map(Integer::longValue).collect(Collectors.toList())) .setRawData(ByteString.copyFrom(buffer)) .build(); } @@ -131,11 +131,13 @@ public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, * @return A TensorProto containing the array as either floats or doubles. */ public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters, boolean downcast) { - return downcast - ? floatTensorBuilder(context, name, Collections.singletonList(parameters.length), - fb -> Arrays.stream(parameters).forEachOrdered(d -> fb.put((float)d))) - : doubleTensorBuilder(context, name, Collections.singletonList(parameters.length), - db -> Arrays.stream(parameters).forEachOrdered(db::put)); + if(downcast) { + return floatTensorBuilder(context, name, Collections.singletonList(parameters.length), + (FloatBuffer fb) -> Arrays.stream(parameters).forEachOrdered(d -> fb.put((float)d))); + } else { + return doubleTensorBuilder(context, name, Collections.singletonList(parameters.length), + (DoubleBuffer db) -> Arrays.stream(parameters).forEachOrdered(db::put)); + } } /** @@ -147,7 +149,7 @@ public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, */ public static OnnxMl.TensorProto floatVectorBuilder(ONNXContext context, String name, SGDVector vector) { return floatTensorBuilder(context, name, Collections.singletonList(vector.size()), - fb -> vector.forEach(vt -> fb.put(vt.index,(float) vt.value))); + (FloatBuffer fb) -> vector.forEach(vt -> fb.put(vt.index,(float) vt.value))); } /** diff --git a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java index 67c4fd0f8..b54b1de07 100644 --- a/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java +++ b/Regression/SLM/src/main/java/org/tribuo/regression/slm/SparseLinearModel.java @@ -42,6 +42,7 @@ import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel; import java.io.IOException; +import java.nio.FloatBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -259,13 +260,9 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { // Add weights OnnxMl.TensorProto weightInitializerProto = ONNXMathUtils.floatTensorBuilder(context, "slm_weights", Arrays.asList(featureIDMap.size(), outputIDInfo.size()), fb -> { - DenseVector[] denseWeights = new DenseVector[weights.length]; - for (int i = 0; i < denseWeights.length; i++) { - denseWeights[i] = weights[i].densify(); - } for (int j = 0; j < featureIDMap.size(); j++) { - for (int i = 0; i < denseWeights.length; i++) { - fb.put((float) denseWeights[i].get(j)); + for (int i = 0; i < weights.length; i++) { + fb.put((float) weights[i].get(j)); } } }); @@ -273,7 +270,7 @@ public OnnxMl.GraphProto exportONNXGraph(ONNXContext context) { // Add biases OnnxMl.TensorProto biasInitializerProto = ONNXMathUtils.floatTensorBuilder(context, "slm_biases", Collections.singletonList(outputIDInfo.size()), - fb -> Arrays.stream(weights).forEachOrdered(sv -> fb.put((float) sv.get(featureIDMap.size())))); + (FloatBuffer fb) -> Arrays.stream(weights).forEachOrdered(sv -> fb.put((float) sv.get(featureIDMap.size())))); graphBuilder.addInitializer(biasInitializerProto); // Add feature and output means