Skip to content

Commit

Permalink
Adds protobuf serialization to FeatureMap, Hasher, and VariableInfo i…
Browse files Browse the repository at this point in the history
…mplementations (#226)

* Roughing out the core tribuo protos.

* Adding protobuf serialization for FeatureMap, subclasses and fields of FeatureMap.

* Adding generated protos.

* Adding equals methods to CategoricalInfo and RealInfo, and added an equality test to CategoricalInfoTest.

* Adding equals methods to Hashers, and adding HashedFeatureMap.serialize().

* Adding equals methods to FeatureMap and MutableFeatureMap, and tests for equality and serialization in FeatureMapTest.

* Removing generated protos for ease of reviewing.

* Fix the javadoc in ProtoSerializable.

* Adding support for serializing TransformerMap.

* Adding protobuf serialization to org.tribuo.transform.transformations.

* Fixing copyrights.

* Adding tests for transformation serialization. Made the transformer implementations package private so they are easier to test.

* initial scratchings of ProtobufClass/Field annotations and ProtoUtil.serialize method

* moved ProtoUtilTest to org.tribuo to allow testing of protected method

* added protobuf generated source files

* added @ProtobufField  and @ProtobufClass to a bunch of class defs

- major update to ProtoUtil.serialize to handle more cases
- added a bunch of tests to ProtoUtilTest
- replaced serialize method impls with ProtoUtil.serialize
- CategoricalIDInfo/CategoricalInfo - handles observedValue and
observedCount separately from valueCounts
- CategoricalInfo has an id member set to -1
- added version attribute to ProtobufClass annotation

* renamed annotations

ProtoSerializable has default serialize method
ProtoSerializableClass no longer has 'serializedClass' attribute and is
no longer repeatable
added RealIDInfoProto and CategricalIDInfoProto
removed CategoricalInfo.id and RealInfo.id
removed various serialize methods in favor of default method
new annotation ProtoSerializableKeysValuesField
towards cleaning up ProtoUtil.serialize
ReflectUtil gives the type parameter assignments for an object's
interface type parameters

* default deserialize method replaces "instantiate"

added @ProtoSerializableMapValuesField
remove register redirects
clean up serialization of fields logic
- remove getMapFields
- clean up getFields
findMethod can be used for setters and getters by specifying expected
param count.

* moved protoserializable code into protos package

remove various compile warnings in package.

* initial commit of ProtoSerializableArrayField

which allows you to annotate e.g. an array of doubles as done in
BinningTransformer (found in BinningTransformation).

* removed unneeded serialize methods

* updated tests to be consistent

* added @ProtoSerializable annotations to two transformers

* removes initial attempt at generic but underspecified deserialize method

developer must implement deserializeFromProto and won't get a crappy
solution that does something if they don't.

also, put back in detailed error handling messages in the higher-level
'deserialize' method (FKA 'instantiate').

* reimplemented ReflectUtil to pull out one type parameter resolution

this version has a bunch of sysout statements
added unit tests

* cleaned up reimplementation of ReflectUtil functionality.

better unit tests
removed sysout statements
fixed a bug found in unit tests.

* SimpleTransform replace serialize() impl

fix problem with field name in annotation

* moved functionality in ReflectUtil to ProtoUtil and made it private

* added @ProtoSerializableClass annotation to LinearScalingTransformer

refactored serialize()

* added @ProtoSerializableClass to IDFTransformer and refactored serialize

added unit test

* removed default impl for ProtoSerializable.serialize

added default impls to all the effected subclasses that now require
them.

* Formatting and licenses.

* Removing ProtoSerializableArrayField, cleaning up the use of maps, making version not default, adding javadoc.

* Tidying up ProtoUtil.

* Adding internal redirect hook.

* Narrowing the visibility of some methods in ProtoUtil.

* Adding static version field. Improving validation of deserialized objects.

* Adding docs about hasher deserialization.

* Adding more logging and exceptions to ProtoUtil.

* Fix ModHashCodeHasher deserialize.

Co-authored-by: Philip Ogren <[email protected]>
  • Loading branch information
Craigacp and pogren authored Jun 2, 2022
1 parent 2b4a9a4 commit 5940cb3
Show file tree
Hide file tree
Showing 104 changed files with 34,124 additions and 120 deletions.
97 changes: 97 additions & 0 deletions Core/src/main/java/org/tribuo/CategoricalIDInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,30 @@

package org.tribuo;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.core.CategoricalIDInfoProto;

import java.util.HashMap;
import java.util.List;
import java.util.Objects;

/**
* Same as a {@link CategoricalInfo}, but with an additional int id field.
*/
@ProtoSerializableClass(version = CategoricalIDInfo.CURRENT_VERSION, serializedDataClass = CategoricalIDInfoProto.class)
public class CategoricalIDInfo extends CategoricalInfo implements VariableIDInfo {
private static final long serialVersionUID = 2L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

@ProtoSerializableField
private final int id;

/**
Expand All @@ -31,6 +49,9 @@ public class CategoricalIDInfo extends CategoricalInfo implements VariableIDInfo
*/
public CategoricalIDInfo(CategoricalInfo info, int id) {
super(info);
if (id < 0) {
throw new IllegalArgumentException("Invalid id number, must be non-negative, found " + id);
}
this.id = id;
}

Expand All @@ -46,6 +67,61 @@ private CategoricalIDInfo(CategoricalIDInfo info, String newName) {
this.id = info.id;
}

/**
* Deserialization constructor.
* @param name The info name.
* @param id The info id.
*/
private CategoricalIDInfo(String name, int id) {
super(name);
if (id < 0) {
throw new IllegalArgumentException("Invalid id number, must be non-negative, found " + id);
}
this.id = id;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static CategoricalIDInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
CategoricalIDInfoProto proto = message.unpack(CategoricalIDInfoProto.class);
CategoricalIDInfo info = new CategoricalIDInfo(proto.getName(),proto.getId());
List<Double> keys = proto.getKeyList();
List<Long> values = proto.getValueList();
if (keys.size() != values.size()) {
throw new IllegalStateException("Invalid protobuf, keys and values don't match. keys.size() = " + keys.size() + ", values.size() = " + values.size());
}
int newCount = 0;
if (keys.size() > 1) {
info.valueCounts = new HashMap<>(keys.size());
for (int i = 0; i < keys.size(); i++) {
if (values.get(i) < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + values.get(i) + " for value " + keys.get(i));
}
info.valueCounts.put(keys.get(i),new MutableLong(values.get(i)));
newCount += values.get(i).intValue();
}
} else {
info.observedValue = proto.getObservedValue();
info.observedCount = proto.getObservedCount();
newCount = (int) proto.getObservedCount();
if (info.observedCount < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + info.observedCount + " for value " + info.observedValue);
}
}
if (newCount != proto.getCount()) {
throw new IllegalStateException("Invalid protobuf, count " + newCount + " did not match expected value " + proto.getCount());
}
info.count = newCount;
return info;
}

@Override
public int getID() {
return id;
Expand Down Expand Up @@ -84,4 +160,25 @@ public String toString() {
return "CategoricalFeature(name=" + name + ",id=" + id + ",count=" + count + ",map={" +observedValue+","+observedCount+"})";
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
CategoricalIDInfo that = (CategoricalIDInfo) o;
return id == that.id;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), id);
}

}
99 changes: 99 additions & 0 deletions Core/src/main/java/org/tribuo/CategoricalInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,24 @@

package org.tribuo;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.MutableNumber;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoSerializableKeysValuesField;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.CategoricalInfoProto;
import org.tribuo.protos.core.VariableInfoProto;
import org.tribuo.util.Util;

import java.io.IOException;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
Expand All @@ -47,9 +56,15 @@
* are recomputed. Care should be taken if data is read while {@link #observe(double)} is called.
* </p>
*/
@ProtoSerializableClass(version = CategoricalInfo.CURRENT_VERSION, serializedDataClass = CategoricalInfoProto.class)
public class CategoricalInfo extends SkeletalVariableInfo {
private static final long serialVersionUID = 2L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private static final MutableLong ZERO = new MutableLong(0);
/**
* The default threshold for converting a categorical info into a {@link RealInfo}.
Expand All @@ -60,16 +75,19 @@ public class CategoricalInfo extends SkeletalVariableInfo {
/**
* The occurrence counts of each value.
*/
@ProtoSerializableKeysValuesField(keysName="key", valuesName="value")
protected Map<Double,MutableLong> valueCounts = null;

/**
* The observed value if it's only seen a single one.
*/
@ProtoSerializableField
protected double observedValue = Double.NaN;

/**
* The count of the observed value if it's only seen a single one.
*/
@ProtoSerializableField
protected long observedCount = 0;

// These variables are used in the sampling methods, and regenerated after serialization if a sample is required.
Expand Down Expand Up @@ -117,6 +135,50 @@ protected CategoricalInfo(CategoricalInfo info, String newName) {
}
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static CategoricalInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
CategoricalInfoProto proto = message.unpack(CategoricalInfoProto.class);
CategoricalInfo info = new CategoricalInfo(proto.getName());
List<Double> keys = proto.getKeyList();
List<Long> values = proto.getValueList();
if (keys.size() != values.size()) {
throw new IllegalStateException("Invalid protobuf, keys and values don't match. keys.size() = " + keys.size() + ", values.size() = " + values.size());
}
int newCount = 0;
if (keys.size() > 1) {
info.valueCounts = new HashMap<>(keys.size());
for (int i = 0; i < keys.size(); i++) {
if (values.get(i) < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + values.get(i) + " for value " + keys.get(i));
}
info.valueCounts.put(keys.get(i),new MutableLong(values.get(i)));
newCount += values.get(i).intValue();
}
} else {
info.observedValue = proto.getObservedValue();
info.observedCount = proto.getObservedCount();
newCount = (int) proto.getObservedCount();
if (info.observedCount < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + info.observedCount + " for value " + info.observedValue);
}
}
info.count = newCount;
return info;
}

@Override
public VariableInfoProto serialize() {
return ProtoUtil.serialize(this);
}

@Override
protected void observe(double value) {
if (value != 0.0) {
Expand Down Expand Up @@ -361,6 +423,43 @@ private synchronized void regenerateValues() {
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
CategoricalInfo that = (CategoricalInfo) o;
// MutableLong in OLCUT 5.2.0 doesn't implement equals,
// so we can't compare valueCounts with Objects.equals.
// That'll be fixed in the next OLCUT but for the time being we've got this workaround.
if (valueCounts != null ^ that.valueCounts != null) {
return false;
} else if (valueCounts != null && that.valueCounts != null) {
if (valueCounts.size() != that.valueCounts.size()) {
return false;
} else {
for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) {
MutableLong other = that.valueCounts.get(e.getKey());
if ((other == null) || (e.getValue().longValue() != other.longValue())) {
return false;
}
}
}
}
return Double.compare(that.observedValue, observedValue) == 0 && observedCount == that.observedCount;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), valueCounts, observedValue, observedCount);
}

@Override
public String toString() {
if (valueCounts != null) {
Expand Down
35 changes: 34 additions & 1 deletion Core/src/main/java/org/tribuo/FeatureMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,30 @@

package org.tribuo;

import org.tribuo.protos.core.FeatureDomainProto;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableMapValuesField;
import org.tribuo.protos.ProtoUtil;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;

/**
* A map from Strings to {@link VariableInfo} objects storing
* information about a feature.
*/
public abstract class FeatureMap implements Serializable, Iterable<VariableInfo> {
public abstract class FeatureMap implements Serializable, ProtoSerializable<FeatureDomainProto>, Iterable<VariableInfo> {
private static final long serialVersionUID = 1L;

/**
* Map from the feature names to their info.
*/
@ProtoSerializableMapValuesField(valuesName = "info")
protected final Map<String, VariableInfo> m;

/**
Expand Down Expand Up @@ -101,6 +108,23 @@ public Iterator<VariableInfo> iterator() {
return m.values().iterator();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FeatureMap that = (FeatureMap) o;
return m.equals(that.m);
}

@Override
public int hashCode() {
return Objects.hash(m);
}

/**
* Same as the toString, but ordered by name, and with newlines.
* @return A String representation of this FeatureMap.
Expand Down Expand Up @@ -136,4 +160,13 @@ public boolean domainEquals(FeatureMap other) {
}
}

/**
* Deserializes a {@link FeatureDomainProto} into a {@link FeatureMap} subclass.
* @param proto The proto to deserialize.
* @return The deserialized FeatureMap.
*/
public static FeatureMap deserialize(FeatureDomainProto proto) {
return ProtoUtil.deserialize(proto);
}

}
Loading

0 comments on commit 5940cb3

Please sign in to comment.