Skip to content

Commit

Permalink
Adding tests for transformation serialization. Made the transformer i…
Browse files Browse the repository at this point in the history
…mplementations package private so they are easier to test.
  • Loading branch information
Craigacp committed Apr 9, 2022
1 parent 763b026 commit 578763a
Show file tree
Hide file tree
Showing 14 changed files with 271 additions and 96 deletions.
9 changes: 9 additions & 0 deletions Core/src/main/java/org/tribuo/transform/Transformer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import org.tribuo.ProtoSerializable;
import org.tribuo.protos.core.TransformerProto;
import org.tribuo.util.ProtoUtil;

import java.io.Serializable;

Expand All @@ -39,4 +40,12 @@ public interface Transformer extends ProtoSerializable<TransformerProto>, Serial
*/
public double transform(double input);

/**
* Deserializes a {@link TransformerProto} into a {@link Transformer} subclass.
* @param proto The proto to deserialize.
* @return The deserialized FeatureMap.
*/
public static Transformer deserialize(TransformerProto proto) {
return (Transformer) ProtoUtil.instantiate(proto.getVersion(), proto.getClassName(), proto.getSerializedData());
}
}
17 changes: 15 additions & 2 deletions Core/src/main/java/org/tribuo/transform/TransformerMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ public int size() {
}

/**
* Gets the transformer associated with a given feature name.
* Gets the transformers associated with a given feature name.
* @param featureName the name of the feature for which we want the transformer
* @return the transformer associated with the feature name, which may be <code>null</code>
* @return the transformer list associated with the feature name, which may be <code>null</code>
* if there is no feature with that name.
*/
public List<Transformer> get(String featureName) {
Expand All @@ -209,6 +209,19 @@ public String toString() {
return "TransformerMap(map="+map.toString()+")";
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TransformerMap that = (TransformerMap) o;
return map.equals(that.map) && datasetProvenance.equals(that.datasetProvenance) && transformationMapProvenance.equals(that.transformationMapProvenance);
}

@Override
public int hashCode() {
return Objects.hash(map, datasetProvenance, transformationMapProvenance);
}

/**
* Get the feature names and associated list of transformers.
* @return The entry set of the transformer map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,14 @@ public String toString() {
}
}

private static final class BinningTransformer implements Transformer {
static final class BinningTransformer implements Transformer {
private static final long serialVersionUID = 1L;

private final BinningType type;
private final double[] bins;
private final double[] values;

public BinningTransformer(BinningType type, double[] bins, double[] values) {
BinningTransformer(BinningType type, double[] bins, double[] values) {
this.type = type;
this.bins = bins;
this.values = values;
Expand All @@ -459,7 +459,7 @@ public BinningTransformer(BinningType type, double[] bins, double[] values) {
* @param message The serialized data.
* @throws InvalidProtocolBufferException If the message is not a {@link BinningTransformerProto}.
*/
public static BinningTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
static BinningTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
BinningTransformerProto proto = message.unpack(BinningTransformerProto.class);
if (version == 0) {
if (proto.getBinsCount() == proto.getValuesCount()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public Transformer generateTransformer() {

}

private static class IDFTransformer implements Transformer {
static class IDFTransformer implements Transformer {
private static final long serialVersionUID = 1L;

private final double df;
Expand All @@ -108,7 +108,7 @@ private static class IDFTransformer implements Transformer {
* @param df The document frequency.
* @param N The number of documents.
*/
public IDFTransformer(int df, int N) {
IDFTransformer(int df, int N) {
if ((df < 0) || (N < 0)) {
throw new IllegalArgumentException("Both df and N must be positive");
}
Expand All @@ -123,7 +123,7 @@ public IDFTransformer(int df, int N) {
* @param message The serialized data.
* @throws InvalidProtocolBufferException If the message is not a {@link IDFTransformerProto}.
*/
public static IDFTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
static IDFTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
IDFTransformerProto proto = message.unpack(IDFTransformerProto.class);
if (version == 0) {
return new IDFTransformer((int)proto.getDf(), (int)proto.getN());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public String toString() {
}
}

private static final class LinearScalingTransformer implements Transformer {
static final class LinearScalingTransformer implements Transformer {
private static final long serialVersionUID = 1L;

private final double observedMin;
Expand All @@ -207,7 +207,7 @@ private static final class LinearScalingTransformer implements Transformer {
private final double scalingFactor;
private final boolean constant;

public LinearScalingTransformer(double observedMin, double observedMax, double targetMin, double targetMax) {
LinearScalingTransformer(double observedMin, double observedMax, double targetMin, double targetMax) {
if ((observedMin > observedMax) || (targetMin > targetMax)) {
throw new IllegalArgumentException("observedMin and targetMin must be less than observedMax and targetMax respectively");
}
Expand All @@ -228,7 +228,7 @@ public LinearScalingTransformer(double observedMin, double observedMax, double t
* @param message The serialized data.
* @throws InvalidProtocolBufferException If the message is not a {@link LinearScalingTransformerProto}.
*/
public static LinearScalingTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
static LinearScalingTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
LinearScalingTransformerProto proto = message.unpack(LinearScalingTransformerProto.class);
if (version == 0) {
return new LinearScalingTransformer(proto.getObservedMin(),proto.getObservedMax(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ public String toString() {
}
}

private static final class MeanStdDevTransformer implements Transformer {
static final class MeanStdDevTransformer implements Transformer {
private static final long serialVersionUID = 1L;

private final double observedMean;
private final double observedStdDev;
private final double targetMean;
private final double targetStdDev;

public MeanStdDevTransformer(double observedMean, double observedStdDev, double targetMean, double targetStdDev) {
MeanStdDevTransformer(double observedMean, double observedStdDev, double targetMean, double targetStdDev) {
if ((observedStdDev < 0) || (targetStdDev < 0)) {
throw new IllegalArgumentException("Standard deviations must be non-negative.");
}
Expand All @@ -230,7 +230,7 @@ public MeanStdDevTransformer(double observedMean, double observedStdDev, double
* @param message The serialized data.
* @throws InvalidProtocolBufferException If the message is not a {@link MeanStdDevTransformerProto}.
*/
public static MeanStdDevTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
static MeanStdDevTransformer deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
MeanStdDevTransformerProto proto = message.unpack(MeanStdDevTransformerProto.class);
if (version == 0) {
return new MeanStdDevTransformer(proto.getObservedMean(),proto.getObservedStdDev(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public enum Operation {
*/
private SimpleTransform() {}

private SimpleTransform(Operation op, double operand, double secondOperand) {
SimpleTransform(Operation op, double operand, double secondOperand) {
this.op = op;
this.operand = operand;
this.secondOperand = secondOperand;
Expand Down Expand Up @@ -192,7 +192,7 @@ public void postConfig() {
* @param message The serialized data.
* @throws InvalidProtocolBufferException If the message is not a {@link SimpleTransformProto}.
*/
public static SimpleTransform deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
static SimpleTransform deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
SimpleTransformProto proto = message.unpack(SimpleTransformProto.class);
if (version == 0) {
return new SimpleTransform(Operation.valueOf(proto.getOp()), proto.getFirstOperand(), proto.getSecondOperand());
Expand Down
2 changes: 1 addition & 1 deletion Core/src/main/java/org/tribuo/util/ProtoUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static Object instantiate(int version, String className, Any message) {
String targetClassName = REDIRECT_MAP.getOrDefault(key, className);
try {
Class<?> targetClass = Class.forName(targetClassName);
Method method = targetClass.getMethod(DESERIALIZATION_METHOD_NAME, int.class, String.class, Any.class);
Method method = targetClass.getDeclaredMethod(DESERIALIZATION_METHOD_NAME, int.class, String.class, Any.class);
method.setAccessible(true);
Object o = method.invoke(null, version, className, message);
method.setAccessible(false);
Expand Down
12 changes: 11 additions & 1 deletion Core/src/main/resources/protos/tribuo-core-impl.proto
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,14 @@ message BinningTransformerProto {
string binning_type = 1;
repeated double bins = 2;
repeated double values = 3;
}
}

/*
CountTransformer proto (used in tests)
*/
message TestCountTransformerProto {
int32 count = 1;
int32 sparseCount = 2;
repeated double countMapKeys = 3;
repeated int64 countMapValues = 4;
}
Loading

0 comments on commit 578763a

Please sign in to comment.