随机森林

1 Bagging

  Bagging采用自助采样法(bootstrap sampling)采样数据。给定包含m个样本的数据集,我们先随机取出一个样本放入采样集中,再把该样本放回初始数据集,使得下次采样时,样本仍可能被选中,
这样,经过m次随机采样操作,我们得到包含m个样本的采样集。

  按照此方式,我们可以采样出T个含m个训练样本的采样集,然后基于每个采样集训练出一个基本学习器,再将这些基本学习器进行结合。这就是Bagging的一般流程。在对预测输出进行结合时,Bagging通常使用简单投票法,
对回归问题使用简单平均法。若分类预测时,出现两个类收到同样票数的情形,则最简单的做法是随机选择一个,也可以进一步考察学习器投票的置信度来确定最终胜者。

  Bagging的算法描述如下图所示。

1.1

2 随机森林

  随机森林是Bagging的一个扩展变体。随机森林在以决策树为基学习器构建Bagging集成的基础上,进一步在决策树的训练过程中引入了随机属性选择。具体来讲,传统决策树在选择划分属性时,
在当前节点的属性集合(假设有d个属性)中选择一个最优属性;而在随机森林中,对基决策树的每个节点,先从该节点的属性集合中随机选择一个包含k个属性的子集,然后再从这个子集中选择一个最优属性用于划分。
这里的参数k控制了随机性的引入程度。若令k=d,则基决策树的构建与传统决策树相同;若令k=1,则是随机选择一个属性用于划分。在MLlib中,有两种选择用于分类,即k=log2(d)k=sqrt(d)
一种选择用于回归,即k=1/3d。在源码分析中会详细介绍。

  可以看出,随机森林对Bagging只做了小改动,但是与Bagging中基学习器的“多样性”仅仅通过样本扰动(通过对初始训练集采样)而来不同,随机森林中基学习器的多样性不仅来自样本扰动,还来自属性扰动。
这使得最终集成的泛化性能可通过个体学习器之间差异度的增加而进一步提升。

3 随机森林在分布式环境下的优化策略

  随机森林算法在单机环境下很容易实现,但在分布式环境下特别是在Spark平台上,传统单机形式的迭代方式必须要进行相应改进才能适用于分布式环境
,这是因为在分布式环境下,数据也是分布式的,算法设计不得当会生成大量的IO操作,例如频繁的网络数据传输,从而影响算法效率。
因此,在Spark上进行随机森林算法的实现,需要进行一定的优化,Spark中的随机森林算法主要实现了三个优化策略:

  • 切分点抽样统计,如下图所示。在单机环境下的决策树对连续变量进行切分点选择时,一般是通过对特征点进行排序,然后取相邻两个数之间的点作为切分点,这在单机环境下是可行的,但如果在分布式环境下如此操作的话,
    会带来大量的网络传输操作,特别是当数据量达到PB级时,算法效率将极为低下。为避免该问题,Spark中的随机森林在构建决策树时,会对各分区采用一定的子特征策略进行抽样,然后生成各个分区的统计数据,并最终得到切分点。
    (从源代码里面看,是先对样本进行抽样,然后根据抽样样本值出现的次数进行排序,然后再进行切分)。
1.2
  • 特征装箱(Binning),如下图所示。决策树的构建过程就是对特征的取值不断进行划分的过程,对于离散的特征,如果有M个值,最多有2^(M-1) - 1个划分。如果值是有序的,那么就最多M-1个划分。
    比如年龄特征,有老,中,少3个值,如果无序有2^2-1=3个划分,即老|中,少;老,中|少;老,少|中。;如果是有序的,即按老,中,少的序,那么只有m-1个,即2种划分,老|中,少;老,中|少
    对于连续的特征,其实就是进行范围划分,而划分的点就是split(切分点),划分出的区间就是bin。对于连续特征,理论上split是无数的,在分布环境下不可能取出所有的值,因此它采用的是切点抽样统计方法。
1.3
  • 逐层训练(level-wise training),如下图所示。单机版本的决策树生成过程是通过递归调用(本质上是深度优先)的方式构造树,在构造树的同时,需要移动数据,将同一个子节点的数据移动到一起。
    此方法在分布式数据结构上无法有效的执行,而且也无法执行,因为数据太大,无法放在一起,所以在分布式环境下采用的策略是逐层构建树节点(本质上是广度优先),这样遍历所有数据的次数等于所有树中的最大层数。
    每次遍历时,只需要计算每个节点所有切分点统计参数,遍历完后,根据节点的特征划分,决定是否切分,以及如何切分。
1.4

4 使用实例

  下面的例子用于分类。

  1. import org.apache.spark.mllib.tree.RandomForest
  2. import org.apache.spark.mllib.tree.model.RandomForestModel
  3. import org.apache.spark.mllib.util.MLUtils
  4. // Load and parse the data file.
  5. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
  6. // Split the data into training and test sets (30% held out for testing)
  7. val splits = data.randomSplit(Array(0.7, 0.3))
  8. val (trainingData, testData) = (splits(0), splits(1))
  9. // Train a RandomForest model.
  10. // 空的类别特征信息表示所有的特征都是连续的.
  11. val numClasses = 2
  12. val categoricalFeaturesInfo = Map[Int, Int]()
  13. val numTrees = 3 // Use more in practice.
  14. val featureSubsetStrategy = "auto" // Let the algorithm choose.
  15. val impurity = "gini"
  16. val maxDepth = 4
  17. val maxBins = 32
  18. val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  19. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  20. // Evaluate model on test instances and compute test error
  21. val labelAndPreds = testData.map { point =>
  22. val prediction = model.predict(point.features)
  23. (point.label, prediction)
  24. }
  25. val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
  26. println("Test Error = " + testErr)
  27. println("Learned classification forest model:\n" + model.toDebugString)

  下面的例子用于回归。

  1. import org.apache.spark.mllib.tree.RandomForest
  2. import org.apache.spark.mllib.tree.model.RandomForestModel
  3. import org.apache.spark.mllib.util.MLUtils
  4. // Load and parse the data file.
  5. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
  6. // Split the data into training and test sets (30% held out for testing)
  7. val splits = data.randomSplit(Array(0.7, 0.3))
  8. val (trainingData, testData) = (splits(0), splits(1))
  9. // Train a RandomForest model.
  10. // 空的类别特征信息表示所有的特征都是连续的
  11. val numClasses = 2
  12. val categoricalFeaturesInfo = Map[Int, Int]()
  13. val numTrees = 3 // Use more in practice.
  14. val featureSubsetStrategy = "auto" // Let the algorithm choose.
  15. val impurity = "variance"
  16. val maxDepth = 4
  17. val maxBins = 32
  18. val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo,
  19. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
  20. // Evaluate model on test instances and compute test error
  21. val labelsAndPredictions = testData.map { point =>
  22. val prediction = model.predict(point.features)
  23. (point.label, prediction)
  24. }
  25. val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
  26. println("Test Mean Squared Error = " + testMSE)
  27. println("Learned regression forest model:\n" + model.toDebugString)

5 源码分析

5.1 训练分析

  训练过程简单可以分为两步,第一步是初始化,第二步是迭代构建随机森林。这两大步还分为若干小步,下面会分别介绍这些内容。

5.1.1 初始化

  1. val retaggedInput = input.retag(classOf[LabeledPoint])
  2. //建立决策树的元数据信息(分裂点位置、箱子数及各箱子包含特征属性的值等)
  3. val metadata =
  4. DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
  5. //找到切分点(splits)及箱子信息(Bins)
  6. //对于连续型特征,利用切分点抽样统计简化计算
  7. //对于离散型特征,如果是无序的,则最多有个 splits=2^(numBins-1)-1 划分
  8. //如果是有序的,则最多有 splits=numBins-1 个划分
  9. val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
  10. //转换成树形的 RDD 类型,转换后,所有样本点已经按分裂点条件分到了各自的箱子中
  11. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
  12. val withReplacement = if (numTrees > 1) true else false
  13. // convertToBaggedRDD 方法使得每棵树就是样本的一个子集
  14. val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
  15. strategy.subsamplingRate, numTrees,
  16. withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
  17. //决策树的深度,最大为30
  18. val maxDepth = strategy.maxDepth
  19. //聚合的最大内存
  20. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
  21. val maxMemoryPerNode = {
  22. val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
  23. // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
  24. Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
  25. .take(metadata.numFeaturesPerNode).map(_._2))
  26. } else {
  27. None
  28. }
  29. //计算聚合操作时节点的内存
  30. RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
  31. }

  初始化的第一步就是决策树元数据信息的构建。它的代码如下所示。

  1. def buildMetadata(
  2. input: RDD[LabeledPoint],
  3. strategy: Strategy,
  4. numTrees: Int,
  5. featureSubsetStrategy: String): DecisionTreeMetadata = {
  6. //特征数
  7. val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
  8. throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
  9. s"but was given by empty one.")
  10. }
  11. val numExamples = input.count()
  12. val numClasses = strategy.algo match {
  13. case Classification => strategy.numClasses
  14. case Regression => 0
  15. }
  16. //最大可能的装箱数
  17. val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
  18. if (maxPossibleBins < strategy.maxBins) {
  19. logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
  20. s" (= number of training instances)")
  21. }
  22. // We check the number of bins here against maxPossibleBins.
  23. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
  24. // based on the number of training examples.
  25. //最大分类数要小于最大可能装箱数
  26. //这里categoricalFeaturesInfo是传入的信息,这个map保存特征的类别信息。
  27. //例如,(n->k)表示特征k包含的类别有(0,1,...,k-1)
  28. if (strategy.categoricalFeaturesInfo.nonEmpty) {
  29. val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
  30. val maxCategory =
  31. strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
  32. require(maxCategoriesPerFeature <= maxPossibleBins,
  33. s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
  34. s"number of values in each categorical feature, but categorical feature $maxCategory " +
  35. s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
  36. "features with a large number of values, or add more training examples.")
  37. }
  38. val unorderedFeatures = new mutable.HashSet[Int]()
  39. val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
  40. if (numClasses > 2) {
  41. // 多分类
  42. val maxCategoriesForUnorderedFeature =
  43. ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
  44. strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
  45. //如果类别特征只有1个类,我们把它看成连续的特征
  46. if (numCategories > 1) {
  47. // Decide if some categorical features should be treated as unordered features,
  48. // which require 2 * ((1 << numCategories - 1) - 1) bins.
  49. // We do this check with log values to prevent overflows in case numCategories is large.
  50. // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
  51. if (numCategories <= maxCategoriesForUnorderedFeature) {
  52. unorderedFeatures.add(featureIndex)
  53. numBins(featureIndex) = numUnorderedBins(numCategories)
  54. } else {
  55. numBins(featureIndex) = numCategories
  56. }
  57. }
  58. }
  59. } else {
  60. // 二分类或者回归
  61. strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
  62. //如果类别特征只有1个类,我们把它看成连续的特征
  63. if (numCategories > 1) {
  64. numBins(featureIndex) = numCategories
  65. }
  66. }
  67. }
  68. // 设置每个节点的特征数 (对随机森林而言).
  69. val _featureSubsetStrategy = featureSubsetStrategy match {
  70. case "auto" =>
  71. if (numTrees == 1) {//决策树时,使用所有特征
  72. "all"
  73. } else {
  74. if (strategy.algo == Classification) {//分类时,使用开平方
  75. "sqrt"
  76. } else { //回归时,使用1/3的特征
  77. "onethird"
  78. }
  79. }
  80. case _ => featureSubsetStrategy
  81. }
  82. val numFeaturesPerNode: Int = _featureSubsetStrategy match {
  83. case "all" => numFeatures
  84. case "sqrt" => math.sqrt(numFeatures).ceil.toInt
  85. case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
  86. case "onethird" => (numFeatures / 3.0).ceil.toInt
  87. }
  88. new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
  89. strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
  90. strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
  91. strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
  92. }

  初始化的第二步就是找到切分点(splits)及箱子信息(Bins)。这时,调用了DecisionTree.findSplitsBins方法,进入该方法了解详细信息。

  1. /**
  2. * Returns splits and bins for decision tree calculation.
  3. * Continuous and categorical features are handled differently.
  4. *
  5. * Continuous features:
  6. * For each feature, there are numBins - 1 possible splits representing the possible binary
  7. * decisions at each node in the tree.
  8. * This finds locations (feature values) for splits using a subsample of the data.
  9. *
  10. * Categorical features:
  11. * For each feature, there is 1 bin per split.
  12. * Splits and bins are handled in 2 ways:
  13. * (a) "unordered features"
  14. * For multiclass classification with a low-arity feature
  15. * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
  16. * the feature is split based on subsets of categories.
  17. * (b) "ordered features"
  18. * For regression and binary classification,
  19. * and for multiclass classification with a high-arity feature,
  20. * there is one bin per category.
  21. *
  22. * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
  23. * @param metadata Learning and dataset metadata
  24. * @return A tuple of (splits, bins).
  25. * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
  26. * of size (numFeatures, numSplits).
  27. * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
  28. * of size (numFeatures, numBins).
  29. */
  30. protected[tree] def findSplitsBins(
  31. input: RDD[LabeledPoint],
  32. metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {
  33. //特征数
  34. val numFeatures = metadata.numFeatures
  35. // Sample the input only if there are continuous features.
  36. // 判断特征中是否存在连续特征
  37. val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
  38. val sampledInput = if (continuousFeatures.nonEmpty) {
  39. // Calculate the number of samples for approximate quantile calculation.
  40. //采样样本数量,最少有 10000 个
  41. val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
  42. //计算采样比例
  43. val fraction = if (requiredSamples < metadata.numExamples) {
  44. requiredSamples.toDouble / metadata.numExamples
  45. } else {
  46. 1.0
  47. }
  48. //采样数据,有放回采样
  49. input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
  50. } else {
  51. input.sparkContext.emptyRDD[LabeledPoint]
  52. }
  53. //分裂点策略,目前 Spark 中只实现了一种策略:排序 Sort
  54. metadata.quantileStrategy match {
  55. case Sort =>
  56. findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
  57. case MinMax =>
  58. throw new UnsupportedOperationException("minmax not supported yet.")
  59. case ApproxHist =>
  60. throw new UnsupportedOperationException("approximate histogram not supported yet.")
  61. }
  62. }

  我们进入findSplitsBinsBySorting方法了解Sort分裂策略的实现。

  1. private def findSplitsBinsBySorting(
  2. input: RDD[LabeledPoint],
  3. metadata: DecisionTreeMetadata,
  4. continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
  5. def findSplits(
  6. featureIndex: Int,
  7. featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
  8. //每个特征分别对应一组切分点位置,这里splits是有序的
  9. val splits = {
  10. // findSplitsForContinuousFeature 返回连续特征的所有切分位置
  11. val featureSplits = findSplitsForContinuousFeature(
  12. featureSamples.toArray,
  13. metadata,
  14. featureIndex)
  15. featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
  16. }
  17. //存放切分点位置对应的箱子信息
  18. val bins = {
  19. //采用最小阈值 Double.MinValue 作为最左边的分裂位置并进行装箱
  20. val lowSplit = new DummyLowSplit(featureIndex, Continuous)
  21. //最后一个箱子的计算采用最大阈值 Double.MaxValue 作为最右边的切分位置
  22. val highSplit = new DummyHighSplit(featureIndex, Continuous)
  23. // tack the dummy splits on either side of the computed splits
  24. val allSplits = lowSplit +: splits.toSeq :+ highSplit
  25. //将切分点两两结合成一个箱子
  26. allSplits.sliding(2).map {
  27. case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
  28. }.toArray
  29. }
  30. (featureIndex, (splits, bins))
  31. }
  32. val continuousSplits = {
  33. // reduce the parallelism for split computations when there are less
  34. // continuous features than input partitions. this prevents tasks from
  35. // being spun up that will definitely do no work.
  36. val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
  37. input
  38. .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
  39. .groupByKey(numPartitions)
  40. .map { case (k, v) => findSplits(k, v) }
  41. .collectAsMap()
  42. }
  43. val numFeatures = metadata.numFeatures
  44. //遍历所有特征
  45. val (splits, bins) = Range(0, numFeatures).unzip {
  46. //处理连续特征的情况
  47. case i if metadata.isContinuous(i) =>
  48. val (split, bin) = continuousSplits(i)
  49. metadata.setNumSplits(i, split.length)
  50. (split, bin)
  51. //处理离散特征且无序的情况
  52. case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
  53. // Unordered features
  54. // 2^(maxFeatureValue - 1) - 1 combinations
  55. val featureArity = metadata.featureArity(i)
  56. val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
  57. val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
  58. new Split(i, Double.MinValue, Categorical, categories)
  59. }
  60. // For unordered categorical features, there is no need to construct the bins.
  61. // since there is a one-to-one correspondence between the splits and the bins.
  62. (split.toArray, Array.empty[Bin])
  63. //处理离散特征且有序的情况
  64. case i if metadata.isCategorical(i) =>
  65. //有序特征无需处理,箱子与特征值对应
  66. // Ordered features
  67. // Bins correspond to feature values, so we do not need to compute splits or bins
  68. // beforehand. Splits are constructed as needed during training.
  69. (Array.empty[Split], Array.empty[Bin])
  70. }
  71. (splits.toArray, bins.toArray)
  72. }

  计算连续特征的所有切分位置需要调用方法findSplitsForContinuousFeature方法。

  1. private[tree] def findSplitsForContinuousFeature(
  2. featureSamples: Array[Double],
  3. metadata: DecisionTreeMetadata,
  4. featureIndex: Int): Array[Double] = {
  5. val splits = {
  6. //切分数是bin的数量减1,即m-1
  7. val numSplits = metadata.numSplits(featureIndex)
  8. // (特征,特征出现的次数)
  9. val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
  10. m + ((x, m.getOrElse(x, 0) + 1))
  11. }
  12. // 根据特征进行排序
  13. val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
  14. // if possible splits is not enough or just enough, just return all possible splits
  15. val possibleSplits = valueCounts.length
  16. //如果特征数小于切分数,所有特征均作为切分点
  17. if (possibleSplits <= numSplits) {
  18. valueCounts.map(_._1)
  19. } else {
  20. // 等频切分
  21. // 切分点之间的步长
  22. val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
  23. val splitsBuilder = Array.newBuilder[Double]
  24. var index = 1
  25. // currentCount: sum of counts of values that have been visited
  26. //第一个特征的出现次数
  27. var currentCount = valueCounts(0)._2
  28. // targetCount: target value for `currentCount`.
  29. // If `currentCount` is closest value to `targetCount`,
  30. // then current value is a split threshold.
  31. // After finding a split threshold, `targetCount` is added by stride.
  32. // 如果currentCount离targetCount最近,那么当前值是切分点
  33. var targetCount = stride
  34. while (index < valueCounts.length) {
  35. val previousCount = currentCount
  36. currentCount += valueCounts(index)._2
  37. val previousGap = math.abs(previousCount - targetCount)
  38. val currentGap = math.abs(currentCount - targetCount)
  39. // If adding count of current value to currentCount
  40. // makes the gap between currentCount and targetCount smaller,
  41. // previous value is a split threshold.
  42. if (previousGap < currentGap) {
  43. splitsBuilder += valueCounts(index - 1)._1
  44. targetCount += stride
  45. }
  46. index += 1
  47. }
  48. splitsBuilder.result()
  49. }
  50. }
  51. splits
  52. }

   在if判断里每步前进stride个样本,累加在targetCount中。while循环逐次把每个特征值的个数加到currentCount里,计算前一次previousCount和这次currentCounttargetCount的距离,有3种情况,一种是precur都在target左边,肯定是cur小,继续循环,进入第二种情况;第二种一左一右,如果pre小,肯定是pre是最好的分割点,如果cur还是小,继续循环步进,进入第三种情况;第三种就是都在右边,显然是pre小。因此if的判断条件pre<cur,只要满足肯定就是split。整体下来的效果就能找到离target最近的一个特征值。

5.1.2 迭代构建随机森林

  1. //节点是否使用缓存,节点 ID 从 1 开始,1 即为这颗树的根节点,左节点为 2,右节点为 3,依次递增下去
  2. val nodeIdCache = if (strategy.useNodeIdCache) {
  3. Some(NodeIdCache.init(
  4. data = baggedInput,
  5. numTrees = numTrees,
  6. checkpointInterval = strategy.checkpointInterval,
  7. initVal = 1))
  8. } else {
  9. None
  10. }
  11. // FIFO queue of nodes to train: (treeIndex, node)
  12. val nodeQueue = new mutable.Queue[(Int, Node)]()
  13. val rng = new scala.util.Random()
  14. rng.setSeed(seed)
  15. // Allocate and queue root nodes.
  16. //创建树的根节点
  17. val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
  18. //将(树的索引,树的根节点)入队,树索引从 0 开始,根节点从 1 开始
  19. Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
  20. while (nodeQueue.nonEmpty) {
  21. // Collect some nodes to split, and choose features for each node (if subsampling).
  22. // Each group of nodes may come from one or multiple trees, and at multiple levels.
  23. // 取得每个树所有需要切分的节点,nodesForGroup表示需要切分的节点
  24. val (nodesForGroup, treeToNodeToIndexInfo) =
  25. RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
  26. //找出最优切点
  27. DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
  28. treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
  29. }

  这里有两点需要重点介绍,第一点是取得每个树所有需要切分的节点,通过RandomForest.selectNodesToSplit方法实现;第二点是找出最优的切分,通过DecisionTree.findBestSplits方法实现。下面分别介绍这两点。

  • 取得每个树所有需要切分的节点
  1. private[tree] def selectNodesToSplit(
  2. nodeQueue: mutable.Queue[(Int, Node)],
  3. maxMemoryUsage: Long,
  4. metadata: DecisionTreeMetadata,
  5. rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
  6. // nodesForGroup保存需要切分的节点,treeIndex --> nodes
  7. val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
  8. // mutableTreeToNodeToIndexInfo保存每个节点中选中特征的索引
  9. // treeIndex --> (global) node index --> (node index in group, feature indices)
  10. //(global) node index是树中的索引,组中节点索引的范围是[0, numNodesInGroup)
  11. val mutableTreeToNodeToIndexInfo =
  12. new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
  13. var memUsage: Long = 0L
  14. var numNodesInGroup = 0
  15. while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
  16. val (treeIndex, node) = nodeQueue.head
  17. // Choose subset of features for node (if subsampling).
  18. // 选中特征子集
  19. val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
  20. Some(SamplingUtils.reservoirSampleAndCount(Range(0,
  21. metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
  22. } else {
  23. None
  24. }
  25. // Check if enough memory remains to add this node to the group.
  26. // 检查是否有足够的内存
  27. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
  28. if (memUsage + nodeMemUsage <= maxMemoryUsage) {
  29. nodeQueue.dequeue()
  30. mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
  31. mutableTreeToNodeToIndexInfo
  32. .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
  33. = new NodeIndexInfo(numNodesInGroup, featureSubset)
  34. }
  35. numNodesInGroup += 1
  36. memUsage += nodeMemUsage
  37. }
  38. // 将可变map转换为不可变map
  39. val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
  40. val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
  41. (nodesForGroup, treeToNodeToIndexInfo)
  42. }
  • 选中最优切分
  1. //所有可切分的节点
  2. val nodes = new Array[Node](numNodes)
  3. nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
  4. nodesForTree.foreach { node =>
  5. nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
  6. }
  7. }
  8. // In each partition, iterate all instances and compute aggregate stats for each node,
  9. // yield an (nodeIndex, nodeAggregateStats) pair for each node.
  10. // After a `reduceByKey` operation,
  11. // stats of a node will be shuffled to a particular partition and be combined together,
  12. // then best splits for nodes are found there.
  13. // Finally, only best Splits for nodes are collected to driver to construct decision tree.
  14. //获取节点对应的特征
  15. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
  16. val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
  17. val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
  18. input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
  19. // Construct a nodeStatsAggregators array to hold node aggregate stats,
  20. // each node will have a nodeStatsAggregator
  21. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
  22. //节点对应的特征集
  23. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
  24. Some(nodeToFeatures(nodeIndex))
  25. }
  26. // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
  27. new DTStatsAggregator(metadata, featuresForNode)
  28. }
  29. // 迭代当前分区的所有对象,更新聚合统计信息,统计信息即采样数据的权重值
  30. points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
  31. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
  32. // which can be combined with other partition using `reduceByKey`
  33. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
  34. }
  35. } else {
  36. input.mapPartitions { points =>
  37. // Construct a nodeStatsAggregators array to hold node aggregate stats,
  38. // each node will have a nodeStatsAggregator
  39. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
  40. //节点对应的特征集
  41. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
  42. Some(nodeToFeatures(nodeIndex))
  43. }
  44. // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
  45. new DTStatsAggregator(metadata, featuresForNode)
  46. }
  47. // 迭代当前分区的所有对象,更新聚合统计信息
  48. points.foreach(binSeqOp(nodeStatsAggregators, _))
  49. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
  50. // which can be combined with other partition using `reduceByKey`
  51. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
  52. }
  53. }
  54. val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
  55. .map { case (nodeIndex, aggStats) =>
  56. val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
  57. nodeToFeatures(nodeIndex)
  58. }
  59. // find best split for each node
  60. val (split: Split, stats: InformationGainStats, predict: Predict) =
  61. binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
  62. (nodeIndex, (split, stats, predict))
  63. }.collectAsMap()

  该方法中的关键是对binsToBestSplit方法的调用,binsToBestSplit方法代码如下:

  1. private def binsToBestSplit(
  2. binAggregates: DTStatsAggregator,
  3. splits: Array[Array[Split]],
  4. featuresForNode: Option[Array[Int]],
  5. node: Node): (Split, InformationGainStats, Predict) = {
  6. // 如果当前节点是根节点,计算预测和不纯度
  7. val level = Node.indexToLevel(node.id)
  8. var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
  9. None
  10. } else {
  11. Some((node.predict, node.impurity))
  12. }
  13. // 对各特征及切分点,计算其信息增益并从中选择最优 (feature, split)
  14. val (bestSplit, bestSplitStats) =
  15. Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
  16. val featureIndex = if (featuresForNode.nonEmpty) {
  17. featuresForNode.get.apply(featureIndexIdx)
  18. } else {
  19. featureIndexIdx
  20. }
  21. val numSplits = binAggregates.metadata.numSplits(featureIndex)
  22. //特征为连续值的情况
  23. if (binAggregates.metadata.isContinuous(featureIndex)) {
  24. // Cumulative sum (scanLeft) of bin statistics.
  25. // Afterwards, binAggregates for a bin is the sum of aggregates for
  26. // that bin + all preceding bins.
  27. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
  28. var splitIndex = 0
  29. while (splitIndex < numSplits) {
  30. binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
  31. splitIndex += 1
  32. }
  33. // Find best split.
  34. val (bestFeatureSplitIndex, bestFeatureGainStats) =
  35. Range(0, numSplits).map { case splitIdx =>
  36. //计算 leftChild 及 rightChild 子节点的 impurity
  37. val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
  38. val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
  39. rightChildStats.subtract(leftChildStats)
  40. //求 impurity 的预测值,采用的是平均值计算
  41. predictWithImpurity = Some(predictWithImpurity.getOrElse(
  42. calculatePredictImpurity(leftChildStats, rightChildStats)))
  43. //求信息增益 information gain 值,用于评估切分点是否最优,请参考决策树中1.4.4章节的介绍
  44. val gainStats = calculateGainForSplit(leftChildStats,
  45. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  46. (splitIdx, gainStats)
  47. }.maxBy(_._2.gain)
  48. (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
  49. }
  50. //无序离散特征时的情况
  51. else if (binAggregates.metadata.isUnordered(featureIndex)) {
  52. // Unordered categorical feature
  53. val (leftChildOffset, rightChildOffset) =
  54. binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
  55. val (bestFeatureSplitIndex, bestFeatureGainStats) =
  56. Range(0, numSplits).map { splitIndex =>
  57. val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
  58. val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
  59. predictWithImpurity = Some(predictWithImpurity.getOrElse(
  60. calculatePredictImpurity(leftChildStats, rightChildStats)))
  61. val gainStats = calculateGainForSplit(leftChildStats,
  62. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  63. (splitIndex, gainStats)
  64. }.maxBy(_._2.gain)
  65. (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
  66. } else {//有序离散特征时的情况
  67. // Ordered categorical feature
  68. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
  69. val numBins = binAggregates.metadata.numBins(featureIndex)
  70. /* Each bin is one category (feature value).
  71. * The bins are ordered based on centroidForCategories, and this ordering determines which
  72. * splits are considered. (With K categories, we consider K - 1 possible splits.)
  73. *
  74. * centroidForCategories is a list: (category, centroid)
  75. */
  76. //多元分类时的情况
  77. val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
  78. // For categorical variables in multiclass classification,
  79. // the bins are ordered by the impurity of their corresponding labels.
  80. Range(0, numBins).map { case featureValue =>
  81. val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  82. val centroid = if (categoryStats.count != 0) {
  83. // impurity 求的就是均方差
  84. categoryStats.calculate()
  85. } else {
  86. Double.MaxValue
  87. }
  88. (featureValue, centroid)
  89. }
  90. } else { // 回归或二元分类时的情况
  91. // For categorical variables in regression and binary classification,
  92. // the bins are ordered by the centroid of their corresponding labels.
  93. Range(0, numBins).map { case featureValue =>
  94. val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  95. val centroid = if (categoryStats.count != 0) {
  96. //求的就是平均值作为 impurity
  97. categoryStats.predict
  98. } else {
  99. Double.MaxValue
  100. }
  101. (featureValue, centroid)
  102. }
  103. }
  104. // bins sorted by centroids
  105. val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
  106. // Cumulative sum (scanLeft) of bin statistics.
  107. // Afterwards, binAggregates for a bin is the sum of aggregates for
  108. // that bin + all preceding bins.
  109. var splitIndex = 0
  110. while (splitIndex < numSplits) {
  111. val currentCategory = categoriesSortedByCentroid(splitIndex)._1
  112. val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
  113. //将两个箱子的状态信息进行合并
  114. binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
  115. splitIndex += 1
  116. }
  117. // lastCategory = index of bin with total aggregates for this (node, feature)
  118. val lastCategory = categoriesSortedByCentroid.last._1
  119. // Find best split.
  120. //通过信息增益值选择最优切分点
  121. val (bestFeatureSplitIndex, bestFeatureGainStats) =
  122. Range(0, numSplits).map { splitIndex =>
  123. val featureValue = categoriesSortedByCentroid(splitIndex)._1
  124. val leftChildStats =
  125. binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
  126. val rightChildStats =
  127. binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
  128. rightChildStats.subtract(leftChildStats)
  129. predictWithImpurity = Some(predictWithImpurity.getOrElse(
  130. calculatePredictImpurity(leftChildStats, rightChildStats)))
  131. val gainStats = calculateGainForSplit(leftChildStats,
  132. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
  133. (splitIndex, gainStats)
  134. }.maxBy(_._2.gain)
  135. val categoriesForSplit =
  136. categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
  137. val bestFeatureSplit =
  138. new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
  139. (bestFeatureSplit, bestFeatureGainStats)
  140. }
  141. }.maxBy(_._2.gain)
  142. (bestSplit, bestSplitStats, predictWithImpurity.get._1)
  143. }

5.2 预测分析

  在利用随机森林进行预测时,调用的predict方法扩展自TreeEnsembleModel,它是树结构组合模型的表示,其核心代码如下所示:

  1. //不同的策略采用不同的预测方法
  2. def predict(features: Vector): Double = {
  3. (algo, combiningStrategy) match {
  4. case (Regression, Sum) =>
  5. predictBySumming(features)
  6. case (Regression, Average) =>
  7. predictBySumming(features) / sumWeights
  8. case (Classification, Sum) => // binary classification
  9. val prediction = predictBySumming(features)
  10. // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
  11. if (prediction > 0.0) 1.0 else 0.0
  12. case (Classification, Vote) =>
  13. predictByVoting(features)
  14. case _ =>
  15. throw new IllegalArgumentException()
  16. }
  17. }
  18. private def predictBySumming(features: Vector): Double = {
  19. val treePredictions = trees.map(_.predict(features))
  20. //两个向量的内集
  21. blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
  22. }
  23. //通过投票选举
  24. private def predictByVoting(features: Vector): Double = {
  25. val votes = mutable.Map.empty[Int, Double]
  26. trees.view.zip(treeWeights).foreach { case (tree, weight) =>
  27. val prediction = tree.predict(features).toInt
  28. votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
  29. }
  30. votes.maxBy(_._2)._1
  31. }

参考文献

【1】机器学习.周志华

【2】Spark 随机森林算法原理、源码分析及案例实战

【3】Scalable Distributed Decision Trees in Spark MLlib