From 26c76215301832fb66b72e36eda55773cd1cf0e0 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 9 Sep 2022 21:15:26 -0400 Subject: [PATCH] Adding protobuf serialization for liblinear models. --- .../liblinear/LibLinearAnomalyModel.java | 45 +- .../LibLinearAnomalyTrainerTest.java | 3 +- .../LibLinearClassificationModel.java | 45 +- .../liblinear/TestLibLinearModel.java | 2 + .../common/liblinear/LibLinearModel.java | 113 +- .../liblinear/protos/LibLinearModelProto.java | 777 +++++++++++ .../protos/LibLinearModelProtoOrBuilder.java | 41 + .../liblinear/protos/LibLinearProto.java | 1178 +++++++++++++++++ .../protos/LibLinearProtoOrBuilder.java | 79 ++ .../liblinear/protos/TribuoLiblinear.java | 67 + .../resources/protos/tribuo-liblinear.proto | 49 + .../liblinear/LibLinearRegressionModel.java | 49 +- .../regression/liblinear/TestLibLinear.java | 3 +- 13 files changed, 2445 insertions(+), 6 deletions(-) create mode 100644 Common/LibLinear/src/main/java/org/tribuo/common/liblinear/protos/LibLinearModelProto.java create mode 100644 Common/LibLinear/src/main/java/org/tribuo/common/liblinear/protos/LibLinearModelProtoOrBuilder.java create mode 100644 Common/LibLinear/src/main/java/org/tribuo/common/liblinear/protos/LibLinearProto.java create mode 100644 Common/LibLinear/src/main/java/org/tribuo/common/liblinear/protos/LibLinearProtoOrBuilder.java create mode 100644 Common/LibLinear/src/main/java/org/tribuo/common/liblinear/protos/TribuoLiblinear.java create mode 100644 Common/LibLinear/src/main/resources/protos/tribuo-liblinear.proto diff --git a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyModel.java b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyModel.java index 114acdf48..2669c7c2e 100644 --- a/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyModel.java +++ b/AnomalyDetection/LibLinear/src/main/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyModel.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.tribuo.anomaly.liblinear; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.Excuse; @@ -27,10 +29,15 @@ import org.tribuo.anomaly.Event; import org.tribuo.common.liblinear.LibLinearModel; import org.tribuo.common.liblinear.LibLinearTrainer; +import org.tribuo.common.liblinear.protos.LibLinearModelProto; +import org.tribuo.impl.ModelDataCarrier; import org.tribuo.provenance.ModelProvenance; import de.bwaldvogel.liblinear.FeatureNode; import de.bwaldvogel.liblinear.Linear; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -67,6 +74,42 @@ public class LibLinearAnomalyModel extends LibLinearModel { super(name, description, featureIDMap, outputIDInfo, false, models); } + /** + * Deserialization factory. + * @param version The serialized object version. + * @param className The class name. + * @param message The serialized data. + */ + public static LibLinearAnomalyModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException { + if (version < 0 || version > CURRENT_VERSION) { + throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION); + } + if (!"org.tribuo.anomaly.liblinear.LibLinearAnomalyModel".equals(className)) { + throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearAnomalyModel"); + } + LibLinearModelProto proto = message.unpack(LibLinearModelProto.class); + + ModelDataCarrier carrier = ModelDataCarrier.deserialize(proto.getMetadata()); + if (!carrier.outputDomain().getOutput(0).getClass().equals(Event.class)) { + throw new IllegalStateException("Invalid protobuf, output domain is not an anomaly domain, found " + carrier.outputDomain().getClass()); + } + @SuppressWarnings("unchecked") // guarded by getClass + ImmutableOutputInfo outputDomain = (ImmutableOutputInfo) carrier.outputDomain(); + + if (proto.getModelsCount() != 1) { + throw new IllegalStateException("Invalid protobuf, expected 1 model, found " + proto.getModelsCount()); + } + try { + ByteArrayInputStream bais = new ByteArrayInputStream(proto.getModels(0).toByteArray()); + ObjectInputStream ois = new ObjectInputStream(bais); + de.bwaldvogel.liblinear.Model model = (de.bwaldvogel.liblinear.Model) ois.readObject(); + ois.close(); + return new LibLinearAnomalyModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,Collections.singletonList(model)); + } catch (IOException | ClassNotFoundException e) { + throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e); + } + } + @Override public Prediction predict(Example example) { FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null); diff --git a/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java b/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java index ac05399b8..e66eb7bf1 100644 --- a/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java +++ b/AnomalyDetection/LibLinear/src/test/java/org/tribuo/anomaly/liblinear/LibLinearAnomalyTrainerTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -66,6 +66,7 @@ public void gaussianDataTest() { // Test serialization Helpers.testModelSerialization(model,Event.class); + Helpers.testModelProtoSerialization(model,Event.class,testData); } } diff --git a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java index 462e23b09..5b082bfb3 100644 --- a/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java +++ b/Classification/LibLinear/src/main/java/org/tribuo/classification/liblinear/LibLinearClassificationModel.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ package org.tribuo.classification.liblinear; import ai.onnx.proto.OnnxMl; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; import com.oracle.labs.mlrg.olcut.util.Pair; import de.bwaldvogel.liblinear.FeatureNode; import de.bwaldvogel.liblinear.Linear; @@ -31,6 +33,8 @@ import org.tribuo.classification.Label; import org.tribuo.common.liblinear.LibLinearModel; import org.tribuo.common.liblinear.LibLinearTrainer; +import org.tribuo.common.liblinear.protos.LibLinearModelProto; +import org.tribuo.impl.ModelDataCarrier; import org.tribuo.provenance.ModelProvenance; import org.tribuo.util.onnx.ONNXContext; import org.tribuo.util.onnx.ONNXInitializer; @@ -39,6 +43,9 @@ import org.tribuo.util.onnx.ONNXPlaceholder; import org.tribuo.util.onnx.ONNXRef; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -104,6 +111,42 @@ public class LibLinearClassificationModel extends LibLinearModel