Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds protobuf serialization for LibLinear models #273

Merged
merged 1 commit into from
Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.anomaly.liblinear;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -27,10 +29,15 @@
import org.tribuo.anomaly.Event;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.provenance.ModelProvenance;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -67,6 +74,42 @@ public class LibLinearAnomalyModel extends LibLinearModel<Event> {
super(name, description, featureIDMap, outputIDInfo, false, models);
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static LibLinearAnomalyModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
if (!"org.tribuo.anomaly.liblinear.LibLinearAnomalyModel".equals(className)) {
throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearAnomalyModel");
}
LibLinearModelProto proto = message.unpack(LibLinearModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Event.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not an anomaly domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Event> outputDomain = (ImmutableOutputInfo<Event>) carrier.outputDomain();

if (proto.getModelsCount() != 1) {
throw new IllegalStateException("Invalid protobuf, expected 1 model, found " + proto.getModelsCount());
}
try {
ByteArrayInputStream bais = new ByteArrayInputStream(proto.getModels(0).toByteArray());
ObjectInputStream ois = new ObjectInputStream(bais);
de.bwaldvogel.liblinear.Model model = (de.bwaldvogel.liblinear.Model) ois.readObject();
ois.close();
return new LibLinearAnomalyModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,Collections.singletonList(model));
} catch (IOException | ClassNotFoundException e) {
throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e);
}
}

@Override
public Prediction<Event> predict(Example<Event> example) {
FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,6 +66,7 @@ public void gaussianDataTest() {

// Test serialization
Helpers.testModelSerialization(model,Event.class);
Helpers.testModelProtoSerialization(model,Event.class,testData);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
package org.tribuo.classification.liblinear;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
Expand All @@ -31,6 +33,8 @@
import org.tribuo.classification.Label;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
Expand All @@ -39,6 +43,9 @@
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -104,6 +111,42 @@ public class LibLinearClassificationModel extends LibLinearModel<Label> implemen
}
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static LibLinearClassificationModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
if (!"org.tribuo.classification.liblinear.LibLinearClassificationModel".equals(className)) {
throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearClassificationModel");
}
LibLinearModelProto proto = message.unpack(LibLinearModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

if (proto.getModelsCount() != 1) {
throw new IllegalStateException("Invalid protobuf, expected 1 model, found " + proto.getModelsCount());
}
try {
ByteArrayInputStream bais = new ByteArrayInputStream(proto.getModels(0).toByteArray());
ObjectInputStream ois = new ObjectInputStream(bais);
de.bwaldvogel.liblinear.Model model = (de.bwaldvogel.liblinear.Model) ois.readObject();
ois.close();
return new LibLinearClassificationModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,Collections.singletonList(model));
} catch (IOException | ClassNotFoundException e) {
throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e);
}
}

@Override
public Prediction<Label> predict(Example<Label> example) {
FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ public void testDenseData() {

// Test serialization
Helpers.testModelSerialization(model,Label.class);

Helpers.testModelProtoSerialization(model, Label.class, p.getB());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,23 +16,36 @@

package org.tribuo.common.liblinear;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.SolverType;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.common.liblinear.protos.LibLinearProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.StringReader;
import java.io.StringWriter;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* A {@link Model} which wraps a LibLinear-java model.
Expand All @@ -57,6 +70,11 @@ public abstract class LibLinearModel<T extends Output<T>> extends Model<T> {

private static final Logger logger = Logger.getLogger(LibLinearModel.class.getName());

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

/**
* The list of LibLinear models. Multiple models are used by multi-label and multidimensional regression outputs.
* <p>
Expand Down Expand Up @@ -163,4 +181,97 @@ protected static de.bwaldvogel.liblinear.Model copyModel(de.bwaldvogel.liblinear
* @return An excuse for this example.
*/
protected abstract Excuse<T> innerGetExcuse(Example<T> e, double[][] featureWeights);

@Override
public ModelProto serialize() {
ModelDataCarrier<T> carrier = createDataCarrier();

LibLinearModelProto.Builder modelBuilder = LibLinearModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
for (de.bwaldvogel.liblinear.Model m : models) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(m);
oos.close();
modelBuilder.addModels(ByteString.copyFrom(baos.toByteArray()));
} catch (IOException e) {
throw new IllegalStateException("Could not serialize liblinear model to byte array");
}
}

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(this.getClass().getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

/**
* Serialize the LibLinear model into a protobuf.
* <p>
* Note deserializing {@link LibLinearProto} requires reflective access into {@code de.bwaldvogel.liblinear.Model}
* and thus requires additional command line permissions when running with the module system. For the time
* being we use Java serialization to a byte array rather than this method and {@code LibLinearProto}.
* @param model The model to serialize.
* @return The protobuf.
*/
private static LibLinearProto serializeModel(de.bwaldvogel.liblinear.Model model) {
LibLinearProto.Builder builder = LibLinearProto.newBuilder();

builder.setBias(model.getBias());
builder.addAllLabel(Arrays.stream(model.getLabels()).boxed().collect(Collectors.toList()));
builder.setNrClass(model.getNrClass());
builder.setNrFeature(model.getNrFeature());
builder.setSolverType(model.getSolverType().name());
builder.addAllW(Arrays.stream(model.getFeatureWeights()).boxed().collect(Collectors.toList()));
if (model.getSolverType().isOneClass()) {
builder.setRho(model.getDecfunRho());
}

return builder.build();
}

/**
* Deserialize a LibLinear model from a protobuf.
* <p>
* Note this method requires reflective access into {@code de.bwaldvogel.liblinear.Model} and thus
* requires additional command line permissions when running with the module system. For the time
* being we use Java serialization to a byte array rather than this method and {@code LibLinearProto}.
* @param proto The protobuf to deserialize.
* @return The model.
*/
private static de.bwaldvogel.liblinear.Model deserializeModels(LibLinearProto proto) {
de.bwaldvogel.liblinear.Model model = new de.bwaldvogel.liblinear.Model();
Class<de.bwaldvogel.liblinear.Model> modelClass = de.bwaldvogel.liblinear.Model.class;
setField(modelClass, "bias", model, proto.getBias());
setField(modelClass, "label", model, Util.toPrimitiveInt(proto.getLabelList()));
setField(modelClass, "nr_class", model, proto.getNrClass());
setField(modelClass, "nr_feature", model, proto.getNrFeature());
setField(modelClass, "solverType", model, SolverType.valueOf(proto.getSolverType()));
setField(modelClass, "w", model, Util.toPrimitiveDouble(proto.getWList()));
setField(modelClass, "rho", model, proto.getRho());
return model;
}

/**
* Sets a field on the supplied object.
* <p>
* Wraps any exceptions in {@link IllegalStateException}.
* @param clazz The class.
* @param fieldName The field name to set.
* @param host The host object.
* @param value The new field value.
*/
private static <U> void setField(Class<U> clazz, String fieldName, U host, Object value) {
try {
Field biasField = clazz.getField(fieldName);
biasField.setAccessible(true);
biasField.set(host, value);
biasField.setAccessible(false);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new IllegalStateException("Failed to write to field " + fieldName, e);
}
}
}
Loading