Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] adding ensemble model objects #47241

Conversation

benwtrent
Copy link
Member

This adds the ensemble model object to core and HLRC.

Changes include:

  • Updating Tree model so that it can support regression and classification target types
  • Adding probability calculation for Tree and Ensemble
  • Adding Ensemble model (and to the HLRC)
  • Adding OutputAggregator models (WeightedSum, WeightedMode).
  • Adding classification_labels

@elasticmachine
Copy link
Collaborator

Pinging @elastic/ml-core

Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

List<TrainedModel> trainedModels;
OutputAggregator outputAggregator;
TargetType targetType;
List<String> classificationLabels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables are not private. It's not a big problem at the moment as the setters don't do any critical processing. But in the future if a setter did anything extra then there would be a way to bypass it. So unless there's a really good reason not to I'd make these private.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should totally be private. Text editing error.

true,
a -> new WeightedMode((List<Double>)a[0]));
static {
PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of the class assumes weights can be null. If it was then it couldn't be round-tripped through XContent and back, as this parser requires weights. I think it should be consistent: either enforce weights != null throughout or make the parser tolerate weights not being present.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its tricky for client side, we should probably be lenient, I will make it optional.

OutputAggregator outputAggregator;
TargetType targetType = TargetType.REGRESSION;
List<String> classificationLabels;
boolean modelsAreOrdered;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these variables be private?

}

private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
if ((outputAggregators.size() == 1) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputAggregators.size() != 1?

if (nodes.isEmpty()) {
return;
}
Set<Integer> visited = new HashSet<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this one also be initialized with nodes.size()?

}
}

private void detectNullOrMissingNode() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

null is a correct value and you continue in line 274 when you encounter it.
To me this method name suggests that nulls will also be reported as missing. Let me know if I misunderstood it or please rename.

}

private Double maxLeafValue() {
return targetType == TargetType.CLASSIFICATION ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered introducing some class hierarchy ("RegressionTree", "ClassificationTree", etc.) to avoid explicit checks against targetType? Just leaving this as an idea. You can decide if it makes the code more readable.
I'm just afraid the more differences there will be between regression and classification, the more ifs of this kind we'll need.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into it, this type of thing is a constant issue with OO style programming. Separating out the actions from the data.


public void testEnsembleWithInvalidModel() {
List<String> featureNames = Arrays.asList("foo", "bar");
expectThrows(ElasticsearchException.class, () -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any meaningful error message to assert on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly, ensemble could have ANY type of sub-model, I think we just want to make sure it is not considered valid.

// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be written as:

assertThat(tree.classificationProbability(featureMap), contains(0.0, 1.0));

if (outputAggregator != null) {
return outputAggregator.aggregate(processedInferences);
}
return processedInferences.stream().mapToDouble(Double::doubleValue).sum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also looks like an aggregation. Should it be wrapped in outputAggregator (possibly in the constructor so that outputAggregator is always non-null here)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@przemekwitek, I could add a default aggregator that is just a sum, the reason for a default is that even though the output aggregation is optional, we should still return something for inference, the default being a simple summation.


@Override
public List<Double> classificationProbability(Map<String, Object> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this check needed provided that this method delegates to another one (with a check) in line 134?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, I don't think we want to even parse the field input if we are classification.

@Override
public double aggregate(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
Optional<Double> summation = values.stream().reduce((memo, v) -> memo + v);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could provide Double::sum instead of (memo, v) -> memo + v.

if (summation.isPresent()) {
return summation.get();
}
throw new IllegalArgumentException("values must not contain null values");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When values are empty, summation will be empty as well, right? Then this message can be misleading.

…com:benwtrent/elasticsearch into feature/ml-inference-ensemble-model-parsing
Copy link
Contributor

@przemekwitek przemekwitek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@droberts195 droberts195 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benwtrent
Copy link
Member Author

@elasticmachine update branch

@benwtrent
Copy link
Member Author

run elasticsearch-ci/2

@benwtrent benwtrent merged commit af4e6ed into elastic:master Oct 1, 2019
@benwtrent benwtrent deleted the feature/ml-inference-ensemble-model-parsing branch October 1, 2019 18:18
benwtrent added a commit to benwtrent/elasticsearch that referenced this pull request Oct 2, 2019
* [ML][Inference] adding ensemble model objects

* addressing PR comments

* Update TreeTests.java

* addressing PR comments

* fixing test
benwtrent added a commit that referenced this pull request Oct 2, 2019
* [ML][Inference] adding ensemble model objects

* addressing PR comments

* Update TreeTests.java

* addressing PR comments

* fixing test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants