diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/CRFModel.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/CRFModel.java index 99882128a..f30d29915 100644 --- a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/CRFModel.java +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/CRFModel.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.tribuo.classification.sgd.crf; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; @@ -24,10 +26,14 @@ import org.tribuo.Prediction; import org.tribuo.classification.Label; import org.tribuo.classification.sequence.ConfidencePredictingSequenceModel; +import org.tribuo.classification.sgd.protos.CRFModelProto; +import org.tribuo.impl.ModelDataCarrier; +import org.tribuo.math.Parameters; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.SparseVector; import org.tribuo.math.la.Tensor; +import org.tribuo.protos.core.SequenceModelProto; import org.tribuo.provenance.ModelProvenance; import org.tribuo.sequence.SequenceExample; @@ -59,6 +65,11 @@ public class CRFModel extends ConfidencePredictingSequenceModel { private static final Logger logger = Logger.getLogger(CRFModel.class.getName()); private static final long serialVersionUID = 2L; + /** + * Protobuf serialization version. + */ + public static final int CURRENT_VERSION = 0; + private final CRFParameters parameters; /** @@ -87,6 +98,37 @@ public enum ConfidenceType { this.confidenceType = ConfidenceType.NONE; } + /** + * Deserialization factory. + * @param version The serialized object version. + * @param className The class name. + * @param message The serialized data. + */ + public static CRFModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException { + if (version < 0 || version > CURRENT_VERSION) { + throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION); + } + CRFModelProto proto = message.unpack(CRFModelProto.class); + + ModelDataCarrier carrier = ModelDataCarrier.deserialize(proto.getMetadata()); + if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) { + throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass()); + } + @SuppressWarnings("unchecked") // guarded by getClass + ImmutableOutputInfo