Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provenance updates #115

Merged
merged 2 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;

import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* Model provenance for ensemble models.
Expand All @@ -38,40 +36,110 @@ public class EnsembleModelProvenance extends ModelProvenance {

private final ListProvenance<? extends ModelProvenance> memberProvenance;

public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, ListProvenance<? extends ModelProvenance> memberProvenance) {
/**
* Creates a provenance for an ensemble model tracking the class name, creation time, dataset provenance and
* trainer provenance along with the individual model provenances
* for each ensemble member.
* <p>
* Also tracks system details like the os name, os architecture, java version, and Tribuo version.
* @param className The model class name.
* @param time The model creation time.
* @param datasetProvenance The dataset provenance.
* @param trainerProvenance The trainer provenance.
* @param memberProvenance The ensemble member provenances.
*/
public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance,
TrainerProvenance trainerProvenance,
ListProvenance<? extends ModelProvenance> memberProvenance) {
super(className, time, datasetProvenance, trainerProvenance);
this.memberProvenance = memberProvenance;
}

public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, Map<String, Provenance> instanceProvenance, ListProvenance<? extends ModelProvenance> memberProvenance) {
/**
* Creates a provenance for an ensemble model tracking the class name, creation time, dataset provenance,
* trainer provenance and any instance specific provenance along with the individual model provenances
* for each ensemble member.
* <p>
* Also tracks system details like the os name, os architecture, java version, and Tribuo version.
* @param className The model class name.
* @param time The model creation time.
* @param datasetProvenance The dataset provenance.
* @param trainerProvenance The trainer provenance.
* @param instanceProvenance Provenance for this specific model training run.
* @param memberProvenance The ensemble member provenances.
*/
public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance,
TrainerProvenance trainerProvenance, Map<String, Provenance> instanceProvenance,
ListProvenance<? extends ModelProvenance> memberProvenance) {
super(className, time, datasetProvenance, trainerProvenance, instanceProvenance);
this.memberProvenance = memberProvenance;
}

/**
* Creates a provenance for an ensemble model tracking the class name, creation time, dataset provenance,
* trainer provenance and any instance specific provenance along with the individual model provenances
* for each ensemble member.
* <p>
* Also optionally tracks system details like the os name, os architecture, java version, and Tribuo version.
* @param className The model class name.
* @param time The model creation time.
* @param datasetProvenance The dataset provenance.
* @param trainerProvenance The trainer provenance.
* @param instanceProvenance Provenance for this specific model training run.
* @param memberProvenance The ensemble member provenances.
* @param trackSystem If true then store the java version, os name and os arch in the provenance.
*/
public EnsembleModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance,
TrainerProvenance trainerProvenance, Map<String, Provenance> instanceProvenance,
boolean trackSystem, ListProvenance<? extends ModelProvenance> memberProvenance) {
super(className, time, datasetProvenance, trainerProvenance, instanceProvenance, trackSystem);
this.memberProvenance = memberProvenance;
}

/**
* Used by the provenance unmarshalling system.
* <p>
* Throws {@link com.oracle.labs.mlrg.olcut.provenance.ProvenanceException} if there are missing
* fields.
* @param map The provenance map.
*/
@SuppressWarnings("unchecked") // member provenance cast.
public EnsembleModelProvenance(Map<String, Provenance> map) {
super(map);
this.memberProvenance = (ListProvenance<? extends ModelProvenance>) ObjectProvenance.checkAndExtractProvenance(map,MEMBERS,ListProvenance.class, EnsembleModelProvenance.class.getSimpleName());
}

/**
* Get the provenances for each ensemble member.
* @return The member provenances.
*/
public ListProvenance<? extends ModelProvenance> getMemberProvenance() {
return memberProvenance;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
EnsembleModelProvenance pairs = (EnsembleModelProvenance) o;
return memberProvenance.equals(pairs.memberProvenance);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), memberProvenance);
}

@Override
public String toString() {
return generateString("EnsembleModel");
}

@Override
public Iterator<Pair<String, Provenance>> iterator() {
ArrayList<Pair<String,Provenance>> iterable = new ArrayList<>();
iterable.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className)));
iterable.add(new Pair<>(DATASET,datasetProvenance));
iterable.add(new Pair<>(TRAINER,trainerProvenance));
iterable.add(new Pair<>(TRAINING_TIME,new DateTimeProvenance(TRAINING_TIME,time)));
iterable.add(new Pair<>(INSTANCE_VALUES,instanceProvenance));
iterable.add(new Pair<>(MEMBERS,memberProvenance));
return iterable.iterator();
protected List<Pair<String, Provenance>> internalProvenances() {
List<Pair<String, Provenance>> superList = super.internalProvenances();
superList.add(new Pair<>(MEMBERS,memberProvenance));
return superList;
}
}
174 changes: 159 additions & 15 deletions Core/src/main/java/org/tribuo/provenance/ModelProvenance.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,29 @@
import com.oracle.labs.mlrg.olcut.provenance.MapProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Tribuo;

import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
* Contains provenance information for an instance of a {@link org.tribuo.Model}.
* <p>
* Made up of the class name of the model object, the date and time it was trained, the provenance of
* the training data, and the provenance of the trainer.
* <p>
* In addition by default it collects the Java version, OS name and system architecture, along
* with the running Tribuo version.
*/
public class ModelProvenance implements ObjectProvenance {
private static final long serialVersionUID = 1L;
Expand All @@ -43,7 +50,17 @@ public class ModelProvenance implements ObjectProvenance {
protected static final String TRAINER = "trainer";
protected static final String TRAINING_TIME = "trained-at";
protected static final String INSTANCE_VALUES = "instance-values";
private static final String TRIBUO_VERSION_STRING = "tribuo-version";
protected static final String TRIBUO_VERSION_STRING = "tribuo-version";

// Note these have been added due to a discrepancy between java.lang.Math
// and java.lang.StrictMath on x64 and aarch64 platforms (and between Java 8 and 9+).
// Training a linear SGD predictor can create different models on different platforms
// due to this discrepancy.
protected static final String JAVA_VERSION_STRING = "java-version";
protected static final String OS_STRING = "os-name";
protected static final String ARCH_STRING = "os-arch";

protected static final String UNKNOWN_VERSION = "unknown-version";

protected final String className;

Expand All @@ -57,31 +74,115 @@ public class ModelProvenance implements ObjectProvenance {

protected final String versionString;

public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance) {
this.className = className;
this.time = time;
this.datasetProvenance = datasetProvenance;
this.trainerProvenance = trainerProvenance;
this.instanceProvenance = new MapProvenance<>();
this.versionString = Tribuo.VERSION;
protected final String javaVersionString;

protected final String osString;

protected final String archString;

/**
* Creates a model provenance tracking the class name, creation time, dataset provenance and trainer provenance.
* <p>
* Also tracks system details like the os name, os architecture, java version, and Tribuo version.
* @param className The model class name.
* @param time The model creation time.
* @param datasetProvenance The dataset provenance.
* @param trainerProvenance The trainer provenance.
*/
public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance,
TrainerProvenance trainerProvenance) {
this(className,time,datasetProvenance,trainerProvenance,Collections.emptyMap());
}

/**
* Creates a model provenance tracking the class name, creation time, dataset provenance,
* trainer provenance and any instance specific provenance.
* <p>
* Also tracks system details like the os name, os architecture, java version, and Tribuo version.
* @param className The model class name.
* @param time The model creation time.
* @param datasetProvenance The dataset provenance.
* @param trainerProvenance The trainer provenance.
* @param instanceProvenance Provenance for this specific model training run.
*/
public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance,
TrainerProvenance trainerProvenance, Map<String,Provenance> instanceProvenance) {
this(className,time,datasetProvenance,trainerProvenance,instanceProvenance,true);
}

public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance, TrainerProvenance trainerProvenance, Map<String,Provenance> instanceProvenance) {
/**
* Creates a model provenance tracking the class name, creation time, dataset provenance,
* trainer provenance and any instance specific provenance.
* <p>
* Also optionally tracks system details like the os name, os architecture, java version, and Tribuo version.
* @param className The model class name.
* @param time The model creation time.
* @param datasetProvenance The dataset provenance.
* @param trainerProvenance The trainer provenance.
* @param instanceProvenance Provenance for this specific model training run.
* @param trackSystem If true then store the java version, os name and os arch in the provenance.
*/
public ModelProvenance(String className, OffsetDateTime time, DatasetProvenance datasetProvenance,
TrainerProvenance trainerProvenance, Map<String,Provenance> instanceProvenance,
boolean trackSystem) {
this.className = className;
this.time = time;
this.datasetProvenance = datasetProvenance;
this.trainerProvenance = trainerProvenance;
this.instanceProvenance = new MapProvenance<>(instanceProvenance);
this.instanceProvenance = instanceProvenance.isEmpty() ? new MapProvenance<>() : new MapProvenance<>(instanceProvenance);
this.versionString = Tribuo.VERSION;
if (trackSystem) {
this.javaVersionString = System.getProperty("java.version");
this.osString = System.getProperty("os.name");
this.archString = System.getProperty("os.arch");
} else {
this.javaVersionString = UNKNOWN_VERSION;
this.osString = UNKNOWN_VERSION;
this.archString = UNKNOWN_VERSION;
}
}

/**
* Used by the provenance unmarshalling system.
* <p>
* Throws {@link com.oracle.labs.mlrg.olcut.provenance.ProvenanceException} if there are missing
* fields.
* @param map The provenance map.
*/
public ModelProvenance(Map<String,Provenance> map) {
this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME,StringProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
this.datasetProvenance = ObjectProvenance.checkAndExtractProvenance(map,DATASET,DatasetProvenance.class, ModelProvenance.class.getSimpleName());
this.trainerProvenance = ObjectProvenance.checkAndExtractProvenance(map,TRAINER,TrainerProvenance.class, ModelProvenance.class.getSimpleName());
this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
this.instanceProvenance = (MapProvenance<?>) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName());
this.versionString = ObjectProvenance.checkAndExtractProvenance(map, TRIBUO_VERSION_STRING,StringProvenance.class, DatasetProvenance.class.getSimpleName()).getValue();
this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class,ModelProvenance.class.getSimpleName()).getValue();
this.javaVersionString = maybeExtractProvenance(map,JAVA_VERSION_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.osString = maybeExtractProvenance(map,OS_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.archString = maybeExtractProvenance(map,ARCH_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
}

/**
* Like {@link ObjectProvenance#checkAndExtractProvenance(Map, String, Class, String)} but doesn't
* throw if it fails to find the key, only if the value is of the wrong type.
* @param map The map to inspect.
* @param key The key to find.
* @param type The class of the value.
* @param <T> The type of the value.
* @return An optional containing the value if present.
* @throws ProvenanceException If the value is the wrong type.
*/
@SuppressWarnings("unchecked") // Guarded by isInstance check
private static <T extends Provenance> Optional<T> maybeExtractProvenance(Map<String,Provenance> map, String key, Class<T> type) throws ProvenanceException {
Provenance tmp = map.remove(key);
if (tmp != null) {
if (type.isInstance(tmp)) {
return Optional.of((T) tmp);
} else {
throw new ProvenanceException("Failed to cast " + key + " when constructing ModelProvenance, found " + tmp);
}
} else {
return Optional.empty();
}
}

/**
Expand Down Expand Up @@ -124,6 +225,30 @@ public String getTribuoVersion() {
return versionString;
}

/**
* The Java version used to create this model.
* @return The Java version.
*/
public String getJavaVersion() {
return javaVersionString;
}

/**
* The name of the OS used to create this model.
* @return The OS name.
*/
public String getOS() {
return osString;
}

/**
* The CPU architecture used to create this model.
* @return The CPU architecture.
*/
public String getArch() {
return archString;
}

@Override
public String toString() {
return generateString("Model");
Expand All @@ -144,23 +269,42 @@ public boolean equals(Object o) {
datasetProvenance.equals(pairs.datasetProvenance) &&
trainerProvenance.equals(pairs.trainerProvenance) &&
instanceProvenance.equals(pairs.instanceProvenance) &&
versionString.equals(pairs.versionString);
versionString.equals(pairs.versionString) &&
javaVersionString.equals(pairs.javaVersionString) &&
osString.equals(pairs.osString) &&
archString.equals(pairs.archString);
}

@Override
public int hashCode() {
return Objects.hash(className, time, datasetProvenance, trainerProvenance, instanceProvenance, versionString);
}

@Override
public Iterator<Pair<String, Provenance>> iterator() {
/**
* Returns a list of all the provenances in this model provenance so subclasses
* can append to the list.
* @return A list of all the provenances in this class.
*/
protected List<Pair<String,Provenance>> internalProvenances() {
ArrayList<Pair<String,Provenance>> iterable = new ArrayList<>();
iterable.add(new Pair<>(CLASS_NAME,new StringProvenance(CLASS_NAME,className)));
iterable.add(new Pair<>(DATASET,datasetProvenance));
iterable.add(new Pair<>(TRAINER,trainerProvenance));
iterable.add(new Pair<>(TRAINING_TIME,new DateTimeProvenance(TRAINING_TIME,time)));
iterable.add(new Pair<>(INSTANCE_VALUES,instanceProvenance));
iterable.add(new Pair<>(TRIBUO_VERSION_STRING,new StringProvenance(TRIBUO_VERSION_STRING,versionString)));
return iterable.iterator();
iterable.add(new Pair<>(JAVA_VERSION_STRING,new StringProvenance(JAVA_VERSION_STRING,javaVersionString)));
iterable.add(new Pair<>(OS_STRING,new StringProvenance(OS_STRING,osString)));
iterable.add(new Pair<>(ARCH_STRING,new StringProvenance(ARCH_STRING,archString)));
return iterable;
}

/**
* Calls {@link #internalProvenances()} and returns the iterator from that list.
* @return An iterator over all the provenances.
*/
@Override
public Iterator<Pair<String, Provenance>> iterator() {
return internalProvenances().iterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public Map<String, PrimitiveProvenance<?>> getInstanceValues() {

map.put(TRAIN_INVOCATION_COUNT, invocationCount);
map.put(IS_SEQUENCE, isSequence);
map.put(TRIBUO_VERSION_STRING, version);

return map;
}
Expand Down
Loading