Skip to content

Commit

Permalink
Adding protobuf serialization for liblinear models. (#273)
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp authored Sep 20, 2022
1 parent c82ec00 commit b35827b
Show file tree
Hide file tree
Showing 13 changed files with 2,445 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.anomaly.liblinear;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -27,10 +29,15 @@
import org.tribuo.anomaly.Event;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.provenance.ModelProvenance;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -67,6 +74,42 @@ public class LibLinearAnomalyModel extends LibLinearModel<Event> {
super(name, description, featureIDMap, outputIDInfo, false, models);
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static LibLinearAnomalyModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
if (!"org.tribuo.anomaly.liblinear.LibLinearAnomalyModel".equals(className)) {
throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearAnomalyModel");
}
LibLinearModelProto proto = message.unpack(LibLinearModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Event.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not an anomaly domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Event> outputDomain = (ImmutableOutputInfo<Event>) carrier.outputDomain();

if (proto.getModelsCount() != 1) {
throw new IllegalStateException("Invalid protobuf, expected 1 model, found " + proto.getModelsCount());
}
try {
ByteArrayInputStream bais = new ByteArrayInputStream(proto.getModels(0).toByteArray());
ObjectInputStream ois = new ObjectInputStream(bais);
de.bwaldvogel.liblinear.Model model = (de.bwaldvogel.liblinear.Model) ois.readObject();
ois.close();
return new LibLinearAnomalyModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,Collections.singletonList(model));
} catch (IOException | ClassNotFoundException e) {
throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e);
}
}

@Override
public Prediction<Event> predict(Example<Event> example) {
FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022 Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,6 +66,7 @@ public void gaussianDataTest() {

// Test serialization
Helpers.testModelSerialization(model,Event.class);
Helpers.testModelProtoSerialization(model,Event.class,testData);
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
package org.tribuo.classification.liblinear;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import de.bwaldvogel.liblinear.FeatureNode;
import de.bwaldvogel.liblinear.Linear;
Expand All @@ -31,6 +33,8 @@
import org.tribuo.classification.Label;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.common.liblinear.LibLinearTrainer;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
Expand All @@ -39,6 +43,9 @@
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -104,6 +111,42 @@ public class LibLinearClassificationModel extends LibLinearModel<Label> implemen
}
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static LibLinearClassificationModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
if (!"org.tribuo.classification.liblinear.LibLinearClassificationModel".equals(className)) {
throw new IllegalStateException("Invalid protobuf, this class can only deserialize LibLinearClassificationModel");
}
LibLinearModelProto proto = message.unpack(LibLinearModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

if (proto.getModelsCount() != 1) {
throw new IllegalStateException("Invalid protobuf, expected 1 model, found " + proto.getModelsCount());
}
try {
ByteArrayInputStream bais = new ByteArrayInputStream(proto.getModels(0).toByteArray());
ObjectInputStream ois = new ObjectInputStream(bais);
de.bwaldvogel.liblinear.Model model = (de.bwaldvogel.liblinear.Model) ois.readObject();
ois.close();
return new LibLinearClassificationModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,Collections.singletonList(model));
} catch (IOException | ClassNotFoundException e) {
throw new IllegalStateException("Invalid protobuf, failed to deserialize liblinear model", e);
}
}

@Override
public Prediction<Label> predict(Example<Label> example) {
FeatureNode[] features = LibLinearTrainer.exampleToNodes(example, featureIDMap, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ public void testDenseData() {

// Test serialization
Helpers.testModelSerialization(model,Label.class);

Helpers.testModelProtoSerialization(model, Label.class, p.getB());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,23 +16,36 @@

package org.tribuo.common.liblinear;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.SolverType;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.common.liblinear.protos.LibLinearProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.StringReader;
import java.io.StringWriter;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* A {@link Model} which wraps a LibLinear-java model.
Expand All @@ -57,6 +70,11 @@ public abstract class LibLinearModel<T extends Output<T>> extends Model<T> {

private static final Logger logger = Logger.getLogger(LibLinearModel.class.getName());

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

/**
* The list of LibLinear models. Multiple models are used by multi-label and multidimensional regression outputs.
* <p>
Expand Down Expand Up @@ -163,4 +181,97 @@ protected static de.bwaldvogel.liblinear.Model copyModel(de.bwaldvogel.liblinear
* @return An excuse for this example.
*/
protected abstract Excuse<T> innerGetExcuse(Example<T> e, double[][] featureWeights);

@Override
public ModelProto serialize() {
ModelDataCarrier<T> carrier = createDataCarrier();

LibLinearModelProto.Builder modelBuilder = LibLinearModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
for (de.bwaldvogel.liblinear.Model m : models) {
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(m);
oos.close();
modelBuilder.addModels(ByteString.copyFrom(baos.toByteArray()));
} catch (IOException e) {
throw new IllegalStateException("Could not serialize liblinear model to byte array");
}
}

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(this.getClass().getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

/**
* Serialize the LibLinear model into a protobuf.
* <p>
* Note deserializing {@link LibLinearProto} requires reflective access into {@code de.bwaldvogel.liblinear.Model}
* and thus requires additional command line permissions when running with the module system. For the time
* being we use Java serialization to a byte array rather than this method and {@code LibLinearProto}.
* @param model The model to serialize.
* @return The protobuf.
*/
private static LibLinearProto serializeModel(de.bwaldvogel.liblinear.Model model) {
LibLinearProto.Builder builder = LibLinearProto.newBuilder();

builder.setBias(model.getBias());
builder.addAllLabel(Arrays.stream(model.getLabels()).boxed().collect(Collectors.toList()));
builder.setNrClass(model.getNrClass());
builder.setNrFeature(model.getNrFeature());
builder.setSolverType(model.getSolverType().name());
builder.addAllW(Arrays.stream(model.getFeatureWeights()).boxed().collect(Collectors.toList()));
if (model.getSolverType().isOneClass()) {
builder.setRho(model.getDecfunRho());
}

return builder.build();
}

/**
* Deserialize a LibLinear model from a protobuf.
* <p>
* Note this method requires reflective access into {@code de.bwaldvogel.liblinear.Model} and thus
* requires additional command line permissions when running with the module system. For the time
* being we use Java serialization to a byte array rather than this method and {@code LibLinearProto}.
* @param proto The protobuf to deserialize.
* @return The model.
*/
private static de.bwaldvogel.liblinear.Model deserializeModels(LibLinearProto proto) {
de.bwaldvogel.liblinear.Model model = new de.bwaldvogel.liblinear.Model();
Class<de.bwaldvogel.liblinear.Model> modelClass = de.bwaldvogel.liblinear.Model.class;
setField(modelClass, "bias", model, proto.getBias());
setField(modelClass, "label", model, Util.toPrimitiveInt(proto.getLabelList()));
setField(modelClass, "nr_class", model, proto.getNrClass());
setField(modelClass, "nr_feature", model, proto.getNrFeature());
setField(modelClass, "solverType", model, SolverType.valueOf(proto.getSolverType()));
setField(modelClass, "w", model, Util.toPrimitiveDouble(proto.getWList()));
setField(modelClass, "rho", model, proto.getRho());
return model;
}

/**
* Sets a field on the supplied object.
* <p>
* Wraps any exceptions in {@link IllegalStateException}.
* @param clazz The class.
* @param fieldName The field name to set.
* @param host The host object.
* @param value The new field value.
*/
private static <U> void setField(Class<U> clazz, String fieldName, U host, Object value) {
try {
Field biasField = clazz.getField(fieldName);
biasField.setAccessible(true);
biasField.set(host, value);
biasField.setAccessible(false);
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new IllegalStateException("Failed to write to field " + fieldName, e);
}
}
}
Loading

0 comments on commit b35827b

Please sign in to comment.