Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds protobuf serialization to FeatureMap, Hasher, and VariableInfo implementations #226

Merged
merged 40 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b91a8ed
Roughing out the core tribuo protos.
Craigacp Mar 15, 2022
28583ce
Adding protobuf serialization for FeatureMap, subclasses and fields o…
Craigacp Mar 22, 2022
5a8e477
Adding generated protos.
Craigacp Mar 22, 2022
0935e0e
Adding equals methods to CategoricalInfo and RealInfo, and added an e…
Craigacp Mar 25, 2022
acad1f4
Adding equals methods to Hashers, and adding HashedFeatureMap.seriali…
Craigacp Mar 25, 2022
911c0d5
Adding equals methods to FeatureMap and MutableFeatureMap, and tests …
Craigacp Mar 25, 2022
ac69a0b
Removing generated protos for ease of reviewing.
Craigacp Mar 25, 2022
cb7c643
Fix the javadoc in ProtoSerializable.
Craigacp Mar 25, 2022
3f7a8f7
Adding support for serializing TransformerMap.
Craigacp Mar 27, 2022
f9c6d30
Adding protobuf serialization to org.tribuo.transform.transformations.
Craigacp Mar 28, 2022
6cf7354
Fixing copyrights.
Craigacp Mar 28, 2022
aa8674a
Adding tests for transformation serialization. Made the transformer i…
Craigacp Mar 29, 2022
426eb95
initial scratchings of ProtobufClass/Field annotations and ProtoUtil.…
pogren Apr 19, 2022
12a88a1
moved ProtoUtilTest to org.tribuo to allow testing of protected method
pogren Apr 20, 2022
ebe4b55
added protobuf generated source files
pogren Apr 20, 2022
8a72af7
added @ProtobufField and @ProtobufClass to a bunch of class defs
pogren Apr 20, 2022
0b6a374
renamed annotations
pogren Apr 27, 2022
bc1ca9d
default deserialize method replaces "instantiate"
pogren Apr 27, 2022
188e483
moved protoserializable code into protos package
pogren Apr 27, 2022
8cd2f1c
initial commit of ProtoSerializableArrayField
pogren Apr 28, 2022
62dfe04
removed unneeded serialize methods
pogren Apr 29, 2022
2d2492a
updated tests to be consistent
pogren Apr 29, 2022
3f6a769
added @ProtoSerializable annotations to two transformers
pogren Apr 29, 2022
32799bd
removes initial attempt at generic but underspecified deserialize method
pogren May 2, 2022
fe7b499
reimplemented ReflectUtil to pull out one type parameter resolution
pogren May 2, 2022
f07236e
cleaned up reimplementation of ReflectUtil functionality.
pogren May 2, 2022
5cecb47
SimpleTransform replace serialize() impl
pogren May 2, 2022
e54bc0b
moved functionality in ReflectUtil to ProtoUtil and made it private
pogren May 2, 2022
f5e7c9b
added @ProtoSerializableClass annotation to LinearScalingTransformer
pogren May 2, 2022
4eb3b71
added @ProtoSerializableClass to IDFTransformer and refactored serialize
pogren May 2, 2022
cc35e29
removed default impl for ProtoSerializable.serialize
pogren May 2, 2022
df698f3
Formatting and licenses.
Craigacp May 3, 2022
c87a8aa
Removing ProtoSerializableArrayField, cleaning up the use of maps, ma…
Craigacp May 3, 2022
1f338ad
Tidying up ProtoUtil.
Craigacp May 3, 2022
d3f6e82
Adding internal redirect hook.
Craigacp May 3, 2022
9be2553
Narrowing the visibility of some methods in ProtoUtil.
Craigacp May 3, 2022
f284dbf
Adding static version field. Improving validation of deserialized obj…
Craigacp May 9, 2022
4a99c67
Adding docs about hasher deserialization.
Craigacp May 25, 2022
e44220b
Adding more logging and exceptions to ProtoUtil.
Craigacp May 25, 2022
bfb1c2b
Fix ModHashCodeHasher deserialize.
Craigacp May 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions Core/src/main/java/org/tribuo/CategoricalIDInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,30 @@

package org.tribuo;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.core.CategoricalIDInfoProto;

import java.util.HashMap;
import java.util.List;
import java.util.Objects;

/**
* Same as a {@link CategoricalInfo}, but with an additional int id field.
*/
@ProtoSerializableClass(version = CategoricalIDInfo.CURRENT_VERSION, serializedDataClass = CategoricalIDInfoProto.class)
public class CategoricalIDInfo extends CategoricalInfo implements VariableIDInfo {
private static final long serialVersionUID = 2L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

@ProtoSerializableField
private final int id;

/**
Expand All @@ -31,6 +49,9 @@ public class CategoricalIDInfo extends CategoricalInfo implements VariableIDInfo
*/
public CategoricalIDInfo(CategoricalInfo info, int id) {
super(info);
if (id < 0) {
throw new IllegalArgumentException("Invalid id number, must be non-negative, found " + id);
}
this.id = id;
}

Expand All @@ -46,6 +67,61 @@ private CategoricalIDInfo(CategoricalIDInfo info, String newName) {
this.id = info.id;
}

/**
* Deserialization constructor.
* @param name The info name.
* @param id The info id.
*/
private CategoricalIDInfo(String name, int id) {
super(name);
if (id < 0) {
throw new IllegalArgumentException("Invalid id number, must be non-negative, found " + id);
}
this.id = id;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static CategoricalIDInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
CategoricalIDInfoProto proto = message.unpack(CategoricalIDInfoProto.class);
CategoricalIDInfo info = new CategoricalIDInfo(proto.getName(),proto.getId());
List<Double> keys = proto.getKeyList();
List<Long> values = proto.getValueList();
if (keys.size() != values.size()) {
throw new IllegalStateException("Invalid protobuf, keys and values don't match. keys.size() = " + keys.size() + ", values.size() = " + values.size());
}
int newCount = 0;
if (keys.size() > 1) {
info.valueCounts = new HashMap<>(keys.size());
for (int i = 0; i < keys.size(); i++) {
if (values.get(i) < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + values.get(i) + " for value " + keys.get(i));
}
info.valueCounts.put(keys.get(i),new MutableLong(values.get(i)));
newCount += values.get(i).intValue();
}
} else {
info.observedValue = proto.getObservedValue();
info.observedCount = proto.getObservedCount();
newCount = (int) proto.getObservedCount();
if (info.observedCount < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + info.observedCount + " for value " + info.observedValue);
}
}
if (newCount != proto.getCount()) {
throw new IllegalStateException("Invalid protobuf, count " + newCount + " did not match expected value " + proto.getCount());
}
info.count = newCount;
return info;
}

@Override
public int getID() {
return id;
Expand Down Expand Up @@ -84,4 +160,25 @@ public String toString() {
return "CategoricalFeature(name=" + name + ",id=" + id + ",count=" + count + ",map={" +observedValue+","+observedCount+"})";
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
CategoricalIDInfo that = (CategoricalIDInfo) o;
return id == that.id;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), id);
}

}
99 changes: 99 additions & 0 deletions Core/src/main/java/org/tribuo/CategoricalInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,24 @@

package org.tribuo;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.MutableLong;
import com.oracle.labs.mlrg.olcut.util.MutableNumber;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoSerializableKeysValuesField;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.CategoricalInfoProto;
import org.tribuo.protos.core.VariableInfoProto;
import org.tribuo.util.Util;

import java.io.IOException;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
Expand All @@ -47,9 +56,15 @@
* are recomputed. Care should be taken if data is read while {@link #observe(double)} is called.
* </p>
*/
@ProtoSerializableClass(version = CategoricalInfo.CURRENT_VERSION, serializedDataClass = CategoricalInfoProto.class)
public class CategoricalInfo extends SkeletalVariableInfo {
private static final long serialVersionUID = 2L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private static final MutableLong ZERO = new MutableLong(0);
/**
* The default threshold for converting a categorical info into a {@link RealInfo}.
Expand All @@ -60,16 +75,19 @@ public class CategoricalInfo extends SkeletalVariableInfo {
/**
* The occurrence counts of each value.
*/
@ProtoSerializableKeysValuesField(keysName="key", valuesName="value")
protected Map<Double,MutableLong> valueCounts = null;

/**
* The observed value if it's only seen a single one.
*/
@ProtoSerializableField
protected double observedValue = Double.NaN;

/**
* The count of the observed value if it's only seen a single one.
*/
@ProtoSerializableField
protected long observedCount = 0;

// These variables are used in the sampling methods, and regenerated after serialization if a sample is required.
Expand Down Expand Up @@ -117,6 +135,50 @@ protected CategoricalInfo(CategoricalInfo info, String newName) {
}
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static CategoricalInfo deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
CategoricalInfoProto proto = message.unpack(CategoricalInfoProto.class);
CategoricalInfo info = new CategoricalInfo(proto.getName());
List<Double> keys = proto.getKeyList();
List<Long> values = proto.getValueList();
if (keys.size() != values.size()) {
throw new IllegalStateException("Invalid protobuf, keys and values don't match. keys.size() = " + keys.size() + ", values.size() = " + values.size());
}
int newCount = 0;
if (keys.size() > 1) {
info.valueCounts = new HashMap<>(keys.size());
for (int i = 0; i < keys.size(); i++) {
if (values.get(i) < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + values.get(i) + " for value " + keys.get(i));
}
info.valueCounts.put(keys.get(i),new MutableLong(values.get(i)));
newCount += values.get(i).intValue();
}
} else {
info.observedValue = proto.getObservedValue();
info.observedCount = proto.getObservedCount();
newCount = (int) proto.getObservedCount();
if (info.observedCount < 0) {
throw new IllegalStateException("Invalid protobuf, counts must be positive, found " + info.observedCount + " for value " + info.observedValue);
}
}
info.count = newCount;
return info;
}

@Override
public VariableInfoProto serialize() {
return ProtoUtil.serialize(this);
}

@Override
protected void observe(double value) {
if (value != 0.0) {
Expand Down Expand Up @@ -361,6 +423,43 @@ private synchronized void regenerateValues() {
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
CategoricalInfo that = (CategoricalInfo) o;
// MutableLong in OLCUT 5.2.0 doesn't implement equals,
// so we can't compare valueCounts with Objects.equals.
// That'll be fixed in the next OLCUT but for the time being we've got this workaround.
if (valueCounts != null ^ that.valueCounts != null) {
return false;
} else if (valueCounts != null && that.valueCounts != null) {
if (valueCounts.size() != that.valueCounts.size()) {
return false;
} else {
for (Map.Entry<Double, MutableLong> e : valueCounts.entrySet()) {
MutableLong other = that.valueCounts.get(e.getKey());
if ((other == null) || (e.getValue().longValue() != other.longValue())) {
return false;
}
}
}
}
return Double.compare(that.observedValue, observedValue) == 0 && observedCount == that.observedCount;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), valueCounts, observedValue, observedCount);
}

@Override
public String toString() {
if (valueCounts != null) {
Expand Down
35 changes: 34 additions & 1 deletion Core/src/main/java/org/tribuo/FeatureMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,30 @@

package org.tribuo;

import org.tribuo.protos.core.FeatureDomainProto;
import org.tribuo.protos.ProtoSerializable;
import org.tribuo.protos.ProtoSerializableMapValuesField;
import org.tribuo.protos.ProtoUtil;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;

/**
* A map from Strings to {@link VariableInfo} objects storing
* information about a feature.
*/
public abstract class FeatureMap implements Serializable, Iterable<VariableInfo> {
public abstract class FeatureMap implements Serializable, ProtoSerializable<FeatureDomainProto>, Iterable<VariableInfo> {
private static final long serialVersionUID = 1L;

/**
* Map from the feature names to their info.
*/
@ProtoSerializableMapValuesField(valuesName = "info")
protected final Map<String, VariableInfo> m;

/**
Expand Down Expand Up @@ -101,6 +108,23 @@ public Iterator<VariableInfo> iterator() {
return m.values().iterator();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FeatureMap that = (FeatureMap) o;
return m.equals(that.m);
}

@Override
public int hashCode() {
return Objects.hash(m);
}

/**
* Same as the toString, but ordered by name, and with newlines.
* @return A String representation of this FeatureMap.
Expand Down Expand Up @@ -136,4 +160,13 @@ public boolean domainEquals(FeatureMap other) {
}
}

/**
* Deserializes a {@link FeatureDomainProto} into a {@link FeatureMap} subclass.
* @param proto The proto to deserialize.
* @return The deserialized FeatureMap.
*/
public static FeatureMap deserialize(FeatureDomainProto proto) {
return ProtoUtil.deserialize(proto);
}

}
Loading