From f067d68f0d951e7f0f089419c506fbd5ce2c2fc1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 19:36:21 -0800 Subject: [PATCH] minor cleanup Signed-off-by: Manish Amde --- .../apache/spark/mllib/tree/DecisionTree.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 865a95c5025fc..a9a578c4ac262 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -41,7 +41,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("numSplits = " + bins(0).length) strategy.numBins = bins(0).length - //TODO: Level-wise training of tree and obtain Decision Tree model val maxDepth = strategy.maxDepth val maxNumNodes = scala.math.pow(2,maxDepth).toInt - 1 @@ -62,7 +61,6 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { logDebug("#####################################") //Find best split for all nodes at a level - val numNodes= scala.math.pow(2,level).toInt val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters,splits,bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex){ @@ -105,7 +103,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging { private def extractInfoForLowerLevels(level: Int, index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], filters: Array[List[Filter]]) { for (i <- 0 to 1) { - val nodeIndex = (scala.math.pow(2, level + 1)).toInt - 1 + 2 * index + i + val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { @@ -205,7 +203,6 @@ object DecisionTree extends Serializable with Logging { def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = { if (isFeatureContinuous){ - //TODO: Do binary search for (binIndex <- 0 until strategy.numBins) { val bin = bins(featureIndex)(binIndex) val lowThreshold = bin.lowSplit.threshold @@ -250,9 +247,12 @@ object DecisionTree extends Serializable with Logging { val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { //Add to invalid bin index -1 - for (featureIndex <- 0 until numFeatures) { - arr(shift+featureIndex) = -1 - //TODO: Break since marking one bin is sufficient + breakable { + for (featureIndex <- 0 until numFeatures) { + arr(shift+featureIndex) = -1 + //Breaking since marking one bin is sufficient + break() + } } } else { for (featureIndex <- 0 until numFeatures) { @@ -318,7 +318,6 @@ object DecisionTree extends Serializable with Logging { def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { strategy.algo match { case Classification => classificationBinSeqOp(arr, agg) - //TODO: Implement this case Regression => regressionBinSeqOp(arr, agg) } agg @@ -599,7 +598,6 @@ object DecisionTree extends Serializable with Logging { logDebug("maxBins = " + numBins) //Calculate the number of sample for approximate quantile calculation - //TODO: Justify this calculation val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 logDebug("fraction of data used for calculating quantiles = " + fraction) @@ -624,7 +622,6 @@ object DecisionTree extends Serializable with Logging { val stride : Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { - //TODO: Investigate this val sampleIndex = (index+1)*stride.toInt val split = new Split(featureIndex,featureSamples(sampleIndex),Continuous, List()) splits(featureIndex)(index) = split