From acc4edf90764dd8a2ed15696c3f31808e872e9a4 Mon Sep 17 00:00:00 2001 From: John Sullivan Date: Mon, 12 Jul 2021 15:48:14 -0400 Subject: [PATCH 1/7] Added support for multioutputs to ResponseProcesser, with tests. This should allow us to convert CSVLoader to use CSVDataSource as its backend, increasing consistency and simplicity. --- .../tribuo/test/MockMultiOutputFactory.java | 31 +++- .../data/columnar/ResponseProcessor.java | 22 +++ .../tribuo/data/columnar/RowProcessor.java | 9 +- .../response/BinaryResponseProcessor.java | 109 +++++++++++++- .../response/EmptyResponseProcessor.java | 12 ++ .../response/FieldResponseProcessor.java | 114 ++++++++++++-- .../response/QuartileResponseProcessor.java | 87 ++++++++++- .../data/columnar/MockResponseProcessor.java | 7 + .../response/EmptyResponseProcessorTest.java | 2 +- .../MultioutputResponseProcessorTest.java | 140 ++++++++++++++++++ 10 files changed, 500 insertions(+), 33 deletions(-) create mode 100644 Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java diff --git a/Core/src/test/java/org/tribuo/test/MockMultiOutputFactory.java b/Core/src/test/java/org/tribuo/test/MockMultiOutputFactory.java index fd1b3c17c..0321cc1fc 100644 --- a/Core/src/test/java/org/tribuo/test/MockMultiOutputFactory.java +++ b/Core/src/test/java/org/tribuo/test/MockMultiOutputFactory.java @@ -16,6 +16,7 @@ package org.tribuo.test; +import com.oracle.labs.mlrg.olcut.provenance.Provenance; import com.oracle.labs.mlrg.olcut.util.Pair; import org.tribuo.ImmutableOutputInfo; import org.tribuo.MutableOutputInfo; @@ -84,7 +85,7 @@ public boolean equals(Object obj) { @Override public OutputFactoryProvenance getProvenance() { - return null; + return new MockMultiOutputFactoryProvenance(); } /** @@ -163,4 +164,32 @@ public static MockMultiOutput createFromPairList(List> dime } return new MockMultiOutput(labels); } + + public static class MockMultiOutputFactoryProvenance implements OutputFactoryProvenance { + private static final long serialVersionUID=1L; + + MockMultiOutputFactoryProvenance() {} + + public MockMultiOutputFactoryProvenance(Map map) {} + + @Override + public String getClassName() { + return MockMultiOutputFactory.class.getName(); + } + + @Override + public String toString() { + return generateString("MockMultiOutputFactory"); + } + + @Override + public boolean equals(Object other) { + return other instanceof MockMultiOutputFactoryProvenance; + } + + @Override + public int hashCode() { + return 32; + } + } } \ No newline at end of file diff --git a/Data/src/main/java/org/tribuo/data/columnar/ResponseProcessor.java b/Data/src/main/java/org/tribuo/data/columnar/ResponseProcessor.java index a5e06b578..eb045b344 100644 --- a/Data/src/main/java/org/tribuo/data/columnar/ResponseProcessor.java +++ b/Data/src/main/java/org/tribuo/data/columnar/ResponseProcessor.java @@ -22,6 +22,7 @@ import org.tribuo.Output; import org.tribuo.OutputFactory; +import java.util.List; import java.util.Optional; /** @@ -39,6 +40,7 @@ public interface ResponseProcessor> extends Configurable, Pr * Gets the field name this ResponseProcessor uses. * @return The field name. */ + @Deprecated public String getFieldName(); /** @@ -55,5 +57,25 @@ public interface ResponseProcessor> extends Configurable, Pr * @param value The value to process. * @return The response value if found. */ + @Deprecated() public Optional process(String value); + + /** + * Returns Optional.empty() if it failed to process out a response. + * @param values The value to process. + * @return The response values if found. + */ + default Optional process(List values) { + if (values.size() != 1) { + throw new IllegalArgumentException(getClass().getSimpleName() + " does not implement support for multiple response values"); + } else { + return process(values.get(0)); + } + } + + /** + * Gets the field names this ResponseProcessor uses. + * @return The field name. + */ + List getFieldNames(); } diff --git a/Data/src/main/java/org/tribuo/data/columnar/RowProcessor.java b/Data/src/main/java/org/tribuo/data/columnar/RowProcessor.java index 42b94e7f2..5e5e032eb 100644 --- a/Data/src/main/java/org/tribuo/data/columnar/RowProcessor.java +++ b/Data/src/main/java/org/tribuo/data/columnar/RowProcessor.java @@ -289,8 +289,13 @@ public Set getFeatureProcessors() { * @return An Optional containing an Example if the row was valid, an empty Optional otherwise. */ public Optional> generateExample(ColumnarIterator.Row row, boolean outputRequired) { - String responseValue = row.getRowData().get(responseProcessor.getFieldName()); - Optional labelOpt = responseProcessor.process(responseValue); + //String responseValue = row.getRowData().get(responseProcessor.getFieldName()); + //Optional labelOpt = responseProcessor.process(responseValue); + + List responseValues = responseProcessor.getFieldNames().stream() + .map(f -> row.getRowData().get(f)) + .collect(Collectors.toList()); + Optional labelOpt = responseProcessor.process(responseValues); if (!labelOpt.isPresent() && outputRequired) { return Optional.empty(); } diff --git a/Data/src/main/java/org/tribuo/data/columnar/processors/response/BinaryResponseProcessor.java b/Data/src/main/java/org/tribuo/data/columnar/processors/response/BinaryResponseProcessor.java index e281f6221..24ffe62ba 100644 --- a/Data/src/main/java/org/tribuo/data/columnar/processors/response/BinaryResponseProcessor.java +++ b/Data/src/main/java/org/tribuo/data/columnar/processors/response/BinaryResponseProcessor.java @@ -17,12 +17,16 @@ package org.tribuo.data.columnar.processors.response; import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; import org.tribuo.Output; import org.tribuo.OutputFactory; import org.tribuo.data.columnar.ResponseProcessor; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Optional; /** @@ -32,13 +36,14 @@ */ public class BinaryResponseProcessor> implements ResponseProcessor { - @Config(mandatory = true,description="The field name to read.") + @Config(description="The field name to read, you should use only one of this or fieldNames") + @Deprecated private String fieldName; - @Config(mandatory = true,description="The string which triggers a positive response.") + @Config(description="The string which triggers a positive response.") private String positiveResponse; - @Config(mandatory = true,description="Output factory to use to create the response.") + @Config(description="Output factory to use to create the response.") private OutputFactory outputFactory; @Config(description="The positive response to emit.") @@ -47,6 +52,36 @@ public class BinaryResponseProcessor> implements ResponsePro @Config(description="The negative response to emit.") private String negativeName = "0"; + @Config(description = "A list of field names to read, you should use only one of this or fieldName.") + private List fieldNames; + + @Config(description = "A list of strings that trigger positive responses; it should be the same length as fieldNames or empty") + private List positiveResponses; + + @Config(description = "Whether to display field names as part of the generated label, defaults to false") + private boolean displayField = false; + + + @Override + public void postConfig() { + if (fieldName != null && fieldNames != null) { // we can only have one path + throw new PropertyException("fieldName, FieldNames", "only one of fieldName or fieldNames can be populated"); + } else if (fieldNames != null) { + positiveResponses = positiveResponses == null ? Collections.nCopies(fieldNames.size(), positiveResponse) : positiveResponses; + if(positiveResponses.size() != fieldNames.size()) { + throw new PropertyException("positiveResponses", "must either be empty or match the length of fieldNames"); + } + } else if (fieldName != null) { + if(positiveResponses != null) { + throw new PropertyException("positiveResponses", "if fieldName is populated, positiveResponses must be blank"); + } + fieldNames = Collections.singletonList(fieldName); + positiveResponses = Collections.singletonList(positiveName); + } else { + throw new PropertyException("fieldName, fieldNames", "One of fieldName or fieldNames must be populated"); + } + } + /** * for OLCUT. */ @@ -60,9 +95,49 @@ private BinaryResponseProcessor() {} * @param outputFactory The output factory to use. */ public BinaryResponseProcessor(String fieldName, String positiveResponse, OutputFactory outputFactory) { - this.fieldName = fieldName; - this.positiveResponse = positiveResponse; + this(Collections.singletonList(fieldName), positiveResponse, outputFactory); + } + + /** + * Constructs a binary response processor which emits a positive value for a single string + * and a negative value for all other field values. + * @param fieldNames The field names to read. + * @param positiveResponse The positive response to look for. + * @param outputFactory The output factory to use. + */ + public BinaryResponseProcessor(List fieldNames, String positiveResponse, OutputFactory outputFactory) { + this(fieldNames, Collections.nCopies(fieldNames.size(), positiveResponse), outputFactory); + } + + /** + * Constructs a binary response processor which emits a positive value for a single string + * and a negative value for all other field values. the lengths of fieldNames and positiveResponses + * must be the same. + * @param fieldNames The field names to read. + * @param positiveResponses The positive responses to look for. + * @param outputFactory The output factory to use. + */ + public BinaryResponseProcessor(List fieldNames, List positiveResponses, OutputFactory outputFactory) { + this(fieldNames, positiveResponses, outputFactory, false); + } + + /** + * Constructs a binary response processor which emits a positive value for a single string + * and a negative value for all other field values. the lengths of fieldNames and positiveResponses + * must be the same. + * @param fieldNames The field names to read. + * @param positiveResponses The positive responses to look for. + * @param outputFactory The output factory to use. + * @param displayField whether to include field names in the generated labels. + */ + public BinaryResponseProcessor(List fieldNames, List positiveResponses, OutputFactory outputFactory, boolean displayField) { + if(fieldNames.size() != positiveResponses.size()) { + throw new IllegalArgumentException("fieldNames and positiveResponses must be the same length"); + } + this.fieldNames = fieldNames; + this.positiveResponses = positiveResponses; this.outputFactory = outputFactory; + this.displayField = displayField; } @Override @@ -72,7 +147,7 @@ public OutputFactory getOutputFactory() { @Override public String getFieldName() { - return fieldName; + return fieldNames.get(0); } @Deprecated @@ -83,12 +158,30 @@ public void setFieldName(String fieldName) { @Override public Optional process(String value) { - return Optional.of(outputFactory.generateOutput(positiveResponse.equals(value) ? positiveName : negativeName)); + return process(Collections.singletonList(value)); + } + + @Override + public Optional process(List values) { + List responses = new ArrayList<>(); + String prefix = ""; + for(int i=0; i < values.size(); i++) { + if(displayField) { + prefix = fieldNames.get(i) + ":"; + } + responses.add(prefix + (positiveResponses.get(i).equals(values.get(i)) ? positiveName : negativeName)); + } + return Optional.of(outputFactory.generateOutput(fieldNames.size() == 1 ? responses.get(0) : responses)); + } + + @Override + public List getFieldNames() { + return fieldNames; } @Override public String toString() { - return "BinaryResponseProcessor(fieldName="+ fieldName +", positiveResponse="+ positiveResponse +", positiveName="+positiveName +", negativeName="+negativeName+")"; + return "BinaryResponseProcessor(fieldNames="+ fieldNames.toString() +", positiveResponses="+ positiveResponses.toString() +", positiveName="+positiveName +", negativeName="+negativeName+")"; } @Override diff --git a/Data/src/main/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessor.java b/Data/src/main/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessor.java index 153e10e73..4bcd5a42a 100644 --- a/Data/src/main/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessor.java +++ b/Data/src/main/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessor.java @@ -23,6 +23,8 @@ import org.tribuo.OutputFactory; import org.tribuo.data.columnar.ResponseProcessor; +import java.util.Collections; +import java.util.List; import java.util.Optional; /** @@ -84,6 +86,16 @@ public Optional process(String value) { return Optional.empty(); } + @Override + public Optional process(List values) { + return Optional.empty(); + } + + @Override + public List getFieldNames() { + return Collections.singletonList(FIELD_NAME); + } + @Override public String toString() { return "EmptyResponseProcessor(outputFactory="+outputFactory.toString()+")"; diff --git a/Data/src/main/java/org/tribuo/data/columnar/processors/response/FieldResponseProcessor.java b/Data/src/main/java/org/tribuo/data/columnar/processors/response/FieldResponseProcessor.java index dc03a5330..14d4c0411 100644 --- a/Data/src/main/java/org/tribuo/data/columnar/processors/response/FieldResponseProcessor.java +++ b/Data/src/main/java/org/tribuo/data/columnar/processors/response/FieldResponseProcessor.java @@ -17,12 +17,16 @@ package org.tribuo.data.columnar.processors.response; import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; import org.tribuo.Output; import org.tribuo.OutputFactory; import org.tribuo.data.columnar.ResponseProcessor; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Optional; /** @@ -30,15 +34,45 @@ */ public class FieldResponseProcessor> implements ResponseProcessor { - @Config(mandatory = true,description="The field name to read.") + @Config(description="The field name to read.") + @Deprecated private String fieldName; - @Config(mandatory = true,description="Default value to return if one isn't found.") + @Config(description="Default value to return if one isn't found.") private String defaultValue; - @Config(mandatory = true,description="The output factory to use.") + @Config(description="The output factory to use.") private OutputFactory outputFactory; + @Config(description = "A list of field names to read, you should use only one of this or fieldName.") + private List fieldNames; + + @Config(description = "A list of default values to return if one isn't found, one for each field") + private List defaultValues; + + @Config(description = "Whether to display field names as part of the generated label, defaults to false") + private boolean displayField = false; + + @Override + public void postConfig() { + if (fieldName != null && fieldNames != null) { + throw new PropertyException("fieldName, FieldNames", "only one of fieldName or fieldNames can be populated"); + } else if (fieldNames != null) { + defaultValues = defaultValues == null ? Collections.nCopies(fieldNames.size(), defaultValue) : defaultValues; + if(defaultValues.size() != fieldNames.size()) { + throw new PropertyException("defaultValues", "must either be empty or match the length of fieldNames"); + } + } else if (fieldName != null) { + if(defaultValues != null) { + throw new PropertyException("defaultValues", "if fieldName is populated, defaultValues must be blank"); + } + fieldNames = Collections.singletonList(fieldName); + defaultValues = Collections.singletonList(defaultValue); + } else { + throw new PropertyException("fieldName, fieldNames", "One of fieldName or fieldNames must be populated"); + } + } + /** * For olcut. */ @@ -52,9 +86,47 @@ private FieldResponseProcessor() {} * @param outputFactory The output factory to use. */ public FieldResponseProcessor(String fieldName, String defaultValue, OutputFactory outputFactory) { - this.fieldName = fieldName; - this.defaultValue = defaultValue; + this(Collections.singletonList(fieldName), defaultValue, outputFactory); + } + + /** + * Constructs a response processor which passes the field value through the + * output factory. + * @param fieldNames The fields to read. + * @param defaultValue The default value to extract if it's not found. + * @param outputFactory The output factory to use. + */ + public FieldResponseProcessor(List fieldNames, String defaultValue, OutputFactory outputFactory) { + this(fieldNames, Collections.nCopies(fieldNames.size(), defaultValue), outputFactory); + } + + /** + * Constructs a response processor which passes the field value through the + * output factory. fieldNames and defaultValues must be the same length. + * @param fieldNames The field to read. + * @param defaultValues The default value to extract if it's not found. + * @param outputFactory The output factory to use. + */ + public FieldResponseProcessor(List fieldNames, List defaultValues, OutputFactory outputFactory) { + this(fieldNames, defaultValues, outputFactory, false); + } + + /** + * Constructs a response processor which passes the field value through the + * output factory. fieldNames and defaultValues must be the same length. + * @param fieldNames The field to read. + * @param defaultValues The default value to extract if it's not found. + * @param outputFactory The output factory to use. + * @param displayField whether to include field names in the generated labels. + */ + public FieldResponseProcessor(List fieldNames, List defaultValues, OutputFactory outputFactory, boolean displayField) { + if(fieldNames.size() != defaultValues.size()) { + throw new IllegalArgumentException("fieldNames and defaultValues must be the same length"); + } + this.fieldNames = fieldNames; + this.defaultValues = defaultValues; this.outputFactory = outputFactory; + this.displayField = displayField; } @Deprecated @@ -70,27 +142,37 @@ public OutputFactory getOutputFactory() { @Override public String getFieldName() { - return fieldName; + return fieldNames.get(0); } @Override public Optional process(String value) { - String val = value == null ? defaultValue : value; - if (val != null) { - val = val.toUpperCase().trim(); - if (val.isEmpty()) { - return Optional.empty(); - } else{ - return Optional.of(outputFactory.generateOutput(val)); + return process(Collections.singletonList(value)); + } + + @Override + public Optional process(List values) { + List responses = new ArrayList<>(); + String prefix = ""; + for(int i=0; i < values.size(); i++) { + if (displayField) { + prefix = fieldNames.get(i) + ":"; } - } else { - return Optional.empty(); + String val = values.get(i).toUpperCase().trim(); + val = val.isEmpty() ? defaultValues.get(i) : val; + responses.add(prefix + val); } + return Optional.of(outputFactory.generateOutput(fieldNames.size() == 1 ? responses.get(0) : responses)); + } + + @Override + public List getFieldNames() { + return fieldNames; } @Override public String toString() { - return "FieldResponseProcessor(fieldName="+ fieldName +")"; + return "FieldResponseProcessor(fieldNames="+ fieldNames.toString() +")"; } @Override diff --git a/Data/src/main/java/org/tribuo/data/columnar/processors/response/QuartileResponseProcessor.java b/Data/src/main/java/org/tribuo/data/columnar/processors/response/QuartileResponseProcessor.java index fef677a59..b8db95ca0 100644 --- a/Data/src/main/java/org/tribuo/data/columnar/processors/response/QuartileResponseProcessor.java +++ b/Data/src/main/java/org/tribuo/data/columnar/processors/response/QuartileResponseProcessor.java @@ -17,13 +17,19 @@ package org.tribuo.data.columnar.processors.response; import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; import org.tribuo.Output; import org.tribuo.OutputFactory; import org.tribuo.data.columnar.ResponseProcessor; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; /** * Processes the response into quartiles and emits them as classification outputs. @@ -33,9 +39,11 @@ public class QuartileResponseProcessor> implements ResponseProcessor { @Config(mandatory = true,description="The string to emit.") + @Deprecated private String name; @Config(mandatory = true,description="The field name to read.") + @Deprecated private String fieldName; @Config(mandatory = true,description="The quartile to use.") @@ -44,13 +52,39 @@ public class QuartileResponseProcessor> implements ResponseP @Config(mandatory = true,description="The output factory to use.") private OutputFactory outputFactory; + @Config(description = "A list of field names to read, you should use only one of this or fieldName.") + private List fieldNames; + + @Config(description = "A list of quartiles to use, should have the same length as fieldNames") + private List quartiles; + + @Override + public void postConfig() throws PropertyException, IOException { + if (fieldName != null && fieldNames != null) { + throw new PropertyException("fieldName, FieldNames", "only one of fieldName or fieldNames can be populated"); + } else if (fieldNames != null) { + quartiles = quartiles == null ? Collections.nCopies(fieldNames.size(), quartile) : quartiles; + if(quartiles.size() != fieldNames.size()) { + throw new PropertyException("quartiles", "must either be empty or match the length of fieldNames"); + } + } else if (fieldName != null) { + if (quartiles != null) { + throw new PropertyException("quartiles", "if fieldName is populated, quartiles must be blank"); + } + fieldNames = Collections.singletonList(fieldName); + quartiles = Collections.singletonList(quartile); + } else { + throw new PropertyException("fieldName, fieldNames", "One of fieldName or fieldNames must be populated"); + } + } + /** * For olcut. */ private QuartileResponseProcessor() {} /** - * Constructs a repsonse processor which emits 4 distinct bins for the output factory to process. + * Constructs a response processor which emits 4 distinct bins for the output factory to process. *

* This works best with classification outputs as the discrete binning is tricky to do in other output * types. @@ -60,9 +94,22 @@ private QuartileResponseProcessor() {} * @param outputFactory The output factory to use. */ public QuartileResponseProcessor(String name, String fieldName, Quartile quartile, OutputFactory outputFactory) { + this(Collections.singletonList(fieldName), Collections.singletonList(quartile), outputFactory); this.name = name; - this.fieldName = fieldName; - this.quartile = quartile; + } + + /** + * Constructs a response processor which emits 4 distinct bins for the output factory to process. + *

+ * This works best with classification outputs as the discrete binning is tricky to do in other output + * types. + * @param fieldNames The field to read. + * @param quartiles The quartile range to use. + * @param outputFactory The output factory to use. + */ + public QuartileResponseProcessor(List fieldNames, List quartiles, OutputFactory outputFactory) { + this.fieldNames = fieldNames; + this.quartiles = quartiles; this.outputFactory = outputFactory; } @@ -79,7 +126,7 @@ public OutputFactory getOutputFactory() { @Override public String getFieldName() { - return fieldName; + return fieldNames.get(0); } @Override @@ -101,9 +148,39 @@ public Optional process(String value) { return Optional.of(output); } + @Override + public Optional process(List values) { + List response = new ArrayList<>(); + for(int i=0; i< values.size(); i++) { + String value = values.get(i); + String prefix = name == null || name.isEmpty() ? fieldNames.get(i) : getFieldName(); + Quartile q = quartiles.get(i); + if(value == null) { + response.add(prefix + ":NONE"); + } else { + double dv = Double.parseDouble(value); + if (dv <= q.getLowerMedian()) { + response.add(prefix + ":first"); + } else if (dv > q.getLowerMedian() && dv <= q.getMedian()) { + response.add(prefix + ":second"); + } else if (dv > q.getMedian() && dv <= q.getUpperMedian()) { + response.add(prefix + ":third"); + } else { + response.add(prefix + ":fourth"); + } + } + } + return Optional.of(outputFactory.generateOutput(response.size() == 1 ? response.get(0) : response)); + } + + @Override + public List getFieldNames() { + return fieldNames; + } + @Override public String toString() { - return "QuartileResponseProcessor(fieldName="+ fieldName +",quartile="+quartile.toString()+")"; + return "QuartileResponseProcessor(fieldNames="+ fieldNames.toString() +",quartiles=" + quartiles.stream().map(Quartile::toString).collect(Collectors.toList()) + ")"; } @Override diff --git a/Data/src/test/java/org/tribuo/data/columnar/MockResponseProcessor.java b/Data/src/test/java/org/tribuo/data/columnar/MockResponseProcessor.java index a08826b10..e92062419 100644 --- a/Data/src/test/java/org/tribuo/data/columnar/MockResponseProcessor.java +++ b/Data/src/test/java/org/tribuo/data/columnar/MockResponseProcessor.java @@ -23,6 +23,8 @@ import org.tribuo.test.MockOutput; import org.tribuo.test.MockOutputFactory; +import java.util.Collections; +import java.util.List; import java.util.Optional; /** @@ -60,6 +62,11 @@ public Optional process(String value) { return Optional.of(new MockOutput(value)); } + @Override + public List getFieldNames() { + return Collections.singletonList(fieldName); + } + @Override public ConfiguredObjectProvenance getProvenance() { return new ConfiguredObjectProvenanceImpl(this,"ResponseProcessor"); diff --git a/Data/src/test/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessorTest.java b/Data/src/test/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessorTest.java index d1deba0c2..e347e2ecd 100644 --- a/Data/src/test/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessorTest.java +++ b/Data/src/test/java/org/tribuo/data/columnar/processors/response/EmptyResponseProcessorTest.java @@ -28,7 +28,7 @@ public void basicTest() { Assertions.assertFalse(rp.process("!@$#$!").isPresent()); Assertions.assertFalse(rp.process("\n").isPresent()); Assertions.assertFalse(rp.process("\t").isPresent()); - Assertions.assertFalse(rp.process(null).isPresent()); + Assertions.assertFalse(rp.process((String) null).isPresent()); } } diff --git a/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java b/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java new file mode 100644 index 000000000..2709424df --- /dev/null +++ b/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java @@ -0,0 +1,140 @@ +package org.tribuo.data.columnar.processors.response; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.tribuo.Example; +import org.tribuo.Output; +import org.tribuo.data.columnar.FieldProcessor; +import org.tribuo.data.columnar.ResponseProcessor; +import org.tribuo.data.columnar.RowProcessor; +import org.tribuo.data.columnar.processors.field.IdentityProcessor; +import org.tribuo.data.csv.CSVDataSource; +import org.tribuo.test.MockMultiOutput; +import org.tribuo.test.MockMultiOutputFactory; +import org.tribuo.test.MockOutputFactory; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MultioutputResponseProcessorTest { + + private URI dataFile; + private Map fieldProcessors; + + + private static > RowProcessor makeRowProcessor(ResponseProcessor responseProcessor, Map fieldProcessors) { + return new RowProcessor(responseProcessor, fieldProcessors); + } + + private , Label> void doTest(ResponseProcessor responseProcessor, List

- * The emitted outputs are of the form {@code {:first, :second, :third, :fourth} }. + * The emitted outputs for each field are of the form: + * {@code {:first, :second, :third, :fourth} }. */ public class QuartileResponseProcessor> implements ResponseProcessor { + /** + * @deprecated This field causes issues with multidimensional outputs. + * When populated the emitted outputs are of the form: + * {@code {:first, :second, :third, :fourth} }. + */ @Config(mandatory = true,description="The string to emit.") @Deprecated private String name; - @Config(mandatory = true,description="The field name to read.") + @Config(description="The field name to read.") @Deprecated private String fieldName; - @Config(mandatory = true,description="The quartile to use.") + @Config(description="The quartile to use.") private Quartile quartile; @Config(mandatory = true,description="The output factory to use.") @@ -58,23 +64,34 @@ public class QuartileResponseProcessor> implements ResponseP @Config(description = "A list of quartiles to use, should have the same length as fieldNames") private List quartiles; + @ConfigurableName + private String configName; + @Override public void postConfig() throws PropertyException, IOException { if (fieldName != null && fieldNames != null) { - throw new PropertyException("fieldName, FieldNames", "only one of fieldName or fieldNames can be populated"); + throw new PropertyException(configName, "fieldName, FieldNames", "only one of fieldName or fieldNames can be populated"); } else if (fieldNames != null) { - quartiles = quartiles == null ? Collections.nCopies(fieldNames.size(), quartile) : quartiles; + if(quartile != null) { + quartiles = quartiles == null ? Collections.nCopies(fieldNames.size(), quartile) : quartiles; + } else { + throw new PropertyException(configName, "quartile, quartiles", "one of quartile or quartiles must be populated"); + } if(quartiles.size() != fieldNames.size()) { - throw new PropertyException("quartiles", "must either be empty or match the length of fieldNames"); + throw new PropertyException(configName, "quartiles", "must either be empty or match the length of fieldNames"); } } else if (fieldName != null) { if (quartiles != null) { - throw new PropertyException("quartiles", "if fieldName is populated, quartiles must be blank"); + throw new PropertyException(configName, "quartiles", "if fieldName is populated, quartiles must be blank"); } fieldNames = Collections.singletonList(fieldName); - quartiles = Collections.singletonList(quartile); + if(quartile != null) { + quartiles = Collections.singletonList(quartile); + } else { + throw new PropertyException(configName, "quartile", "if fieldName is populated, quartile must be populated"); + } } else { - throw new PropertyException("fieldName, fieldNames", "One of fieldName or fieldNames must be populated"); + throw new PropertyException(configName, "fieldName, fieldNames", "One of fieldName or fieldNames must be populated"); } } @@ -180,7 +197,7 @@ public List getFieldNames() { @Override public String toString() { - return "QuartileResponseProcessor(fieldNames="+ fieldNames.toString() +",quartiles=" + quartiles.stream().map(Quartile::toString).collect(Collectors.toList()) + ")"; + return "QuartileResponseProcessor(fieldNames="+ fieldNames.toString() +",quartiles=" + quartiles.toString() + ")"; } @Override diff --git a/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java b/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java index 9b939708f..ffcad2abe 100644 --- a/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java +++ b/Data/src/test/java/org/tribuo/data/columnar/processors/response/MultioutputResponseProcessorTest.java @@ -39,14 +39,14 @@ private static > RowProcessor makeRowProcessor(ResponsePr return new RowProcessor(responseProcessor, fieldProcessors); } - private , Label> void doTest(ResponseProcessor responseProcessor, List