Skip to content

Commit

Permalink
Adds XGBoost protobuf serialization. (#270)
Browse files Browse the repository at this point in the history
* Roughing out XGBoost serialization.

* Finishing XGBoost protobuf serialization.

* Bug fixes.
  • Loading branch information
Craigacp authored Sep 12, 2022
1 parent 649cece commit 2a04a30
Show file tree
Hide file tree
Showing 18 changed files with 3,594 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,11 +16,16 @@

package org.tribuo.classification.xgboost;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.xgboost.XGBoostOutputConverter;
import org.tribuo.common.xgboost.protos.XGBoostOutputConverterProto;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoUtil;

import java.util.ArrayList;
import java.util.LinkedHashMap;
Expand All @@ -29,14 +34,36 @@
/**
* Converts XGBoost outputs into {@link Label} {@link Prediction}s.
*/
@ProtoSerializableClass(version = XGBoostClassificationConverter.CURRENT_VERSION)
public final class XGBoostClassificationConverter implements XGBoostOutputConverter<Label> {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

/**
* Constructs an XGBoostClassificationConverter.
*/
public XGBoostClassificationConverter() {}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static XGBoostClassificationConverter deserializeFromProto(int version, String className, Any message) {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
if (message.getValue() != ByteString.EMPTY) {
throw new IllegalArgumentException("Invalid proto");
}
return new XGBoostClassificationConverter();
}

@Override
public boolean generatesProbabilities() {
return true;
Expand Down Expand Up @@ -92,4 +119,14 @@ public List<Prediction<Label>> convertBatchOutput(ImmutableOutputInfo<Label> inf

return predictions;
}

@Override
public XGBoostOutputConverterProto serialize() {
return ProtoUtil.serialize(this);
}

@Override
public Class<Label> getTypeWitness() {
return Label.class;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Model<Label> model = testXGBoost(t,p);
Helpers.testModelSerialization(model,Label.class);
Helpers.testModelProtoSerialization(model, Label.class, p.getB());
testXGBoost(dart,p);
testXGBoost(linear,p);
XGBoostModel<Label> m = testXGBoost(gbtree,p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public void testMNIST() throws IOException, URISyntaxException {
assertEquals(0.0, evaluation.balancedErrorRate(), 1e-6);

Helpers.testModelSerialization(transposedMNISTXGB,Label.class);
Helpers.testModelProtoSerialization(transposedMNISTXGB, Label.class, transposedMNIST);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,13 @@

package org.tribuo.common.xgboost;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
Expand All @@ -29,12 +34,17 @@
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.common.xgboost.protos.XGBoostExternalModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

import java.io.ByteArrayInputStream;
import java.io.File;
Expand Down Expand Up @@ -89,6 +99,11 @@ public final class XGBoostExternalModel<T extends Output<T>> extends ExternalMod

private static final Logger logger = Logger.getLogger(XGBoostExternalModel.class.getName());

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final XGBoostOutputConverter<T> converter;

/**
Expand All @@ -115,6 +130,37 @@ private XGBoostExternalModel(String name, ModelProvenance provenance,
this.converter = converter;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
@SuppressWarnings({"unchecked","rawtypes"}) // output converter and domain are checked via getClass.
public static XGBoostExternalModel<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException, XGBoostError, IOException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
XGBoostExternalModelProto proto = message.unpack(XGBoostExternalModelProto.class);

XGBoostOutputConverter<?> converter = ProtoUtil.deserialize(proto.getConverter());
Class<?> converterWitness = converter.getTypeWitness();
ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(converterWitness)) {
throw new IllegalStateException("Invalid protobuf, output domain does not match the converter, found " + carrier.outputDomain().getClass() + " and " + converterWitness);
}
int[] featureForwardMapping = Util.toPrimitiveInt(proto.getForwardFeatureMappingList());
int[] featureBackwardMapping = Util.toPrimitiveInt(proto.getBackwardFeatureMappingList());
if (!validateFeatureMapping(featureForwardMapping,featureBackwardMapping,carrier.featureDomain())) {
throw new IllegalStateException("Invalid protobuf, external<->Tribuo feature mapping does not form a bijection");
}

Booster model = XGBoost.loadModel(proto.getModel().toByteArray());

return new XGBoostExternalModel(carrier.name(), carrier.provenance(), carrier.featureDomain(),
carrier.outputDomain(), featureForwardMapping, featureBackwardMapping, model, converter);
}

@Override
protected DMatrix convertFeatures(SparseVector input) {
try {
Expand Down Expand Up @@ -201,6 +247,30 @@ public Map<String, List<Pair<String, Double>>> getTopFeatures(int n) {
}
}

@Override
public ModelProto serialize() {
ModelDataCarrier<T> carrier = createDataCarrier();

XGBoostExternalModelProto.Builder modelBuilder = XGBoostExternalModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setConverter(converter.serialize());
modelBuilder.addAllForwardFeatureMapping(Arrays.stream(featureForwardMapping).boxed().collect(
Collectors.toList()));
modelBuilder.addAllBackwardFeatureMapping(Arrays.stream(featureBackwardMapping).boxed().collect(Collectors.toList()));
try {
modelBuilder.setModel(ByteString.copyFrom(model.toByteArray()));
} catch (XGBoostError e) {
throw new IllegalStateException("Failed to serialize XGBoost model");
}

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(XGBoostExternalModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

@Override
protected XGBoostExternalModel<T> copy(String newName, ModelProvenance newProvenance) {
return new XGBoostExternalModel<>(newName, newProvenance, featureIDMap, outputIDInfo,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,9 @@

package org.tribuo.common.xgboost;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableDouble;
import com.oracle.labs.mlrg.olcut.util.Pair;
import ml.dmlc.xgboost4j.java.Booster;
Expand All @@ -30,6 +33,10 @@
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.xgboost.XGBoostTrainer.DMatrixTuple;
import org.tribuo.common.xgboost.protos.XGBoostModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;

Expand Down Expand Up @@ -81,6 +88,11 @@ public final class XGBoostModel<T extends Output<T>> extends Model<T> {

private static final Logger logger = Logger.getLogger(XGBoostModel.class.getName());

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final XGBoostOutputConverter<T> converter;

// Used to signal if the model has been rewritten to fix the issue with multidimensional XGBoost regression models in 4.0 and 4.1.0.
Expand All @@ -100,6 +112,35 @@ public final class XGBoostModel<T extends Output<T>> extends Model<T> {
this.regression41MappingFix = true;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static XGBoostModel<?> deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException, XGBoostError, IOException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
XGBoostModelProto proto = message.unpack(XGBoostModelProto.class);

XGBoostOutputConverter<?> converter = ProtoUtil.deserialize(proto.getConverter());
Class<?> converterWitness = converter.getTypeWitness();
ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(converterWitness)) {
throw new IllegalStateException("Invalid protobuf, output domain does not match the converter, found " + carrier.outputDomain().getClass() + " and " + converterWitness);
}
List<Booster> models = new ArrayList<>();
for (ByteString b : proto.getModelsList()) {
models.add(XGBoost.loadModel(b.toByteArray()));
}
if (models.isEmpty()) {
throw new IllegalStateException("Invalid protobuf, no XGBoost models were found");
}

return new XGBoostModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),carrier.outputDomain(),models,converter);
}

/**
* Returns an unmodifiable list containing a copy of each model.
* <p>
Expand Down Expand Up @@ -292,6 +333,29 @@ protected Model<T> copy(String newName, ModelProvenance newProvenance) {
return new XGBoostModel<>(newName, newProvenance, featureIDMap, outputIDInfo, newModels, converter);
}

@Override
public ModelProto serialize() {
ModelDataCarrier<T> carrier = createDataCarrier();

XGBoostModelProto.Builder modelBuilder = XGBoostModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setConverter(converter.serialize());
try {
for (Booster b : models) {
modelBuilder.addModels(ByteString.copyFrom(b.toByteArray()));
}
} catch (XGBoostError e) {
throw new IllegalStateException("Failed to serialize XGBoost model");
}

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(XGBoostModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

private void writeObject(ObjectOutputStream out) throws IOException {
out.defaultWriteObject();
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,11 +23,14 @@

import java.io.Serializable;
import java.util.List;
import org.tribuo.common.xgboost.protos.XGBoostOutputConverterProto;
import org.tribuo.protos.ProtoSerializable;

/**
* Converts the output of XGBoost into the appropriate prediction type.
*/
public interface XGBoostOutputConverter<T extends Output<T>> extends Serializable {
public interface XGBoostOutputConverter<T extends Output<T>> extends
ProtoSerializable<XGBoostOutputConverterProto>, Serializable {

/**
* Does this converter produce probabilities?
Expand Down Expand Up @@ -55,4 +58,15 @@ public interface XGBoostOutputConverter<T extends Output<T>> extends Serializabl
*/
public List<Prediction<T>> convertBatchOutput(ImmutableOutputInfo<T> info, List<float[][]> probabilities, int[] numValidFeatures, Example<T>[] examples);

/**
* Gets the type witness for the output this converter uses.
* <p>
* The default implementation throws {@link UnsupportedOperationException} for compatibility
* with subclasses which don't support protobuf serialization.
* @return The class of the output.
*/
default public Class<T> getTypeWitness() {
throw new UnsupportedOperationException("This class has not been updated to support protobuf serialization.");
}

}
Loading

0 comments on commit 2a04a30

Please sign in to comment.