Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of a Reproducibility Framework #185

Merged
merged 42 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
d316475
Added setInvocationCount method to trainers for use by a reproducibil…
jwons Jun 24, 2021
4b6a7b6
Added a reproducibility utility that can reproduce a model from a Mod…
jwons Jul 13, 2021
6a88f8c
added optional parameter to allow a DataSource config to be overwritt…
jwons Jul 15, 2021
e2ab826
Refactored ReproUtil to not be static
jwons Jul 16, 2021
134c98d
Added tests for reproducibility utility that test for reproducing fro…
jwons Jul 17, 2021
4da9efa
add configManager getter to ReproUtil, tested that it can be used to …
jwons Jul 21, 2021
af5bec6
Added overloaded train method that takes an invocation count using it…
jwons Aug 2, 2021
1b0fda0
added diff feature to ReproUtility and added more comments to the ent…
jwons Aug 10, 2021
e4c256a
refactor and comments on ReproUtil, and diffs will now diff prov obje…
jwons Aug 11, 2021
30239a5
Merge branch 'oracle:main' into reproducibility
jwons Aug 11, 2021
cc8e4e2
fixed import merge conflict in KMeansTest
jwons Aug 23, 2021
5f255ce
git push origin reproducibilityMerge branch 'oracle-main' into reprod…
jwons Aug 23, 2021
ebed123
Merge pull request #1 from oracle/main
jwons Aug 23, 2021
5de9dc6
removed unneeded reflection to diff provenance, check for external mo…
jwons Aug 24, 2021
0601d14
Checks datasource is configurable before recovering, starting to remo…
jwons Aug 31, 2021
245e96b
Check if datasource is configurable after checking if it's a traintes…
jwons Sep 2, 2021
67f5453
Merge branch 'oracle:main' into reproducibility
jwons Sep 2, 2021
82b69b0
Continued to fix use of generics in reproducibility framework, remove…
jwons Sep 15, 2021
a950f37
Fixed bug where trainInvocationCounter would not be reset in setInvoc…
jwons Sep 21, 2021
0788f63
Fixed merge conflicts mostly relating to combining new invocationCoun…
jwons Sep 21, 2021
b70c94a
Merge branch 'oracle-main' into reproducibility
jwons Sep 21, 2021
2477ee7
Attempting to finish merge
jwons Sep 21, 2021
68af149
Merge branch 'oracle-main' into reproducibility
jwons Sep 21, 2021
2f2eb86
Some styling changes, fixed final merge issue, and some extra documen…
jwons Sep 21, 2021
edef8c0
In reproducibility framework, the invocation counts for ALL trainers …
jwons Sep 23, 2021
c560ab9
Fixed bug in reproducibility framework where Transform trainers could…
jwons Sep 30, 2021
c4705cd
Merging changes
jwons Oct 4, 2021
14939f0
merged upstream tribuo changes, applied setInvocation changes to Inde…
jwons Oct 19, 2021
32b8148
initial round of feedback, fixing generics, exceptions, print outs, u…
jwons Oct 20, 2021
dde0074
Made some exceptions more explicit
jwons Oct 21, 2021
65b234d
Added exception I thought I did yesterday that turned out to be a pri…
jwons Oct 21, 2021
fcb5015
removed more snake_case
jwons Oct 21, 2021
c0afdc8
Brought object mapper up to a static variable and initialize it in a …
jwons Oct 21, 2021
6f503da
simplified default train method with invocationCount parameter, fixed…
jwons Oct 22, 2021
13c0473
added unneeded check for dataset provenance, and removed completed todos
jwons Oct 25, 2021
08586d3
Diff will now add lists that only appear in one model
jwons Oct 25, 2021
13a8a91
Cleaned up exceptions, begin filling out javadocs particularly to inc…
jwons Oct 26, 2021
de98f3a
Moved test resources
jwons Oct 26, 2021
3a7d19e
Changed reproduceFromModel return type to record that has model, feat…
jwons Oct 26, 2021
f0760fc
removed rogue import and specified Set types in diff records
jwons Oct 27, 2021
0383ba4
only build reproducibility module with java 16
jwons Oct 27, 2021
0e7cd6d
fixed generic in diff record
jwons Oct 27, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,17 @@ public void postConfig() {

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) {
return(train(examples, instanceProvenance, INCREMENT_INVOCATION_COUNT));
}

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance, int invocationCount) {
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
this.invocationCount = invocationCount;
}
ModelProvenance provenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), instanceProvenance);
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
invocationCount++;
this.invocationCount++;
switch (dummyType) {
case CONSTANT:
MutableOutputInfo<Label> labelInfo = examples.getOutputInfo().generateMutableOutputInfo();
Expand All @@ -114,6 +122,15 @@ public int getInvocationCount() {
return invocationCount;
}

@Override
public synchronized void setInvocationCount(int invocationCount){
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

this.invocationCount = invocationCount;
}

@Override
public String toString() {
switch (dummyType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,21 @@ public String toString() {
*/
@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
return(train(examples, runProvenance, INCREMENT_INVOCATION_COUNT));
}

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count.
SplittableRandom localRNG;
TrainerProvenance trainerProvenance;
synchronized(this) {
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
Expand Down Expand Up @@ -195,7 +203,7 @@ public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runPr
EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance, ListProvenance.createListProvenance(models));
return new WeightedEnsembleModel<>("boosted-ensemble",provenance,featureIDs,labelIDs,models,new VotingCombiner(),newModelWeights);
}

//
// Update the weights
for (int j = 0; j < predictions.size(); j++) {
Expand All @@ -221,6 +229,22 @@ public int getInvocationCount() {
return trainInvocationCounter;
}

@Override
public synchronized void setInvocationCount(int invocationCount){
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;
for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
}

}

private float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) {
float correctSum = 0;
float total = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,46 @@ public void testEmptyExample() {
public void testRandomEmptyExample() {
runEmptyExample(randomt);
}

@Test
public void testSetInvocationCount() {
// Create new trainer and dataset so as not to mess with the other tests
CARTClassificationTrainer originalTrainer = new CARTClassificationTrainer(5,2, 0.0f,1.0f, true,
new GiniIndex(), Trainer.DEFAULT_SEED);
Pair<Dataset<Label>,Dataset<Label>> data = LabelledDataGenerator.denseTrainTest();
DatasetView<Label> trainingData = DatasetView.createView(data.getA(),(Example<Label> e) -> e.getOutput().getLabel().equals("Foo"), "Foo selector");

// The number of times to call train before final training.
// Original trainer will be trained numOfInvocations + 1 times
// New trainer will have it's invocation count set to numOfInvocations then trained once
int numOfInvocations = 2;

// Create the first model and train it numOfInvocations + 1 times
TreeModel<Label> originalModel = null;
for(int invocationCounter = 0; invocationCounter < numOfInvocations + 1; invocationCounter++){
originalModel = originalTrainer.train(trainingData);
}

// Create a new model with same configuration, but set the invocation count to numOfInvocations
// Assert that this succeeded, this means RNG will be at state where originalTrainer was before
// it performed its last train.
CARTClassificationTrainer newTrainer = new CARTClassificationTrainer(5,2, 0.0f,1.0f, true,
new GiniIndex(), Trainer.DEFAULT_SEED);
newTrainer.setInvocationCount(numOfInvocations);
assertEquals(numOfInvocations,newTrainer.getInvocationCount());

// Training newTrainer should now have the same result as if it
// had trained numOfInvocations times previously even though it hasn't
TreeModel<Label> newModel = newTrainer.train(trainingData);
assertEquals(originalTrainer.getInvocationCount(),newTrainer.getInvocationCount());
}

@Test
public void testNegativeInvocationCount(){
assertThrows(IllegalArgumentException.class, () -> {
CARTClassificationTrainer t = new CARTClassificationTrainer(5,2, 0.0f,1.0f, true,
new GiniIndex(), Trainer.DEFAULT_SEED);
t.setInvocationCount(-1);
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ public MultinomialNaiveBayesTrainer(double alpha) {

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT);
}

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
Expand All @@ -101,7 +106,9 @@ public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runPr
featureMap.merge(featureInfos.getID(feat.getName()), curWeight*feat.getValue(), Double::sum);
}
}

if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
TrainerProvenance trainerProvenance = getProvenance();
ModelProvenance provenance = new ModelProvenance(MultinomialNaiveBayesModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
invocationCount++;
Expand All @@ -125,6 +132,15 @@ public int getInvocationCount() {
return invocationCount;
}

@Override
public void setInvocationCount(int invocationCount) {
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

this.invocationCount = invocationCount;
}

@Override
public String toString() {
return "MultinomialNaiveBayesTrainer(alpha=" + alpha + ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,21 @@ public void setShuffle(boolean shuffle) {

@Override
public KernelSVMModel train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
return(train(examples, runProvenance, INCREMENT_INVOCATION_COUNT));
}

@Override
public KernelSVMModel train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
// Creates a new RNG, adds one to the invocation count.
TrainerProvenance trainerProvenance;
SplittableRandom localRNG;
synchronized(this) {
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
Expand Down Expand Up @@ -223,6 +231,23 @@ public int getInvocationCount() {
return trainInvocationCounter;
}

@Override
public synchronized void setInvocationCount(int invocationCount){
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;

for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
}

}

@Override
public String toString() {
return "KernelSVMTrainer(kernel="+kernel.toString()+",lambda="+lambda+",epochs="+epochs+",seed="+seed+")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.common.sgd.AbstractLinearSGDModel;
import org.tribuo.math.kernel.RBF;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.test.Helpers;

import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

public class TestSGDKernel {
Expand Down Expand Up @@ -97,4 +101,42 @@ public void testEmptyExample() {
});
}

@Test
public void testSetInvocationCount() {
// Create new trainer and dataset so as not to mess with the other tests
KernelSVMTrainer originalTrainer = new KernelSVMTrainer(new RBF(1.0),1,5,1000, Trainer.DEFAULT_SEED);
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();

// The number of times to call train before final training.
// Original trainer will be trained numOfInvocations + 1 times
// New trainer will have it's invocation count set to numOfInvocations then trained once
int numOfInvocations = 2;

// Create the first model and train it numOfInvocations + 1 times
Model<Label> originalModel = null;
for(int invocationCounter = 0; invocationCounter < numOfInvocations + 1; invocationCounter++){
originalModel = originalTrainer.train(p.getA());
}

// Create a new model with same configuration, but set the invocation count to numOfInvocations
// Assert that this succeeded, this means RNG will be at state where originalTrainer was before
// it performed its last train.
KernelSVMTrainer newTrainer = new KernelSVMTrainer(new RBF(1.0),1,5,1000, Trainer.DEFAULT_SEED);
newTrainer.setInvocationCount(numOfInvocations);
assertEquals(numOfInvocations,newTrainer.getInvocationCount());

// Training newTrainer should now have the same result as if it
// had trained numOfInvocations times previously even though it hasn't
Model<Label> newModel = newTrainer.train(p.getA());
assertEquals(originalTrainer.getInvocationCount(),newTrainer.getInvocationCount());
}

@Test
public void testNegativeInvocationCount(){
assertThrows(IllegalArgumentException.class, () -> {
KernelSVMTrainer t = new KernelSVMTrainer(new RBF(1.0),1,5,1000, Trainer.DEFAULT_SEED);
t.setInvocationCount(-1);
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,19 @@ public void postConfig() {

@Override
public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance) {
return (train(examples, runProvenance, INCREMENT_INVOCATION_COUNT));
}

@Override
public synchronized XGBoostModel<Label> train(Dataset<Label> examples, Map<String, Provenance> runProvenance, int invocationCount) {
if (examples.getOutputInfo().getUnknownCount() > 0) {
throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
}
ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
ImmutableOutputInfo<Label> outputInfo = examples.getOutputIDInfo();
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
TrainerProvenance trainerProvenance = getProvenance();
trainInvocationCounter++;
parameters.put("num_class", outputInfo.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,18 @@ public synchronized void postConfig() {

@Override
public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance) {
return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT);
}

@Override
public KMeansModel train(Dataset<ClusterID> examples, Map<String, Provenance> runProvenance, int invocationCount) {
// Creates a new local RNG and adds one to the invocation count.
TrainerProvenance trainerProvenance;
SplittableRandom localRNG;
synchronized (this) {
if(invocationCount != INCREMENT_INVOCATION_COUNT) {
setInvocationCount(invocationCount);
}
localRNG = rng.split();
trainerProvenance = getProvenance();
trainInvocationCounter++;
Expand Down Expand Up @@ -303,6 +311,23 @@ public int getInvocationCount() {
return trainInvocationCounter;
}

@Override
public synchronized void setInvocationCount(int invocationCount){
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;

for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
}

}

/**
* Initialisation method called at the start of each train call when using the default centroid initialisation.
* Centroids are initialised using a uniform random sample from the feature domain.
Expand Down
Loading