-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML][Inference] adding ensemble model objects #47241
[ML][Inference] adding ensemble model objects #47241
Conversation
Pinging @elastic/ml-core |
...test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java
Outdated
Show resolved
Hide resolved
...c/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java
Outdated
Show resolved
Hide resolved
...ugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java
Outdated
Show resolved
Hide resolved
...ck/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java
Show resolved
Hide resolved
...ugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java
Show resolved
Hide resolved
...gin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
Outdated
Show resolved
Hide resolved
...gin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java
Show resolved
Hide resolved
...h-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
List<TrainedModel> trainedModels; | ||
OutputAggregator outputAggregator; | ||
TargetType targetType; | ||
List<String> classificationLabels; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These variables are not private
. It's not a big problem at the moment as the setters don't do any critical processing. But in the future if a setter did anything extra then there would be a way to bypass it. So unless there's a really good reason not to I'd make these private
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should totally be private
. Text editing error.
true, | ||
a -> new WeightedMode((List<Double>)a[0])); | ||
static { | ||
PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rest of the class assumes weights
can be null
. If it was then it couldn't be round-tripped through XContent and back, as this parser requires weights
. I think it should be consistent: either enforce weights != null
throughout or make the parser tolerate weights
not being present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its tricky for client side, we should probably be lenient, I will make it optional.
OutputAggregator outputAggregator; | ||
TargetType targetType = TargetType.REGRESSION; | ||
List<String> classificationLabels; | ||
boolean modelsAreOrdered; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these variables be private
?
} | ||
|
||
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) { | ||
if ((outputAggregators.size() == 1) == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputAggregators.size() != 1
?
if (nodes.isEmpty()) { | ||
return; | ||
} | ||
Set<Integer> visited = new HashSet<>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this one also be initialized with nodes.size()
?
} | ||
} | ||
|
||
private void detectNullOrMissingNode() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
null
is a correct value and you continue
in line 274 when you encounter it.
To me this method name suggests that nulls will also be reported as missing. Let me know if I misunderstood it or please rename.
} | ||
|
||
private Double maxLeafValue() { | ||
return targetType == TargetType.CLASSIFICATION ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you considered introducing some class hierarchy ("RegressionTree", "ClassificationTree", etc.) to avoid explicit checks against targetType
? Just leaving this as an idea. You can decide if it makes the code more readable.
I'm just afraid the more differences there will be between regression and classification, the more if
s of this kind we'll need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will look into it, this type of thing is a constant issue with OO style programming. Separating out the actions from the data.
...ck/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java
Show resolved
Hide resolved
|
||
public void testEnsembleWithInvalidModel() { | ||
List<String> featureNames = Arrays.asList("foo", "bar"); | ||
expectThrows(ElasticsearchException.class, () -> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any meaningful error message to assert on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not particularly, ensemble could have ANY type of sub-model, I think we just want to make sure it is not considered valid.
// This feature vector should hit the right child of the root node | ||
List<Double> featureVector = Arrays.asList(0.6, 0.0); | ||
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); | ||
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be written as:
assertThat(tree.classificationProbability(featureMap), contains(0.0, 1.0));
if (outputAggregator != null) { | ||
return outputAggregator.aggregate(processedInferences); | ||
} | ||
return processedInferences.stream().mapToDouble(Double::doubleValue).sum(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also looks like an aggregation. Should it be wrapped in outputAggregator (possibly in the constructor so that outputAggregator is always non-null here)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@przemekwitek, I could add a default aggregator that is just a sum, the reason for a default is that even though the output aggregation is optional, we should still return something for inference, the default being a simple summation.
|
||
@Override | ||
public List<Double> classificationProbability(Map<String, Object> fields) { | ||
if ((targetType == TargetType.CLASSIFICATION) == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this check needed provided that this method delegates to another one (with a check) in line 134?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, I don't think we want to even parse the field input if we are classification.
@Override | ||
public double aggregate(List<Double> values) { | ||
Objects.requireNonNull(values, "values must not be null"); | ||
Optional<Double> summation = values.stream().reduce((memo, v) -> memo + v); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could provide Double::sum
instead of (memo, v) -> memo + v
.
if (summation.isPresent()) { | ||
return summation.get(); | ||
} | ||
throw new IllegalArgumentException("values must not contain null values"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When values
are empty, summation
will be empty as well, right? Then this message can be misleading.
…com:benwtrent/elasticsearch into feature/ml-inference-ensemble-model-parsing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@elasticmachine update branch |
run elasticsearch-ci/2 |
* [ML][Inference] adding ensemble model objects * addressing PR comments * Update TreeTests.java * addressing PR comments * fixing test
This adds the ensemble model object to core and HLRC.
Changes include: