Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds SGD protobuf serialization #275

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.classification.sgd.crf;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
Expand All @@ -24,10 +26,14 @@
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel;
import org.tribuo.classification.sgd.protos.CRFModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.protos.core.SequenceModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceExample;

Expand Down Expand Up @@ -59,6 +65,11 @@ public class CRFModel extends ConfidencePredictingSequenceModel {
private static final Logger logger = Logger.getLogger(CRFModel.class.getName());
private static final long serialVersionUID = 2L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final CRFParameters parameters;

/**
Expand Down Expand Up @@ -87,6 +98,37 @@ public enum ConfidenceType {
this.confidenceType = ConfidenceType.NONE;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static CRFModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
CRFModelProto proto = message.unpack(CRFModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

Parameters params = Parameters.deserialize(proto.getParams());
if (!(params instanceof CRFParameters)) {
throw new IllegalStateException("Invalid protobuf, parameters must be CRFParameters, found " + params.getClass());
}

ConfidenceType confidenceType = ConfidenceType.valueOf(proto.getConfidenceType());

CRFModel model = new CRFModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,(CRFParameters) params);
model.confidenceType = confidenceType;
return model;
}

/**
* Sets the inference method used for confidence prediction.
* If CONSTRAINED_BP uses the constrained belief propagation algorithm from Culotta and McCallum 2004,
Expand Down Expand Up @@ -257,6 +299,22 @@ public String generateWeightsString() {
return buffer.toString();
}

@Override
public SequenceModelProto serialize() {
ModelDataCarrier<Label> carrier = createDataCarrier();
CRFModelProto.Builder modelBuilder = CRFModelProto.newBuilder();
modelBuilder.setConfidenceType(confidenceType.name());
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setParams(parameters.serialize());

SequenceModelProto.Builder builder = SequenceModelProto.newBuilder();
builder.setVersion(CURRENT_VERSION);
builder.setClassName(CRFModel.class.getName());
builder.setSerializedData(Any.pack(modelBuilder.build()));

return builder.build();
}

/**
* Converts a {@link SequenceExample} into an array of {@link SparseVector}s suitable for CRF prediction.
* @deprecated As it's replaced with {@link #convertToVector} which is more flexible.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,16 +16,22 @@

package org.tribuo.classification.sgd.fm;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.protos.FMClassificationModelProto;
import org.tribuo.common.sgd.AbstractFMModel;
import org.tribuo.common.sgd.FMParameters;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.Parameters;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.onnx.ONNXNode;

Expand All @@ -45,6 +51,11 @@
public class FMClassificationModel extends AbstractFMModel<Label> implements ONNXExportable {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final VectorNormalizer normalizer;

/**
Expand All @@ -64,6 +75,35 @@ public class FMClassificationModel extends AbstractFMModel<Label> implements ONN
this.normalizer = normalizer;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static FMClassificationModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
FMClassificationModelProto proto = message.unpack(FMClassificationModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

Parameters params = Parameters.deserialize(proto.getParams());
if (!(params instanceof FMParameters)) {
throw new IllegalStateException("Invalid protobuf, parameters must be FMParameters, found " + params.getClass());
}

VectorNormalizer normalizer = VectorNormalizer.deserialize(proto.getNormalizer());

return new FMClassificationModel(carrier.name(),carrier.provenance(),carrier.featureDomain(),outputDomain,(FMParameters) params, normalizer, carrier.generatesProbabilities());
}

@Override
public Prediction<Label> predict(Example<Label> example) {
PredAndActive predTuple = predictSingle(example);
Expand All @@ -86,6 +126,22 @@ public Prediction<Label> predict(Example<Label> example) {
return new Prediction<>(maxLabel, predMap, predTuple.numActiveFeatures, example, generatesProbabilities);
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Label> carrier = createDataCarrier();
FMClassificationModelProto.Builder modelBuilder = FMClassificationModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setParams(modelParameters.serialize());
modelBuilder.setNormalizer(normalizer.serialize());

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setVersion(CURRENT_VERSION);
builder.setClassName(FMClassificationModel.class.getName());
builder.setSerializedData(Any.pack(modelBuilder.build()));

return builder.build();
}

@Override
protected FMClassificationModel copy(String newName, ModelProvenance newProvenance) {
return new FMClassificationModel(newName,newProvenance,featureIDMap,outputIDInfo,(FMParameters)modelParameters.copy(),normalizer,generatesProbabilities);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

package org.tribuo.classification.sgd.kernel;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -24,10 +26,15 @@
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.protos.KernelSVMModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

import java.util.Collections;
Expand All @@ -49,6 +56,11 @@
public class KernelSVMModel extends Model<Label> {
private static final long serialVersionUID = 2L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final Kernel kernel;
private final SparseVector[] supportVectors;
private final DenseMatrix weights;
Expand All @@ -62,6 +74,58 @@ public class KernelSVMModel extends Model<Label> {
this.weights = weights;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static KernelSVMModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
KernelSVMModelProto proto = message.unpack(KernelSVMModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

SparseVector[] supportVectors = new SparseVector[proto.getSupportVectorsCount()];
int featureSize = carrier.featureDomain().size() + 1;
List<TensorProto> supportProtos = proto.getSupportVectorsList();
for (int i = 0; i < supportProtos.size(); i++) {
Tensor tensor = Tensor.deserialize(supportProtos.get(i));
if (!(tensor instanceof SparseVector)) {
throw new IllegalStateException("Invalid protobuf, support vector must be a sparse vector, found " + tensor.getClass());
}
SparseVector vec = (SparseVector) tensor;
if (vec.size() != featureSize) {
throw new IllegalStateException("Invalid protobuf, support vector size must equal feature domain size, found " + vec.size() + ", expected " + featureSize);
}
supportVectors[i] = vec;
}

Tensor weightTensor = Tensor.deserialize(proto.getWeights());
if (!(weightTensor instanceof DenseMatrix)) {
throw new IllegalStateException("Invalid protobuf, weights must be a dense matrix, found " + weightTensor.getClass());
}
DenseMatrix weights = (DenseMatrix) weightTensor;
if (weights.getDimension1Size() != carrier.outputDomain().size()) {
throw new IllegalStateException("Invalid protobuf, weights not the right size, expected " + carrier.outputDomain().size() + ", found " + weights.getDimension1Size());
}
if (weights.getDimension2Size() != supportVectors.length) {
throw new IllegalStateException("Invalid protobuf, weights not the right size, expected " + supportVectors.length + ", found " + weights.getDimension2Size());
}

Kernel kernel = Kernel.deserialize(proto.getKernel());

return new KernelSVMModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), outputDomain,
kernel, supportVectors, weights);
}

/**
* Returns the number of support vectors used.
* @return The number of support vectors.
Expand Down Expand Up @@ -109,6 +173,25 @@ public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
return Optional.empty();
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Label> carrier = createDataCarrier();
KernelSVMModelProto.Builder modelBuilder = KernelSVMModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setKernel(kernel.serialize());
modelBuilder.setWeights(weights.serialize());
for (SparseVector v : supportVectors) {
modelBuilder.addSupportVectors(v.serialize());
}

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setVersion(CURRENT_VERSION);
builder.setClassName(KernelSVMModel.class.getName());
builder.setSerializedData(Any.pack(modelBuilder.build()));

return builder.build();
}

@Override
protected KernelSVMModel copy(String newName, ModelProvenance newProvenance) {
SparseVector[] vectorCopies = new SparseVector[supportVectors.length];
Expand Down
Loading