Skip to content

Commit

Permalink
Allows configurable iteration order in AggregateDataSource, and adds …
Browse files Browse the repository at this point in the history
…a configurable version (#125)

* Adding round robin iteration to AggregateDataSource.

* Adding AggregateConfigurableDataSource and slighly refactoring the way provenance marshalling is tested.
  • Loading branch information
Craigacp authored Apr 14, 2021
1 parent 0adf555 commit 1f832c9
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.datasource;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.datasource.AggregateDataSource.IterationOrder;
import org.tribuo.provenance.DataSourceProvenance;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
* Aggregates multiple {@link ConfigurableDataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the
* iteration order.
* <p>
* Identical to {@link AggregateDataSource} except it can be configured.
*/
public class AggregateConfigurableDataSource<T extends Output<T>> implements ConfigurableDataSource<T> {

@Config(mandatory = true, description = "The iteration order.")
private IterationOrder order;

@Config(mandatory = true, description = "The sources to aggregate.")
private List<ConfigurableDataSource<T>> sources;

/**
* Creates an aggregate data source which will iterate the provided
* sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}.
* @param sources The sources to aggregate.
*/
public AggregateConfigurableDataSource(List<ConfigurableDataSource<T>> sources) {
this(sources, IterationOrder.SEQUENTIAL);
}

/**
* Creates an aggregate data source using the supplied sources and iteration order.
* @param sources The sources to iterate.
* @param order The iteration order.
*/
public AggregateConfigurableDataSource(List<ConfigurableDataSource<T>> sources, IterationOrder order) {
this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
this.order = order;
}

@Override
public String toString() {
return "AggregateConfigurableDataSource(sources="+sources.toString()+",order="+order+")";
}

@Override
public OutputFactory<T> getOutputFactory() {
return sources.get(0).getOutputFactory();
}

@Override
public Iterator<Example<T>> iterator() {
switch (order) {
case ROUNDROBIN:
return new AggregateDataSource.ADSRRIterator<>(sources);
case SEQUENTIAL:
return new AggregateDataSource.ADSSeqIterator<>(sources);
default:
throw new IllegalStateException("Unknown enum value " + order);
}
}

@Override
public DataSourceProvenance getProvenance() {
return new AggregateConfigurableDataSourceProvenance(this);
}

/**
* Provenance for the {@link AggregateConfigurableDataSource}.
*/
public static class AggregateConfigurableDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements DataSourceProvenance {
private static final long serialVersionUID = 1L;

<T extends Output<T>> AggregateConfigurableDataSourceProvenance(AggregateConfigurableDataSource<T> host) {
super(host, "DataSource");
}

/**
* Deserialization constructor.
* @param map The provenance to deserialize.
*/
public AggregateConfigurableDataSourceProvenance(Map<String, Provenance> map) {
this(extractProvenanceInfo(map));
}

private AggregateConfigurableDataSourceProvenance(ExtractedInfo info) {
super(info);
}

protected static ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
Map<String, Provenance> configuredParameters = new HashMap<>(map);
String className = ObjectProvenance.checkAndExtractProvenance(configuredParameters,CLASS_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue();
String hostTypeStringName = ObjectProvenance.checkAndExtractProvenance(configuredParameters, HOST_SHORT_NAME, StringProvenance.class, AggregateConfigurableDataSourceProvenance.class.getSimpleName()).getValue();
return new ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
}
}
}
122 changes: 112 additions & 10 deletions Core/src/main/java/org/tribuo/datasource/AggregateDataSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,73 @@
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.EnumProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.provenance.ModelProvenance;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;

/**
* Aggregates multiple {@link DataSource}s, and round-robins the iterators.
* Aggregates multiple {@link DataSource}s, uses {@link AggregateDataSource.IterationOrder} to control the
* iteration order.
*/
public class AggregateDataSource<T extends Output<T>> implements DataSource<T> {


/**
* Specifies the iteration order of the inner sources.
*/
public enum IterationOrder {
/**
* Round-robins the iterators (i.e., chooses one from each in turn).
*/
ROUNDROBIN,
/**
* Iterates each dataset sequentially, in the order of the sources list.
*/
SEQUENTIAL;
}

private final IterationOrder order;

private final List<DataSource<T>> sources;

/**
* Creates an aggregate data source which will iterate the provided
* sources in the order of the list (i.e., using {@link IterationOrder#SEQUENTIAL}.
* @param sources The sources to aggregate.
*/
public AggregateDataSource(List<DataSource<T>> sources) {
this(sources,IterationOrder.SEQUENTIAL);
}

/**
* Creates an aggregate data source using the supplied sources and iteration order.
* @param sources The sources to iterate.
* @param order The iteration order.
*/
public AggregateDataSource(List<DataSource<T>> sources, IterationOrder order) {
this.sources = Collections.unmodifiableList(new ArrayList<>(sources));
this.order = order;
}

@Override
public String toString() {
return "AggregateDataSource(sources="+sources.toString()+")";
return "AggregateDataSource(sources="+sources.toString()+",order="+order+")";
}

@Override
Expand All @@ -58,17 +95,66 @@ public OutputFactory<T> getOutputFactory() {

@Override
public Iterator<Example<T>> iterator() {
return new ADSIterator();
switch (order) {
case ROUNDROBIN:
return new ADSRRIterator<>(sources);
case SEQUENTIAL:
return new ADSSeqIterator<>(sources);
default:
throw new IllegalStateException("Unknown enum value " + order);
}
}

@Override
public DataSourceProvenance getProvenance() {
return new AggregateDataSourceProvenance(this);
}

private class ADSIterator implements Iterator<Example<T>> {
Iterator<DataSource<T>> si = sources.iterator();
Iterator<Example<T>> curr = null;
static class ADSRRIterator<T extends Output<T>> implements Iterator<Example<T>> {
private final Deque<Iterator<Example<T>>> queue;

ADSRRIterator(List<? extends DataSource<T>> sources) {
this.queue = new ArrayDeque<>(sources.size());
for (DataSource<T> ds : sources) {
Iterator<Example<T>> itr = ds.iterator();
if (itr.hasNext()) {
queue.addLast(itr);
}
}
}

@Override
public boolean hasNext() {
return !queue.isEmpty();
}

@Override
public Example<T> next() {
if (!hasNext()) {
throw new NoSuchElementException("Iterator exhausted");
}
Iterator<Example<T>> itr = queue.pollFirst();
if (itr.hasNext()) {
Example<T> buff = itr.next();
if (itr.hasNext()) {
queue.addLast(itr);
}
return buff;
} else {
throw new IllegalStateException("Invalid iterator in queue");
}
}
}

static class ADSSeqIterator<T extends Output<T>> implements Iterator<Example<T>> {
private final Iterator<? extends DataSource<T>> si;
private Iterator<Example<T>> curr;

ADSSeqIterator(List<? extends DataSource<T>> sources) {
this.si = sources.iterator();
this.curr = null;
}

@Override
public boolean hasNext() {
if (curr == null) {
Expand Down Expand Up @@ -106,19 +192,25 @@ public static class AggregateDataSourceProvenance implements DataSourceProvenanc
private static final long serialVersionUID = 1L;

private static final String SOURCES = "sources";
private static final String ORDER = "order";

private final StringProvenance className;
private final ListProvenance<DataSourceProvenance> provenances;
private EnumProvenance<IterationOrder> orderProvenance;

<T extends Output<T>> AggregateDataSourceProvenance(AggregateDataSource<T> host) {
this.className = new StringProvenance(CLASS_NAME,host.getClass().getName());
this.provenances = ListProvenance.createListProvenance(host.sources);
this.orderProvenance = new EnumProvenance<>(ORDER,host.order);
}

@SuppressWarnings("unchecked") //ListProvenance cast
@SuppressWarnings({"unchecked","rawtypes"}) //ListProvenance cast, EnumProvenance cast
public AggregateDataSourceProvenance(Map<String,Provenance> map) {
this.className = ObjectProvenance.checkAndExtractProvenance(map,CLASS_NAME, StringProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
this.provenances = ObjectProvenance.checkAndExtractProvenance(map,SOURCES,ListProvenance.class,AggregateDataSourceProvenance.class.getSimpleName());
// TODO fix this when we upgrade OLCUT.
Optional<EnumProvenance> opt = ModelProvenance.maybeExtractProvenance(map,ORDER,EnumProvenance.class);
this.orderProvenance = opt.orElseGet(() -> new EnumProvenance<>(ORDER, IterationOrder.SEQUENTIAL));
}

@Override
Expand All @@ -132,22 +224,32 @@ public Iterator<Pair<String, Provenance>> iterator() {

list.add(new Pair<>(CLASS_NAME,className));
list.add(new Pair<>(SOURCES,provenances));
list.add(new Pair<>(ORDER,getOrder()));

return list.iterator();
}

private EnumProvenance<IterationOrder> getOrder() {
if (orderProvenance != null) {
return orderProvenance;
} else {
return new EnumProvenance<>(ORDER,IterationOrder.SEQUENTIAL);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof AggregateDataSourceProvenance)) return false;
AggregateDataSourceProvenance pairs = (AggregateDataSourceProvenance) o;
return className.equals(pairs.className) &&
provenances.equals(pairs.provenances);
provenances.equals(pairs.provenances) &&
getOrder().equals(pairs.getOrder());
}

@Override
public int hashCode() {
return Objects.hash(className, provenances);
return Objects.hash(className, provenances, getOrder());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ public ModelProvenance(Map<String,Provenance> map) {
this.time = ObjectProvenance.checkAndExtractProvenance(map,TRAINING_TIME,DateTimeProvenance.class, ModelProvenance.class.getSimpleName()).getValue();
this.instanceProvenance = (MapProvenance<?>) ObjectProvenance.checkAndExtractProvenance(map,INSTANCE_VALUES,MapProvenance.class, ModelProvenance.class.getSimpleName());
this.versionString = ObjectProvenance.checkAndExtractProvenance(map,TRIBUO_VERSION_STRING,StringProvenance.class,ModelProvenance.class.getSimpleName()).getValue();

// TODO fix this when we upgrade OLCUT.
this.javaVersionString = maybeExtractProvenance(map,JAVA_VERSION_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.osString = maybeExtractProvenance(map,OS_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
this.archString = maybeExtractProvenance(map,ARCH_STRING,StringProvenance.class).map(StringProvenance::getValue).orElse(UNKNOWN_VERSION);
Expand All @@ -164,6 +166,8 @@ public ModelProvenance(Map<String,Provenance> map) {
/**
* Like {@link ObjectProvenance#checkAndExtractProvenance(Map, String, Class, String)} but doesn't
* throw if it fails to find the key, only if the value is of the wrong type.
*
* @deprecated Deprecated as it's in OLCUT.
* @param map The map to inspect.
* @param key The key to find.
* @param type The class of the value.
Expand All @@ -172,7 +176,8 @@ public ModelProvenance(Map<String,Provenance> map) {
* @throws ProvenanceException If the value is the wrong type.
*/
@SuppressWarnings("unchecked") // Guarded by isInstance check
private static <T extends Provenance> Optional<T> maybeExtractProvenance(Map<String,Provenance> map, String key, Class<T> type) throws ProvenanceException {
@Deprecated
public static <T extends Provenance> Optional<T> maybeExtractProvenance(Map<String,Provenance> map, String key, Class<T> type) throws ProvenanceException {
Provenance tmp = map.remove(key);
if (tmp != null) {
if (type.isInstance(tmp)) {
Expand Down
Loading

0 comments on commit 1f832c9

Please sign in to comment.