-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease #51
Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease #51
Conversation
…ll resulting issues. Get all tests running.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small changes, mostly formatting and javadoc. I would like consistent use of "features" rather than "attributes" as we don't use "attributes" to mean features anywhere else in the codebase.
|
||
lessThanOrEqual = new ClassifierTrainingNode(impurity, lessThanData, lessThanIndices.size, depth + 1, featureIDMap, labelIDMap); | ||
greaterThan = new ClassifierTrainingNode(impurity, greaterThanData, numExamples - lessThanIndices.size, depth + 1, featureIDMap, labelIDMap); | ||
List<AbstractTrainingNode<Label>> output = new ArrayList<>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should have sized this arraylist to 2 originally, but now you've moved it we should definitely do that. Ditto for the other places in Regression which create a small array. When we move up from Java 8 we can replace it with a List.of() which will be better.
public double getImpurity() { | ||
return impurity.impurity(labelCounts); | ||
} | ||
public double getImpurity() { return impurityScore;} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be on a single line.
@@ -75,7 +75,8 @@ default public double impurity(double[] input) { | |||
} | |||
|
|||
/** | |||
* Calculates the impurity assuming the input are weighted counts, normalizing by their sum. | |||
* Calculates the impurity assuming the input are weighted counts, normalizing by their sum. The resulting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This javadoc isn't quite right. The counts are assumed to be weighted, they are converted into a probability distribution by dividing by their sum, and then the impurity is multiplied by the sum. It's missing the "probability distribution" bit.
|
||
public void testCART(Pair<Dataset<Label>,Dataset<Label>> p) { | ||
TreeModel<Label> m = t.train(p.getA()); | ||
public void testCART(Pair<Dataset<Label>,Dataset<Label>> p, AbstractCARTTrainer<Label> trainer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be sharply typed (i.e. CARTClassificationTrainer
not AbstractCARTTrainer<Label>
). I'd prefer nobody ever use the AbstractCARTTrainer type in user code, so we shouldn't do it in the tests unless it's strictly necessary.
|
||
public class TestCART { | ||
|
||
private static final CARTClassificationTrainer t = new CARTClassificationTrainer(); | ||
private static final CARTClassificationTrainer randomt = new CARTClassificationTrainer(5, 2, 0.0f,1.0f, true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like there's some random whitespace in this line?
(node.getNumExamples() > minChildWeight)) { | ||
if (numFeaturesInSplit != featureIDMap.size()) { | ||
Util.randpermInPlace(originalIndices, localRNG); | ||
System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit); | ||
} | ||
List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices); | ||
List<AbstractTrainingNode<Regressor>> nodes = node.buildTree(indices, localRNG, | ||
getUseRandomSplitPoints(),getMinImpurityDecrease() * weightSum); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe precompute getMinImpurityDecrease()*weightSum
rather than do it every time?
} | ||
|
||
@Override | ||
public double getImpurity() { | ||
public double getImpurity() { return impurityScore;} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting.
* Calculates the impurity score of the node. | ||
* @return the impurity score of the node. | ||
*/ | ||
private double calcImpurity(){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put a space between ()
and the open curly brace.
|
||
public void testJointRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p) { | ||
TreeModel<Regressor> m = t.train(p.getA()); | ||
public void testJointRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p, AbstractCARTTrainer<Regressor> trainer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the classification tests, I'd prefer it if the sharp CARTJointRegressionTrainer
is used rather than AbstractCARTTrainer<Regressor>
unless you're sharing the tests across both types of regression tree trainer.
public void testIndependentRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p) { | ||
Model<Regressor> m = t.train(p.getA()); | ||
public void testIndependentRegressionTree(Pair<Dataset<Regressor>,Dataset<Regressor>> p, | ||
AbstractCARTTrainer<Regressor> trainer) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sharp type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Three tiny changes to clean things up. Looks good otherwise.
@@ -126,8 +126,8 @@ public synchronized void postConfig() { | |||
throw new IllegalArgumentException("maxDepth must be greater than or equal to 1"); | |||
} | |||
|
|||
if ((minChildWeight < 0.0f)) { | |||
throw new IllegalArgumentException("minChildWeight must be greater than or equal to 0"); | |||
if ((minChildWeight <= 0.0f)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two sets of parentheses here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -121,20 +124,20 @@ public static void main(String[] args) throws IOException { | |||
SparseTrainer<Regressor> trainer; | |||
switch (o.treeType) { | |||
case CART_INDEPENDENT: | |||
if (o.fraction <= 0) { | |||
trainer = new CARTRegressionTrainer(o.depth,o.minChildWeight,0.0f, 1, false, impurity, | |||
if (o.fraction == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's probably better to fix the default value of fraction to be 1.0, and then remove this if clause entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -69,10 +73,12 @@ public CARTClassificationTrainer getTrainer() { | |||
CARTClassificationTrainer trainer; | |||
switch (cartTreeAlgorithm) { | |||
case CART: | |||
if (cartSplitFraction <= 0) { | |||
trainer = new CARTClassificationTrainer(cartMaxDepth, cartMinChildWeight, 1, impurity, cartSeed); | |||
if (cartSplitFraction == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best to set the default value of cartSplitFraction to 1.0 and then remove this if statement entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks.
Description
Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease.
Motivation
Adds Extremely Randomized Trees Algorithm and Min Impurity Decrease.
Paper reference
https://link.springer.com/article/10.1007/s10994-006-6226-1