diff --git a/Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java b/Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java
new file mode 100644
index 000000000..07708e0f6
--- /dev/null
+++ b/Core/src/main/java/org/tribuo/datasource/AggregateConfigurableDataSource.java
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tribuo.datasource;
+
+import com.oracle.labs.mlrg.olcut.config.Config;
+import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
+import com.oracle.labs.mlrg.olcut.provenance.Provenance;
+import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
+import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
+import org.tribuo.ConfigurableDataSource;
+import org.tribuo.Example;
+import org.tribuo.Output;
+import org.tribuo.OutputFactory;
+import org.tribuo.datasource.AggregateDataSource.IterationOrder;
+import org.tribuo.provenance.DataSourceProvenance;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Aggregates multiple {@link ConfigurableDataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the
+ * iteration order.
+ *
+ * Identical to {@link AggregateDataSource} except it can be configured.
+ */
+public class AggregateConfigurableDataSource> implements ConfigurableDataSource {
+
+ @Config(mandatory = true, description = "The iteration order.")
+ private IterationOrder order;
+
+ @Config(mandatory = true, description = "The sources to aggregate.")
+ private List> sources;
+
+ /**
+ * Creates an aggregate data source which will iterate the provided
+ * sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}.
+ * @param sources The sources to aggregate.
+ */
+ public AggregateConfigurableDataSource(List> sources) {
+ this(sources, IterationOrder.SEQUENTIAL);
+ }
+
+ /**
+ * Creates an aggregate data source using the supplied sources and iteration order.
+ * @param sources The sources to iterate.
+ * @param order The iteration order.
+ */
+ public AggregateConfigurableDataSource(List> sources, IterationOrder order) {
+ this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
+ this.order = order;
+ }
+
+ @Override
+ public String toString() {
+ return "AggregateConfigurableDataSource(sources="+sources.toString()+",order="+order+")";
+ }
+
+ @Override
+ public OutputFactory getOutputFactory() {
+ return sources.get(0).getOutputFactory();
+ }
+
+ @Override
+ public Iterator> iterator() {
+ switch (order) {
+ case ROUNDROBIN:
+ return new AggregateDataSource.ADSRRIterator<>(sources);
+ case SEQUENTIAL:
+ return new AggregateDataSource.ADSSeqIterator<>(sources);
+ default:
+ throw new IllegalStateException("Unknown enum value " + order);
+ }
+ }
+
+ @Override
+ public DataSourceProvenance getProvenance() {
+ return new AggregateConfigurableDataSourceProvenance(this);
+ }
+
+ /**
+ * Provenance for the {@link AggregateConfigurableDataSource}.
+ */
+ public static class AggregateConfigurableDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance {
+ private static final long serialVersionUID = 1L;
+
+ > AggregateConfigurableDataSourceProvenance(AggregateConfigurableDataSource host) {
+ super(host, "DataSource");
+ }
+
+ /**
+ * Deserialization constructor.
+ * @param map The provenance to deserialize.
+ */
+ public AggregateConfigurableDataSourceProvenance(Map map) {
+ this(extractProvenanceInfo(map));
+ }
+
+ private AggregateConfigurableDataSourceProvenance(ExtractedInfo info) {
+ super(info);
+ }
+
+ protected static ExtractedInfo extractProvenanceInfo(Map map) {
+ Map configuredParameters = new HashMap<>(map);
+ String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue();
+ String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue();
+ return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
+ }
+ }
+}
diff --git a/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java b/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java
index 2cebddee3..de1fac1c6 100644
--- a/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java
+++ b/Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java
@@ -19,6 +19,7 @@
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
+import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.DataSource;
@@ -26,29 +27,65 @@
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.provenance.DataSourceProvenance;
+import org.tribuo.provenance.ModelProvenance;
+import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
+import java.util.Optional;
/**
- * Aggregates multiple {@link DataSource}s, and round-robins the iterators.
+ * Aggregates multiple {@link DataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the
+ * iteration order.
*/
public class AggregateDataSource> implements DataSource {
-
+
+ /**
+ * Specifies the iteration order of the inner sources.
+ */
+ public enum IterationOrder {
+ /**
+ * Round-robins the iterators (i.e., chooses one from each in turn).
+ */
+ ROUNDROBIN,
+ /**
+ * Iterates each dataset sequentially, in the order of the sources list.
+ */
+ SEQUENTIAL;
+ }
+
+ private final IterationOrder order;
+
private final List> sources;
+ /**
+ * Creates an aggregate data source which will iterate the provided
+ * sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}.
+ * @param sources The sources to aggregate.
+ */
public AggregateDataSource(List> sources) {
+ this(sources,IterationOrder.SEQUENTIAL);
+ }
+
+ /**
+ * Creates an aggregate data source using the supplied sources and iteration order.
+ * @param sources The sources to iterate.
+ * @param order The iteration order.
+ */
+ public AggregateDataSource(List> sources, IterationOrder order) {
this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
+ this.order = order;
}
@Override
public String toString() {
- return "AggregateDataSource(sources="+sources.toString()+")";
+ return "AggregateDataSource(sources="+sources.toString()+",order="+order+")";
}
@Override
@@ -58,7 +95,14 @@ public OutputFactory getOutputFactory() {
@Override
public Iterator> iterator() {
- return new ADSIterator();
+ switch (order) {
+ case ROUNDROBIN:
+ return new ADSRRIterator<>(sources);
+ case SEQUENTIAL:
+ return new ADSSeqIterator<>(sources);
+ default:
+ throw new IllegalStateException("Unknown enum value " + order);
+ }
}
@Override
@@ -66,9 +110,51 @@ public DataSourceProvenance getProvenance() {
return new AggregateDataSourceProvenance(this);
}
- private class ADSIterator implements Iterator> {
- Iterator> si = sources.iterator();
- Iterator> curr = null;
+ static class ADSRRIterator> implements Iterator> {
+ private final Deque>> queue;
+
+ ADSRRIterator(List extends DataSource> sources) {
+ this.queue = new ArrayDeque<>(sources.size());
+ for (DataSource ds : sources) {
+ Iterator> itr = ds.iterator();
+ if (itr.hasNext()) {
+ queue.addLast(itr);
+ }
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !queue.isEmpty();
+ }
+
+ @Override
+ public Example next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException("Iterator exhausted");
+ }
+ Iterator> itr = queue.pollFirst();
+ if (itr.hasNext()) {
+ Example buff = itr.next();
+ if (itr.hasNext()) {
+ queue.addLast(itr);
+ }
+ return buff;
+ } else {
+ throw new IllegalStateException("Invalid iterator in queue");
+ }
+ }
+ }
+
+ static class ADSSeqIterator> implements Iterator> {
+ private final Iterator extends DataSource> si;
+ private Iterator> curr;
+
+ ADSSeqIterator(List extends DataSource> sources) {
+ this.si = sources.iterator();
+ this.curr = null;
+ }
+
@Override
public boolean hasNext() {
if (curr == null) {
@@ -106,19 +192,25 @@ public static class AggregateDataSourceProvenance implements DataSourceProvenanc
private static final long serialVersionUID = 1L;
private static final String SOURCES = "sources";
+ private static final String ORDER = "order";
private final StringProvenance className;
private final ListProvenance provenances;
+ private EnumProvenance orderProvenance;
> AggregateDataSourceProvenance(AggregateDataSource host) {
this.className = new StringProvenance(CLASS_NAME,host.getClass().getName());
this.provenances = ListProvenance.createListProvenance(host.sources);
+ this.orderProvenance = new EnumProvenance<>(ORDER,host.order);
}
- @SuppressWarnings("unchecked") //ListProvenance cast
+ @SuppressWarnings({"unchecked","rawtypes"}) //ListProvenance cast, EnumProvenance cast
public AggregateDataSourceProvenance(Map map) {
this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
this.provenances = ObjectProvenance.checkAndExtractProvenance(map,SOURCES,ListProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
+ // TODO fix this when we upgrade OLCUT.
+ Optional opt = ModelProvenance.maybeExtractProvenance(map,ORDER,EnumProvenance.class);
+ this.orderProvenance = opt.orElseGet(() -> new EnumProvenance<>(ORDER, IterationOrder.SEQUENTIAL));
}
@Override
@@ -132,22 +224,32 @@ public Iterator> iterator() {
list.add(new Pair<>(CLASS_NAME,className));
list.add(new Pair<>(SOURCES,provenances));
+ list.add(new Pair<>(ORDER,getOrder()));
return list.iterator();
}
+ private EnumProvenance getOrder() {
+ if (orderProvenance != null) {
+ return orderProvenance;
+ } else {
+ return new EnumProvenance<>(ORDER,IterationOrder.SEQUENTIAL);
+ }
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof AggregateDataSourceProvenance)) return false;
AggregateDataSourceProvenance pairs = (AggregateDataSourceProvenance) o;
return className.equals(pairs.className) &&
- provenances.equals(pairs.provenances);
+ provenances.equals(pairs.provenances) &&
+ getOrder().equals(pairs.getOrder());
}
@Override
public int hashCode() {
- return Objects.hash(className, provenances);
+ return Objects.hash(className, provenances, getOrder());
}
@Override
diff --git a/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java b/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java
index af986af91..f914a4772 100644
--- a/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java
+++ b/Core/src/main/java/org/tribuo/provenance/ModelProvenance.java
@@ -156,6 +156,8 @@ public ModelProvenance(Map map) {
this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
this.instanceProvenance = (MapProvenance>) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName());
this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class,ModelProvenance.class.getSimpleName()).getValue();
+
+ // TODO fix this when we upgrade OLCUT.
this.javaVersionString = maybeExtractProvenance(map,JAVA_VERSION_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.osString = maybeExtractProvenance(map,OS_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.archString = maybeExtractProvenance(map,ARCH_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
@@ -164,6 +166,8 @@ public ModelProvenance(Map map) {
/**
* Like {@link ObjectProvenance#checkAndExtractProvenance(Map, String, Class, String)} but doesn't
* throw if it fails to find the key, only if the value is of the wrong type.
+ *
+ * @deprecated Deprecated as it's in OLCUT.
* @param map The map to inspect.
* @param key The key to find.
* @param type The class of the value.
@@ -172,7 +176,8 @@ public ModelProvenance(Map map) {
* @throws ProvenanceException If the value is the wrong type.
*/
@SuppressWarnings("unchecked") // Guarded by isInstance check
- private static Optional maybeExtractProvenance(Map map, String key, Class type) throws ProvenanceException {
+ @Deprecated
+ public static Optional maybeExtractProvenance(Map map, String key, Class type) throws ProvenanceException {
Provenance tmp = map.remove(key);
if (tmp != null) {
if (type.isInstance(tmp)) {
diff --git a/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java b/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java
new file mode 100644
index 000000000..f94ec17df
--- /dev/null
+++ b/Core/src/test/java/org/tribuo/datasource/AggregateDataSourceTest.java
@@ -0,0 +1,177 @@
+/*
+ * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tribuo.datasource;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.tribuo.ConfigurableDataSource;
+import org.tribuo.DataSource;
+import org.tribuo.Example;
+import org.tribuo.Output;
+import org.tribuo.OutputFactory;
+import org.tribuo.impl.ArrayExample;
+import org.tribuo.provenance.DataSourceProvenance;
+import org.tribuo.provenance.SimpleDataSourceProvenance;
+import org.tribuo.test.Helpers;
+import org.tribuo.test.MockOutput;
+import org.tribuo.test.MockOutputFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.StreamSupport;
+
+public class AggregateDataSourceTest {
+
+ @Test
+ public void testADSIterationOrder() {
+ MockOutputFactory factory = new MockOutputFactory();
+ String[] featureNames = new String[] {"X1","X2"};
+ double[] featureValues = new double[] {1.0, 2.0};
+
+ List> first = new ArrayList<>();
+ first.add(new ArrayExample<>(new MockOutput("A"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("B"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("C"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("D"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("E"),featureNames,featureValues));
+ ListDataSource firstSource = new ListDataSource<>(first,factory,new SimpleDataSourceProvenance("First",factory));
+
+ List> second = new ArrayList<>();
+ second.add(new ArrayExample<>(new MockOutput("F"),featureNames,featureValues));
+ second.add(new ArrayExample<>(new MockOutput("G"),featureNames,featureValues));
+ ListDataSource secondSource = new ListDataSource<>(second,factory,new SimpleDataSourceProvenance("Second",factory));
+
+ List> third = new ArrayList<>();
+ third.add(new ArrayExample<>(new MockOutput("H"),featureNames,featureValues));
+ third.add(new ArrayExample<>(new MockOutput("I"),featureNames,featureValues));
+ third.add(new ArrayExample<>(new MockOutput("J"),featureNames,featureValues));
+ third.add(new ArrayExample<>(new MockOutput("K"),featureNames,featureValues));
+ ListDataSource thirdSource = new ListDataSource<>(third,factory,new SimpleDataSourceProvenance("Third",factory));
+
+ List> sources = new ArrayList<>();
+ sources.add(firstSource);
+ sources.add(secondSource);
+ sources.add(thirdSource);
+
+ AggregateDataSource adsSeq = new AggregateDataSource<>(sources, AggregateDataSource.IterationOrder.SEQUENTIAL);
+ String[] expectedSeq = new String[] {"A","B","C","D","E","F","G","H","I","J","K"};
+ String[] actualSeq = StreamSupport.stream(adsSeq.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new);
+ Assertions.assertArrayEquals(expectedSeq,actualSeq);
+ Helpers.testProvenanceMarshalling(adsSeq.getProvenance());
+
+ AggregateDataSource adsRR = new AggregateDataSource<>(sources, AggregateDataSource.IterationOrder.ROUNDROBIN);
+ String[] expectedRR = new String[] {"A","F","H","B","G","I","C","J","D","K","E"};
+ String[] actualRR = StreamSupport.stream(adsRR.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new);
+ Assertions.assertArrayEquals(expectedRR,actualRR);
+ Helpers.testProvenanceMarshalling(adsRR.getProvenance());
+ }
+
+ @Test
+ public void testACDSIterationOrder() {
+ MockOutputFactory factory = new MockOutputFactory();
+ String[] featureNames = new String[] {"X1","X2"};
+ double[] featureValues = new double[] {1.0, 2.0};
+
+ List> first = new ArrayList<>();
+ first.add(new ArrayExample<>(new MockOutput("A"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("B"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("C"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("D"),featureNames,featureValues));
+ first.add(new ArrayExample<>(new MockOutput("E"),featureNames,featureValues));
+ MockListConfigurableDataSource firstSource = new MockListConfigurableDataSource<>(first,factory,new SimpleDataSourceProvenance("First",factory));
+
+ List> second = new ArrayList<>();
+ second.add(new ArrayExample<>(new MockOutput("F"),featureNames,featureValues));
+ second.add(new ArrayExample<>(new MockOutput("G"),featureNames,featureValues));
+ MockListConfigurableDataSource secondSource = new MockListConfigurableDataSource<>(second,factory,new SimpleDataSourceProvenance("Second",factory));
+
+ List> third = new ArrayList<>();
+ third.add(new ArrayExample<>(new MockOutput("H"),featureNames,featureValues));
+ third.add(new ArrayExample<>(new MockOutput("I"),featureNames,featureValues));
+ third.add(new ArrayExample<>(new MockOutput("J"),featureNames,featureValues));
+ third.add(new ArrayExample<>(new MockOutput("K"),featureNames,featureValues));
+ MockListConfigurableDataSource thirdSource = new MockListConfigurableDataSource<>(third,factory,new SimpleDataSourceProvenance("Third",factory));
+
+ List> sources = new ArrayList<>();
+ sources.add(firstSource);
+ sources.add(secondSource);
+ sources.add(thirdSource);
+
+ AggregateConfigurableDataSource acdsSeq = new AggregateConfigurableDataSource<>(sources, AggregateDataSource.IterationOrder.SEQUENTIAL);
+ String[] expectedSeq = new String[] {"A","B","C","D","E","F","G","H","I","J","K"};
+ String[] actualSeq = StreamSupport.stream(acdsSeq.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new);
+ Assertions.assertArrayEquals(expectedSeq,actualSeq);
+ Helpers.testProvenanceMarshalling(acdsSeq.getProvenance());
+
+ AggregateConfigurableDataSource acdsRR = new AggregateConfigurableDataSource<>(sources, AggregateDataSource.IterationOrder.ROUNDROBIN);
+ String[] expectedRR = new String[] {"A","F","H","B","G","I","C","J","D","K","E"};
+ String[] actualRR = StreamSupport.stream(acdsRR.spliterator(), false).map(Example::getOutput).map(MockOutput::toString).toArray(String[]::new);
+ Assertions.assertArrayEquals(expectedRR,actualRR);
+ Helpers.testProvenanceMarshalling(acdsRR.getProvenance());
+
+ }
+
+ /**
+ * This isn't actually configurable, it's used to test {@link AggregateConfigurableDataSource}.
+ * @param The output type.
+ */
+ private static class MockListConfigurableDataSource> implements ConfigurableDataSource {
+
+ private final List> data;
+
+ private final OutputFactory factory;
+
+ private final DataSourceProvenance provenance;
+
+ public MockListConfigurableDataSource(List> list, OutputFactory factory, DataSourceProvenance provenance) {
+ this.data = Collections.unmodifiableList(new ArrayList<>(list));
+ this.factory = factory;
+ this.provenance = provenance;
+ }
+
+ /**
+ * Number of examples.
+ * @return The number of examples.
+ */
+ public int size() {
+ return data.size();
+ }
+
+ @Override
+ public OutputFactory getOutputFactory() {
+ return factory;
+ }
+
+ @Override
+ public DataSourceProvenance getProvenance() {
+ return provenance;
+ }
+
+ @Override
+ public Iterator> iterator() {
+ return data.iterator();
+ }
+
+ @Override
+ public String toString() {
+ return provenance.toString();
+ }
+ }
+
+}
diff --git a/Core/src/test/java/org/tribuo/test/Helpers.java b/Core/src/test/java/org/tribuo/test/Helpers.java
index 60610fe25..bbabfe743 100644
--- a/Core/src/test/java/org/tribuo/test/Helpers.java
+++ b/Core/src/test/java/org/tribuo/test/Helpers.java
@@ -16,6 +16,7 @@
package org.tribuo.test;
+import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
import org.junit.jupiter.api.Assertions;
@@ -75,11 +76,15 @@ public static Example mkExample(MockOutput label, String... features
return ex;
}
+ public static void testProvenanceMarshalling(ObjectProvenance inputProvenance) {
+ List provenanceList = ProvenanceUtil.marshalProvenance(inputProvenance);
+ ObjectProvenance unmarshalledProvenance = ProvenanceUtil.unmarshalProvenance(provenanceList);
+ Assertions.assertEquals(unmarshalledProvenance,inputProvenance);
+ }
+
public static > void testModelSerialization(Model model, Class outputClazz) {
// test provenance marshalling
- List provenanceList = ProvenanceUtil.marshalProvenance(model.getProvenance());
- ModelProvenance provenance = (ModelProvenance) ProvenanceUtil.unmarshalProvenance(provenanceList);
- Assertions.assertEquals(provenance,model.getProvenance());
+ testProvenanceMarshalling(model.getProvenance());
// write to byte array
ByteArrayOutputStream baos = new ByteArrayOutputStream();