@@ -811,128 +811,132 @@ object DecisionTree extends Serializable with Logging {
811811 // For each (feature, split), calculate the gain, and select the best (feature, split).
812812 val (bestSplit, bestSplitStats) =
813813 Range (0 , binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
814- val featureIndex = if (featuresForNode.nonEmpty) {
815- featuresForNode.get.apply(featureIndexIdx)
816- } else {
817- featureIndexIdx
818- }
819- val numSplits = binAggregates.metadata.numSplits(featureIndex)
820- if (binAggregates.metadata.isContinuous(featureIndex)) {
821- // Cumulative sum (scanLeft) of bin statistics.
822- // Afterwards, binAggregates for a bin is the sum of aggregates for
823- // that bin + all preceding bins.
824- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
825- var splitIndex = 0
826- while (splitIndex < numSplits) {
827- binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1 , splitIndex)
828- splitIndex += 1
814+ val featureIndex = if (featuresForNode.nonEmpty) {
815+ featuresForNode.get.apply(featureIndexIdx)
816+ } else {
817+ featureIndexIdx
829818 }
830- // Find best split.
831- val (bestFeatureSplitIndex, bestFeatureGainStats) =
832- Range (0 , numSplits).map { case splitIdx =>
833- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
834- val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
835- rightChildStats.subtract(leftChildStats)
836- predictWithImpurity = Some (predictWithImpurity.getOrElse(
837- calculatePredictImpurity(leftChildStats, rightChildStats)))
838- val gainStats = calculateGainForSplit(leftChildStats,
839- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
840- (splitIdx, gainStats)
841- }.maxBy(_._2.gain)
842- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
843- } else if (binAggregates.metadata.isUnordered(featureIndex)) {
844- // Unordered categorical feature
845- val (leftChildOffset, rightChildOffset) =
846- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
847- val (bestFeatureSplitIndex, bestFeatureGainStats) =
848- Range (0 , numSplits).map { splitIndex =>
849- val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
850- val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
851- predictWithImpurity = Some (predictWithImpurity.getOrElse(
852- calculatePredictImpurity(leftChildStats, rightChildStats)))
853- val gainStats = calculateGainForSplit(leftChildStats,
854- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
855- (splitIndex, gainStats)
856- }.maxBy(_._2.gain)
857- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
858- } else {
859- // Ordered categorical feature
860- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
861- val numBins = binAggregates.metadata.numBins(featureIndex)
862-
863- /* Each bin is one category (feature value).
864- * The bins are ordered based on centroidForCategories, and this ordering determines which
865- * splits are considered. (With K categories, we consider K - 1 possible splits.)
866- *
867- * centroidForCategories is a list: (category, centroid)
868- */
869- val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
870- // For categorical variables in multiclass classification,
871- // the bins are ordered by the impurity of their corresponding labels.
872- Range (0 , numBins).map { case featureValue =>
873- val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
874- val centroid = if (categoryStats.count != 0 ) {
875- categoryStats.calculate()
876- } else {
877- Double .MaxValue
878- }
879- (featureValue, centroid)
819+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
820+ if (binAggregates.metadata.isContinuous(featureIndex)) {
821+ // Cumulative sum (scanLeft) of bin statistics.
822+ // Afterwards, binAggregates for a bin is the sum of aggregates for
823+ // that bin + all preceding bins.
824+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
825+ var splitIndex = 0
826+ while (splitIndex < numSplits) {
827+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1 , splitIndex)
828+ splitIndex += 1
880829 }
881- } else { // regression or binary classification
882- // For categorical variables in regression and binary classification,
883- // the bins are ordered by the centroid of their corresponding labels.
884- Range (0 , numBins).map { case featureValue =>
885- val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
886- val centroid = if (categoryStats.count != 0 ) {
887- categoryStats.predict
888- } else {
889- Double .MaxValue
830+ // Find best split.
831+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
832+ Range (0 , numSplits).map { case splitIdx =>
833+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
834+ val rightChildStats =
835+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
836+ rightChildStats.subtract(leftChildStats)
837+ predictWithImpurity = Some (predictWithImpurity.getOrElse(
838+ calculatePredictImpurity(leftChildStats, rightChildStats)))
839+ val gainStats = calculateGainForSplit(leftChildStats,
840+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
841+ (splitIdx, gainStats)
842+ }.maxBy(_._2.gain)
843+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
844+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
845+ // Unordered categorical feature
846+ val (leftChildOffset, rightChildOffset) =
847+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
848+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
849+ Range (0 , numSplits).map { splitIndex =>
850+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
851+ val rightChildStats =
852+ binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
853+ predictWithImpurity = Some (predictWithImpurity.getOrElse(
854+ calculatePredictImpurity(leftChildStats, rightChildStats)))
855+ val gainStats = calculateGainForSplit(leftChildStats,
856+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
857+ (splitIndex, gainStats)
858+ }.maxBy(_._2.gain)
859+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
860+ } else {
861+ // Ordered categorical feature
862+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
863+ val numBins = binAggregates.metadata.numBins(featureIndex)
864+
865+ /* Each bin is one category (feature value).
866+ * The bins are ordered based on centroidForCategories, and this ordering determines which
867+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
868+ *
869+ * centroidForCategories is a list: (category, centroid)
870+ */
871+ val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
872+ // For categorical variables in multiclass classification,
873+ // the bins are ordered by the impurity of their corresponding labels.
874+ Range (0 , numBins).map { case featureValue =>
875+ val categoryStats =
876+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
877+ val centroid = if (categoryStats.count != 0 ) {
878+ categoryStats.calculate()
879+ } else {
880+ Double .MaxValue
881+ }
882+ (featureValue, centroid)
883+ }
884+ } else { // regression or binary classification
885+ // For categorical variables in regression and binary classification,
886+ // the bins are ordered by the impurity of their corresponding labels.
887+ Range (0 , numBins).map { case featureValue =>
888+ val categoryStats =
889+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
890+ val centroid = if (categoryStats.count != 0 ) {
891+ categoryStats.calculate()
892+ } else {
893+ Double .MaxValue
894+ }
895+ (featureValue, centroid)
890896 }
891- (featureValue, centroid)
892897 }
893- }
894898
895- logDebug(" Centroids for categorical variable: " + centroidForCategories.mkString(" ," ))
899+ logDebug(" Centroids for categorical variable: " + centroidForCategories.mkString(" ," ))
896900
897- // bins sorted by centroids
898- val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
901+ // bins sorted by centroids
902+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
899903
900- logDebug(" Sorted centroids for categorical variable = " +
901- categoriesSortedByCentroid.mkString(" ," ))
904+ logDebug(" Sorted centroids for categorical variable = " +
905+ categoriesSortedByCentroid.mkString(" ," ))
902906
903- // Cumulative sum (scanLeft) of bin statistics.
904- // Afterwards, binAggregates for a bin is the sum of aggregates for
905- // that bin + all preceding bins.
906- var splitIndex = 0
907- while (splitIndex < numSplits) {
908- val currentCategory = categoriesSortedByCentroid(splitIndex)._1
909- val nextCategory = categoriesSortedByCentroid(splitIndex + 1 )._1
910- binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
911- splitIndex += 1
907+ // Cumulative sum (scanLeft) of bin statistics.
908+ // Afterwards, binAggregates for a bin is the sum of aggregates for
909+ // that bin + all preceding bins.
910+ var splitIndex = 0
911+ while (splitIndex < numSplits) {
912+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
913+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1 )._1
914+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
915+ splitIndex += 1
916+ }
917+ // lastCategory = index of bin with total aggregates for this (node, feature)
918+ val lastCategory = categoriesSortedByCentroid.last._1
919+ // Find best split.
920+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
921+ Range (0 , numSplits).map { splitIndex =>
922+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
923+ val leftChildStats =
924+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
925+ val rightChildStats =
926+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
927+ rightChildStats.subtract(leftChildStats)
928+ predictWithImpurity = Some (predictWithImpurity.getOrElse(
929+ calculatePredictImpurity(leftChildStats, rightChildStats)))
930+ val gainStats = calculateGainForSplit(leftChildStats,
931+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
932+ (splitIndex, gainStats)
933+ }.maxBy(_._2.gain)
934+ val categoriesForSplit =
935+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0 , bestFeatureSplitIndex + 1 )
936+ val bestFeatureSplit =
937+ new Split (featureIndex, Double .MinValue , Categorical , categoriesForSplit)
938+ (bestFeatureSplit, bestFeatureGainStats)
912939 }
913- // lastCategory = index of bin with total aggregates for this (node, feature)
914- val lastCategory = categoriesSortedByCentroid.last._1
915- // Find best split.
916- val (bestFeatureSplitIndex, bestFeatureGainStats) =
917- Range (0 , numSplits).map { splitIndex =>
918- val featureValue = categoriesSortedByCentroid(splitIndex)._1
919- val leftChildStats =
920- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
921- val rightChildStats =
922- binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
923- rightChildStats.subtract(leftChildStats)
924- predictWithImpurity = Some (predictWithImpurity.getOrElse(
925- calculatePredictImpurity(leftChildStats, rightChildStats)))
926- val gainStats = calculateGainForSplit(leftChildStats,
927- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
928- (splitIndex, gainStats)
929- }.maxBy(_._2.gain)
930- val categoriesForSplit =
931- categoriesSortedByCentroid.map(_._1.toDouble).slice(0 , bestFeatureSplitIndex + 1 )
932- val bestFeatureSplit =
933- new Split (featureIndex, Double .MinValue , Categorical , categoriesForSplit)
934- (bestFeatureSplit, bestFeatureGainStats)
935- }
936940 }.maxBy(_._2.gain)
937941
938942 (bestSplit, bestSplitStats, predictWithImpurity.get._1)
0 commit comments