Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing the transformer behaviour on sparse features #122

Merged
merged 13 commits into from
Apr 2, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions Core/src/main/java/org/tribuo/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,44 @@ public String toString(){
* Does not mutate the dataset, if you wish to apply the TransformerMap, use
* {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
* <p>
* Currently TransformationMaps and TransformerMaps only operate on feature values
* which are present, sparse values are ignored and not transformed. If the zeros
* should be transformed, call {@link MutableDataset#densify} on the datasets.
* TransformerMaps operate on feature values which are present, sparse values
* are ignored and not transformed. If the zeros should be transformed, call
* {@link MutableDataset#densify} on the datasets before applying a transformer.
* <p>
* This method calls {@link #createTransformers(TransformationMap, boolean)} with
* {@code includeImplicitZeroFeatures} set to false, thus ignoring implicitly zero
* features when fitting the transformations. This is the default behaviour in
* Tribuo 4.0, but causes erroneous behaviour in
* {@link org.tribuo.transform.transformations.IDFTransformation} so should be
* avoided with that transformation.
* <p>
* Throws {@link IllegalArgumentException} if the TransformationMap object has
* regexes which apply to multiple features.
* @param transformations The transformations to fit.
* @return A TransformerMap which can apply the transformations to a dataset.
*/
public TransformerMap createTransformers(TransformationMap transformations) {
return createTransformers(transformations, false);
}

/**
* Takes a {@link TransformationMap} and converts it into a {@link TransformerMap} by
* observing all the values in this dataset.
* <p>
* Does not mutate the dataset, if you wish to apply the TransformerMap, use
* {@link MutableDataset#transform} or {@link TransformerMap#transformDataset}.
* <p>
* TransformerMaps operate on feature values which are present, sparse values
* are ignored and not transformed. If the zeros should be transformed, call
* {@link MutableDataset#densify} on the datasets before applying a transformer.
* <p>
* Throws {@link IllegalArgumentException} if the TransformationMap object has
* regexes which apply to multiple features.
* @param transformations The transformations to fit.
* @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
* @return A TransformerMap which can apply the transformations to a dataset.
*/
public TransformerMap createTransformers(TransformationMap transformations, boolean includeImplicitZeroFeatures) {
ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());

// Validate map by checking no regex applies to multiple features.
Expand Down Expand Up @@ -330,10 +358,12 @@ public TransformerMap createTransformers(TransformationMap transformations) {
removeSet.clear();
// Emit the new transformers onto the end of the list in the output map.
for (Map.Entry<String,Queue<TransformStatistics>> entry : featureStats.entrySet()) {
// Observe all the sparse feature values
int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
TransformStatistics currentStats = entry.getValue().poll();
currentStats.observeSparse(unobservedFeatures);
if (includeImplicitZeroFeatures) {
// Observe all the sparse feature values
int unobservedFeatures = sparseCount.get(entry.getKey()).intValue();
currentStats.observeSparse(unobservedFeatures);
}
// Get the transformer list for that feature (if absent)
List<Transformer> l = output.computeIfAbsent(entry.getKey(), (k) -> new ArrayList<>());
// Generate the transformer and add it to the appropriate list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ public interface TransformStatistics {

/**
* Observes a sparse (i.e., zero) value.
* @deprecated in 4.1 as it's unnecessary.
*/
@Deprecated
public void observeSparse();

/**
Expand Down
39 changes: 33 additions & 6 deletions Core/src/main/java/org/tribuo/transform/TransformTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
* A {@link Trainer} which encapsulates another trainer plus a {@link TransformationMap} object
* to apply to each {@link Dataset} before training each {@link Model}.
* <p>
* Transformations only operate on observed values. To operate on implicit zeros then
* first call {@link MutableDataset#densify} on the datasets.
* By default transformations only operate on explicit feature values. To include implicit zeros
* in transformation fitting set {@code includeImplicitZeroFeatures}. To convert implicit
* zeros to explicit zeros before applying the transformations set {@code densify}.
*/
public final class TransformTrainer<T extends Output<T>> implements Trainer<T> {

Expand All @@ -51,6 +52,9 @@ public final class TransformTrainer<T extends Output<T>> implements Trainer<T> {
@Config(description="Densify all the features before applying transformations.")
private boolean densify;

@Config(description="Include the implicit zeros in the transformation statistics collection")
private boolean includeImplicitZeroFeatures;

/**
* For OLCUT.
*/
Expand All @@ -60,7 +64,10 @@ private TransformTrainer() {}
* Creates a trainer which transforms the data before training, and stores
* the transformers along with the trained model in a {@link TransformedModel}.
* <p>
* This constructor makes a trainer which keeps the data sparse.
* Sets {@code observeSparse} to false and so this constructor makes a trainer
* which keeps the data sparse, and does not use the implicit zeros to construct
* the transformations. Models produced by this trainer will not convert implicit
* zeros in the feature space to explicit zeros (i.e., densify is false).
* @param innerTrainer The trainer to use.
* @param transformations The transformations to apply to the data first.
*/
Expand All @@ -71,22 +78,42 @@ public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformatio
/**
* Creates a trainer which transforms the data before training, and stores
* the transformers along with the trained model in a {@link TransformedModel}.
*
* <p>
* Sets {@code observeSparse} to false and so this constructor makes a trainer
* which keeps the data sparse, and does not use the implicit zeros to construct
* the transformations.
* @param innerTrainer The trainer to use.
* @param transformations The transformations to apply to the data first.
* @param densify Densify the dataset (and any predict time data) before training/prediction.
* @param densify Convert the implicit zeros in each training and prediction example
* to explicit zeros before training/prediction.
*/
public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformations, boolean densify) {
this(innerTrainer,transformations,densify,false);
}

/**
* Creates a trainer which transforms the data before training, and stores
* the transformers along with the trained model in a {@link TransformedModel}.
*
* @param innerTrainer The trainer to use.
* @param transformations The transformations to apply to the data first.
* @param densify Convert the implicit zeros in each training and prediction example
* to explicit zeros before training/prediction.
* @param includeImplicitZeroFeatures Use the implicit zero feature values to construct the transformations.
*/
public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformations, boolean densify, boolean includeImplicitZeroFeatures) {
this.innerTrainer = innerTrainer;
this.transformations = transformations;
this.densify = densify;
this.includeImplicitZeroFeatures = includeImplicitZeroFeatures;
}

@Override
public TransformedModel<T> train(Dataset<T> examples, Map<String, Provenance> instanceProvenance) {

logger.fine(String.format("Creating transformers"));
TransformerMap transformerMap = examples.createTransformers(transformations);

TransformerMap transformerMap = examples.createTransformers(transformations, includeImplicitZeroFeatures);

logger.fine("Transforming data set");

Expand Down
12 changes: 12 additions & 0 deletions Core/src/main/java/org/tribuo/transform/TransformedModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ public class TransformedModel<T extends Output<T>> extends Model<T> {
Collections.sort(featureNames);
}

/**
* Gets the transformers that this model applies to each example.
* <p>
* Note if you use these transformers to modify some data, and then
* feed that data to this model, the data will be transformed twice
* and this is not what you want.
* @return The transformers.
*/
public TransformerMap getTransformerMap() {
return transformerMap;
}

@Override
public Prediction<T> predict(Example<T> example) {
Example<T> transformedExample;
Expand Down
3 changes: 2 additions & 1 deletion Core/src/main/java/org/tribuo/transform/TransformerMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -66,7 +67,7 @@ public final class TransformerMap implements Provenancable<TransformerMapProvena
* @param transformationMapProvenance The provenance of the transformation map that was fit.
*/
public TransformerMap(Map<String,List<Transformer>> map, DatasetProvenance datasetProvenance, ConfiguredObjectProvenance transformationMapProvenance) {
this.map = map;
this.map = Collections.unmodifiableMap(map);
this.datasetProvenance = datasetProvenance;
this.transformationMapProvenance = transformationMapProvenance;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Logger;

/**
* A Transformation which bins values.
Expand All @@ -53,7 +54,20 @@ public final class BinningTransformation implements Transformation {
/**
* The allowed binning types.
*/
public enum BinningType { EQUAL_WIDTH, EQUAL_FREQUENCY, STD_DEVS }
public enum BinningType {
/**
* Creates bins of equal width over the data range.
*/
EQUAL_WIDTH,
/**
* Creates bins of equal frequency (i.e., equal numbers of data points).
*/
EQUAL_FREQUENCY,
/**
* Creates bins based on the mean and then +/- multiples of standard deviations.
*/
STD_DEVS
}

private static final String NUM_BINS = "numBins";
private static final String TYPE = "type";
Expand Down Expand Up @@ -229,10 +243,16 @@ public void observeValue(double value) {
}

@Override
public void observeSparse() { }
@Deprecated
public void observeSparse() {
observeValue(0.0);
}

@Override
public void observeSparse(int count) { }
public void observeSparse(int count) {
// This just tracks max and min, so seeing many 0.0 is the same as one 0.0.
observeValue(0.0);
}

@Override
public Transformer generateTransformer() {
Expand Down Expand Up @@ -308,10 +328,19 @@ protected void growArray() {
}

@Override
public void observeSparse() { }
@Deprecated
public void observeSparse() {
observeValue(0.0);
}

@Override
public void observeSparse(int count) { }
public void observeSparse(int sparseCount) {
if (observedValues.length < (count + sparseCount)) {
growArray(count + sparseCount);
}
count += sparseCount;
// Java initializes the array to zero so we don't need to write the values in.
}

@Override
public Transformer generateTransformer() {
Expand All @@ -338,6 +367,8 @@ public String toString() {
}

private static class StdDevStats implements TransformStatistics {
private static final Logger logger = Logger.getLogger(StdDevStats.class.getName());

private final int numBins;

private double mean = 0;
Expand All @@ -358,13 +389,25 @@ public void observeValue(double value) {
}

@Override
public void observeSparse() { }
@Deprecated
public void observeSparse() {
observeValue(0.0);
}

@Override
public void observeSparse(int count) { }
public void observeSparse(int sparseCount) {
count += sparseCount;
double delta = -mean;
mean += delta; // implicit zero for delta = 0 - mean;
double delta2 = -mean;
sumSquares += sparseCount * (delta * delta2);
}

@Override
public Transformer generateTransformer() {
if (sumSquares == 0.0) {
logger.info("Only observed a single value (" + mean + ") when building a BinningTransformer using standard deviation bins.");
}
double[] bins = new double[numBins];
double[] values = new double[numBins];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.tribuo.transform.TransformStatistics;
import org.tribuo.transform.Transformation;
Expand All @@ -24,7 +23,7 @@ public TransformStatistics createStats() {

@Override
public TransformationProvenance getProvenance() {
if(provenance == null) {
if (provenance == null) {
provenance = new IDFTransformationProvenance();
}
return provenance;
Expand Down Expand Up @@ -53,7 +52,9 @@ public void observeValue(double value) {
}

@Override
@Deprecated
public void observeSparse() {
sparseObservances++;
}

@Override
Expand All @@ -69,7 +70,8 @@ public Transformer generateTransformer() {
}

private static class IDFTransformer implements Transformer {

private static final long serialVersionUID = 1L;

private double df;

private double N;
Expand All @@ -85,12 +87,21 @@ public double transform(double tf) {
}

}


/**
* Provenance for {@link IDFTransformation}.
*/
public final static class IDFTransformationProvenance implements TransformationProvenance {
private static final long serialVersionUID = 1L;

IDFTransformationProvenance() { }

// IDFTransformation has no state to record.
public IDFTransformationProvenance(Map<String,Provenance> map) { }

@Override
public Map<String, Provenance> getConfiguredParameters() {
return Collections.unmodifiableMap(new HashMap<>());
return Collections.emptyMap();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,16 @@ public void observeValue(double value) {
}

@Override
public void observeSparse() { }
@Deprecated
public void observeSparse() {
observeValue(0.0);
}

@Override
public void observeSparse(int count) { }
public void observeSparse(int count) {
// This just tracks max and min, so seeing many 0.0 is the same as one 0.0.
observeValue(0.0);
}

@Override
public Transformer generateTransformer() {
Expand Down
Loading