Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for multioutputs to ResponseProcesser, with tests. #150

Merged
merged 7 commits into from
Jul 15, 2021
31 changes: 30 additions & 1 deletion Core/src/test/java/org/tribuo/test/MockMultiOutputFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.tribuo.test;

import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.MutableOutputInfo;
Expand Down Expand Up @@ -84,7 +85,7 @@ public boolean equals(Object obj) {

@Override
public OutputFactoryProvenance getProvenance() {
return null;
return new MockMultiOutputFactoryProvenance();
}

/**
Expand Down Expand Up @@ -163,4 +164,32 @@ public static MockMultiOutput createFromPairList(List<Pair<String,Boolean>> dime
}
return new MockMultiOutput(labels);
}

public static class MockMultiOutputFactoryProvenance implements OutputFactoryProvenance {
private static final long serialVersionUID=1L;

MockMultiOutputFactoryProvenance() {}

public MockMultiOutputFactoryProvenance(Map<String, Provenance> map) {}

@Override
public String getClassName() {
return MockMultiOutputFactory.class.getName();
}

@Override
public String toString() {
return generateString("MockMultiOutputFactory");
}

@Override
public boolean equals(Object other) {
return other instanceof MockMultiOutputFactoryProvenance;
}

@Override
public int hashCode() {
return 32;
}
}
}
22 changes: 22 additions & 0 deletions Data/src/main/java/org/tribuo/data/columnar/ResponseProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tribuo.Output;
import org.tribuo.OutputFactory;

import java.util.List;
import java.util.Optional;

/**
Expand All @@ -39,6 +40,7 @@ public interface ResponseProcessor<T extends Output<T>> extends Configurable, Pr
* Gets the field name this ResponseProcessor uses.
* @return The field name.
*/
@Deprecated
public String getFieldName();

/**
Expand All @@ -55,5 +57,25 @@ public interface ResponseProcessor<T extends Output<T>> extends Configurable, Pr
* @param value The value to process.
* @return The response value if found.
*/
@Deprecated()
JackSullivan marked this conversation as resolved.
Show resolved Hide resolved
public Optional<T> process(String value);

/**
* Returns Optional.empty() if it failed to process out a response.
JackSullivan marked this conversation as resolved.
Show resolved Hide resolved
* @param values The value to process.
* @return The response values if found.
*/
default Optional<T> process(List<String> values) {
if (values.size() != 1) {
throw new IllegalArgumentException(getClass().getSimpleName() + " does not implement support for multiple response values");
} else {
return process(values.get(0));
}
}

/**
* Gets the field names this ResponseProcessor uses.
* @return The field name.
JackSullivan marked this conversation as resolved.
Show resolved Hide resolved
*/
List<String> getFieldNames();
}
7 changes: 5 additions & 2 deletions Data/src/main/java/org/tribuo/data/columnar/RowProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,11 @@ public Set<FeatureProcessor> getFeatureProcessors() {
* @return An Optional containing an Example if the row was valid, an empty Optional otherwise.
*/
public Optional<Example<T>> generateExample(ColumnarIterator.Row row, boolean outputRequired) {
String responseValue = row.getRowData().get(responseProcessor.getFieldName());
Optional<T> labelOpt = responseProcessor.process(responseValue);

List<String> responseValues = responseProcessor.getFieldNames().stream()
.map(f -> row.getRowData().getOrDefault(f, ""))
.collect(Collectors.toList());
Optional<T> labelOpt = responseProcessor.process(responseValues);
if (!labelOpt.isPresent() && outputRequired) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
package org.tribuo.data.columnar.processors.response;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.data.columnar.ResponseProcessor;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
Expand All @@ -32,13 +36,14 @@
*/
public class BinaryResponseProcessor<T extends Output<T>> implements ResponseProcessor<T> {

@Config(mandatory = true,description="The field name to read.")
@Config(description="The field name to read, you should use only one of this or fieldNames")
@Deprecated
private String fieldName;

@Config(mandatory = true,description="The string which triggers a positive response.")
@Config(description="The string which triggers a positive response.")
private String positiveResponse;

@Config(mandatory = true,description="Output factory to use to create the response.")
@Config(description="Output factory to use to create the response.")
JackSullivan marked this conversation as resolved.
Show resolved Hide resolved
private OutputFactory<T> outputFactory;

@Config(description="The positive response to emit.")
Expand All @@ -47,6 +52,36 @@ public class BinaryResponseProcessor<T extends Output<T>> implements ResponsePro
@Config(description="The negative response to emit.")
private String negativeName = "0";

@Config(description = "A list of field names to read, you should use only one of this or fieldName.")
private List<String> fieldNames;

@Config(description = "A list of strings that trigger positive responses; it should be the same length as fieldNames or empty")
private List<String> positiveResponses;

@Config(description = "Whether to display field names as part of the generated label, defaults to false")
JackSullivan marked this conversation as resolved.
Show resolved Hide resolved
private boolean displayField = false;


@Override
public void postConfig() {
if (fieldName != null && fieldNames != null) { // we can only have one path
throw new PropertyException("fieldName, FieldNames", "only one of fieldName or fieldNames can be populated");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first argument to the three string constructor for PropertyException is supposed to be the name of the object in the config file. In practice I don't bother putting in the ConfigurableName annotation just to get that information, though if I'd thought more clearly about it I guess it could have been an argument to postConfig. Either way, you should probably set it to the empty string to mirror the rest of the places in Tribuo where we don't have this information and throw PropertyException.

} else if (fieldNames != null) {
positiveResponses = positiveResponses == null ? Collections.nCopies(fieldNames.size(), positiveResponse) : positiveResponses;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should check that positiveResponse is non-null before copying it.

if(positiveResponses.size() != fieldNames.size()) {
throw new PropertyException("positiveResponses", "must either be empty or match the length of fieldNames");
}
} else if (fieldName != null) {
if(positiveResponses != null) {
throw new PropertyException("positiveResponses", "if fieldName is populated, positiveResponses must be blank");
}
fieldNames = Collections.singletonList(fieldName);
positiveResponses = Collections.singletonList(positiveName);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is positiveName not positiveResponse? Also it should check that it's not null.

} else {
throw new PropertyException("fieldName, fieldNames", "One of fieldName or fieldNames must be populated");
}
}

/**
* for OLCUT.
*/
Expand All @@ -60,9 +95,66 @@ private BinaryResponseProcessor() {}
* @param outputFactory The output factory to use.
*/
public BinaryResponseProcessor(String fieldName, String positiveResponse, OutputFactory<T> outputFactory) {
this.fieldName = fieldName;
this.positiveResponse = positiveResponse;
this(Collections.singletonList(fieldName), positiveResponse, outputFactory);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values.
* @param fieldNames The field names to read.
* @param positiveResponse The positive response to look for.
* @param outputFactory The output factory to use.
*/
public BinaryResponseProcessor(List<String> fieldNames, String positiveResponse, OutputFactory<T> outputFactory) {
this(fieldNames, Collections.nCopies(fieldNames.size(), positiveResponse), outputFactory);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. the lengths of fieldNames and positiveResponses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/the/The/

* must be the same.
* @param fieldNames The field names to read.
* @param positiveResponses The positive responses to look for.
* @param outputFactory The output factory to use.
*/
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory) {
this(fieldNames, positiveResponses, outputFactory, false);
}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. the lengths of fieldNames and positiveResponses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/the/The/

* must be the same.
* @param fieldNames The field names to read.
* @param positiveResponses The positive responses to look for.
* @param outputFactory The output factory to use.
* @param displayField whether to include field names in the generated labels.
*/
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory, boolean displayField) {
this(fieldNames, positiveResponses, outputFactory, "1", "0", displayField);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new constructors should probably document what they default positiveName and negativeName to, which probably implies that "1" and "0" should be public static final constants so you can easily refer to them from the Javadoc.

}

/**
* Constructs a binary response processor which emits a positive value for a single string
* and a negative value for all other field values. the lengths of fieldNames and positiveResponses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/the/The/

* must be the same.
* @param fieldNames The field names to read.
* @param positiveResponses The positive responses to look for.
* @param outputFactory The output factory to use.
* @param positiveName The value of a 'positive' output
* @param negativeName the value of a 'negative' output
* @param displayField whether to include field names in the generated labels.
*/
public BinaryResponseProcessor(List<String> fieldNames, List<String> positiveResponses, OutputFactory<T> outputFactory, String positiveName, String negativeName, boolean displayField) {
if(fieldNames.size() != positiveResponses.size()) {
throw new IllegalArgumentException("fieldNames and positiveResponses must be the same length");
}
this.fieldNames = fieldNames;
this.positiveResponses = positiveResponses;
this.outputFactory = outputFactory;
this.positiveName = positiveName;
this.negativeName = negativeName;
this.displayField = displayField;
}

@Override
Expand All @@ -72,7 +164,7 @@ public OutputFactory<T> getOutputFactory() {

@Override
public String getFieldName() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we tag this method deprecated too? I think that means that Maven doesn't warn you for using a deprecated method.

return fieldName;
return fieldNames.get(0);
}

@Deprecated
Expand All @@ -83,12 +175,30 @@ public void setFieldName(String fieldName) {

@Override
public Optional<T> process(String value) {
return Optional.of(outputFactory.generateOutput(positiveResponse.equals(value) ? positiveName : negativeName));
return process(Collections.singletonList(value));
}

@Override
public Optional<T> process(List<String> values) {
List<String> responses = new ArrayList<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should check that values.size()==fieldNames.size()?

String prefix = "";
for(int i=0; i < fieldNames.size(); i++) {
if(displayField) {
prefix = fieldNames.get(i) + "=";
}
responses.add(prefix + (positiveResponses.get(i).equals(values.get(i)) ? positiveName : negativeName));
}
return Optional.of(outputFactory.generateOutput(fieldNames.size() == 1 ? responses.get(0) : responses));
}

@Override
public List<String> getFieldNames() {
return fieldNames;
}

@Override
public String toString() {
return "BinaryResponseProcessor(fieldName="+ fieldName +", positiveResponse="+ positiveResponse +", positiveName="+positiveName +", negativeName="+negativeName+")";
return "BinaryResponseProcessor(fieldNames="+ fieldNames.toString() +", positiveResponses="+ positiveResponses.toString() +", positiveName="+positiveName +", negativeName="+negativeName+")";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.tribuo.OutputFactory;
import org.tribuo.data.columnar.ResponseProcessor;

import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
Expand Down Expand Up @@ -84,6 +86,16 @@ public Optional<T> process(String value) {
return Optional.empty();
}

@Override
public Optional<T> process(List<String> values) {
return Optional.empty();
}

@Override
public List<String> getFieldNames() {
return Collections.singletonList(FIELD_NAME);
}

@Override
public String toString() {
return "EmptyResponseProcessor(outputFactory="+outputFactory.toString()+")";
Expand Down
Loading