带权最小二乘

1 原理

  给定n个带权的观察样本$(w_i,a_i,b_i)$:

  • $w_i$表示第i个观察样本的权重;
  • $a_i$表示第i个观察样本的特征向量;
  • $b_i$表示第i个观察样本的标签。

  每个观察样本的特征数是m。我们使用下面的带权最小二乘公式作为目标函数:

minimize{x}\frac{1}{2} \sum{i=1}^n \frac{wi(a_i^T x -b_i)^2}{\sum{k=1}^n wk} + \frac{1}{2}\frac{\lambda}{\delta}\sum{j=1}^m(\sigma{j} x{j})^2

  其中$\lambda$是正则化参数,$\delta$是标签的总体标准差,$\sigma_j$是第j个特征列的总体标准差。

  这个目标函数有一个解析解法,它仅仅需要一次处理样本来搜集必要的统计数据去求解。与原始数据集必须存储在分布式系统上不同,
如果特征数相对较小,这些统计数据可以加载进单机的内存中,然后在driver端使用乔里斯基分解求解目标函数。

  spark ml中使用WeightedLeastSquares求解带权最小二乘问题。WeightedLeastSquares仅仅支持L2正则化,并且提供了正则化和标准化
的开关。为了使正太方程(normal equation)方法有效,特征数不能超过4096。如果超过4096,用L-BFGS代替。下面从代码层面介绍带权最小二乘优化算法
的实现。

2 代码解析

  我们首先看看WeightedLeastSquares的参数及其含义。

  1. private[ml] class WeightedLeastSquares(
  2. val fitIntercept: Boolean, //是否使用截距
  3. val regParam: Double, //L2正则化参数,指上面公式中的lambda
  4. val elasticNetParam: Double, // alpha,控制L1和L2正则化
  5. val standardizeFeatures: Boolean, // 是否标准化特征
  6. val standardizeLabel: Boolean, // 是否标准化标签
  7. val solverType: WeightedLeastSquares.Solver = WeightedLeastSquares.Auto,
  8. val maxIter: Int = 100, // 迭代次数
  9. val tol: Double = 1e-6) extends Logging with Serializable
  10. sealed trait Solver
  11. case object Auto extends Solver
  12. case object Cholesky extends Solver
  13. case object QuasiNewton extends Solver

  在上面的代码中,standardizeFeatures决定是否标准化特征,如果为真,则$\sigma_j$是A第j个特征列的总体标准差,否则$\sigma_j$为1。
standardizeLabel决定是否标准化标签,如果为真,则$\delta$是标签b的总体标准差,否则$\delta$为1。solverType指定求解的类型,有AutoCholesky
QuasiNewton三种选择。tol表示迭代的收敛阈值,仅仅在solverTypeQuasiNewton时可用。

2.1 求解过程

  WeightedLeastSquares接收一个包含(标签,权重,特征)的RDD,使用fit方法训练,并返回WeightedLeastSquaresModel

  1. def fit(instances: RDD[Instance]): WeightedLeastSquaresModel

  训练过程分为下面几步。

  • 1 统计样本信息
  1. val summary = instances.treeAggregate(new Aggregator)(_.add(_), _.merge(_))

  使用treeAggregate方法来统计样本信息。统计的信息在Aggregator类中给出了定义。通过展开上面的目标函数,我们可以知道这些统计信息的含义。

  1. private class Aggregator extends Serializable {
  2. var initialized: Boolean = false
  3. var k: Int = _ // 特征数
  4. var count: Long = _ // 样本数
  5. var triK: Int = _ // 对角矩阵保存的元素个数
  6. var wSum: Double = _ // 权重和
  7. private var wwSum: Double = _ // 权重的平方和
  8. private var bSum: Double = _ // 带权标签和
  9. private var bbSum: Double = _ // 带权标签的平方和
  10. private var aSum: DenseVector = _ // 带权特征和
  11. private var abSum: DenseVector = _ // 带权特征标签相乘和
  12. private var aaSum: DenseVector = _ // 带权特征平方和
  13. }

  方法add添加样本的统计信息,方法merge合并不同分区的统计信息。代码很简单,如下所示:

  1. /**
  2. * Adds an instance.
  3. */
  4. def add(instance: Instance): this.type = {
  5. val Instance(l, w, f) = instance
  6. val ak = f.size
  7. if (!initialized) {
  8. init(ak)
  9. }
  10. assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.")
  11. count += 1L
  12. wSum += w
  13. wwSum += w * w
  14. bSum += w * l
  15. bbSum += w * l * l
  16. BLAS.axpy(w, f, aSum)
  17. BLAS.axpy(w * l, f, abSum)
  18. BLAS.spr(w, f, aaSum) // wff^T
  19. this
  20. }
  21. /**
  22. * Merges another [[Aggregator]].
  23. */
  24. def merge(other: Aggregator): this.type = {
  25. if (!other.initialized) {
  26. this
  27. } else {
  28. if (!initialized) {
  29. init(other.k)
  30. }
  31. assert(k == other.k, s"dimension mismatch: this.k = $k but other.k = ${other.k}")
  32. count += other.count
  33. wSum += other.wSum
  34. wwSum += other.wwSum
  35. bSum += other.bSum
  36. bbSum += other.bbSum
  37. BLAS.axpy(1.0, other.aSum, aSum)
  38. BLAS.axpy(1.0, other.abSum, abSum)
  39. BLAS.axpy(1.0, other.aaSum, aaSum)
  40. this
  41. }

  Aggregator类给出了以下一些统计信息:

  1. aBar: 特征加权平均数
  2. bBar: 标签加权平均数
  3. aaBar: 特征平方加权平均数
  4. bbBar: 标签平方加权平均数
  5. aStd: 特征的加权总体标准差
  6. bStd: 标签的加权总体标准差
  7. aVar: 带权的特征总体方差

  计算出这些信息之后,将均值缩放到标准空间,即使每列数据的方差为1。

  1. // 缩放bBar和 bbBar
  2. val bBar = summary.bBar / bStd
  3. val bbBar = summary.bbBar / (bStd * bStd)
  4. val aStd = summary.aStd
  5. val aStdValues = aStd.values
  6. // 缩放aBar
  7. val aBar = {
  8. val _aBar = summary.aBar
  9. val _aBarValues = _aBar.values
  10. var i = 0
  11. // scale aBar to standardized space in-place
  12. while (i < numFeatures) {
  13. if (aStdValues(i) == 0.0) {
  14. _aBarValues(i) = 0.0
  15. } else {
  16. _aBarValues(i) /= aStdValues(i)
  17. }
  18. i += 1
  19. }
  20. _aBar
  21. }
  22. val aBarValues = aBar.values
  23. // 缩放 abBar
  24. val abBar = {
  25. val _abBar = summary.abBar
  26. val _abBarValues = _abBar.values
  27. var i = 0
  28. // scale abBar to standardized space in-place
  29. while (i < numFeatures) {
  30. if (aStdValues(i) == 0.0) {
  31. _abBarValues(i) = 0.0
  32. } else {
  33. _abBarValues(i) /= (aStdValues(i) * bStd)
  34. }
  35. i += 1
  36. }
  37. _abBar
  38. }
  39. val abBarValues = abBar.values
  40. // 缩放aaBar
  41. val aaBar = {
  42. val _aaBar = summary.aaBar
  43. val _aaBarValues = _aaBar.values
  44. var j = 0
  45. var p = 0
  46. // scale aaBar to standardized space in-place
  47. while (j < numFeatures) {
  48. val aStdJ = aStdValues(j)
  49. var i = 0
  50. while (i <= j) {
  51. val aStdI = aStdValues(i)
  52. if (aStdJ == 0.0 || aStdI == 0.0) {
  53. _aaBarValues(p) = 0.0
  54. } else {
  55. _aaBarValues(p) /= (aStdI * aStdJ)
  56. }
  57. p += 1
  58. i += 1
  59. }
  60. j += 1
  61. }
  62. _aaBar
  63. }
  64. val aaBarValues = aaBar.values
  • 2 处理L2正则项
  1. val effectiveRegParam = regParam / bStd
  2. val effectiveL1RegParam = elasticNetParam * effectiveRegParam
  3. val effectiveL2RegParam = (1.0 - elasticNetParam) * effectiveRegParam
  4. // 添加L2正则项到对角矩阵中
  5. var i = 0
  6. var j = 2
  7. while (i < triK) {
  8. var lambda = effectiveL2RegParam
  9. if (!standardizeFeatures) {
  10. val std = aStdValues(j - 2)
  11. if (std != 0.0) {
  12. lambda /= (std * std) //正则项标准化
  13. } else {
  14. lambda = 0.0
  15. }
  16. }
  17. if (!standardizeLabel) {
  18. lambda *= bStd
  19. }
  20. aaBarValues(i) += lambda
  21. i += j
  22. j += 1
  23. }
  • 3 选择solver

  WeightedLeastSquares实现了CholeskySolverQuasiNewtonSolver两种不同的求解方法。当没有正则化项时,
选择CholeskySolver求解,否则用QuasiNewtonSolver求解。

  1. val solver = if ((solverType == WeightedLeastSquares.Auto && elasticNetParam != 0.0 &&
  2. regParam != 0.0) || (solverType == WeightedLeastSquares.QuasiNewton)) {
  3. val effectiveL1RegFun: Option[(Int) => Double] = if (effectiveL1RegParam != 0.0) {
  4. Some((index: Int) => {
  5. if (fitIntercept && index == numFeatures) {
  6. 0.0
  7. } else {
  8. if (standardizeFeatures) {
  9. effectiveL1RegParam
  10. } else {
  11. if (aStdValues(index) != 0.0) effectiveL1RegParam / aStdValues(index) else 0.0
  12. }
  13. }
  14. })
  15. } else {
  16. None
  17. }
  18. new QuasiNewtonSolver(fitIntercept, maxIter, tol, effectiveL1RegFun)
  19. } else {
  20. new CholeskySolver
  21. }

  CholeskySolverQuasiNewtonSolver的详细分析会在另外的专题进行描述。

  • 4 处理结果
  1. val solution = solver match {
  2. case cholesky: CholeskySolver =>
  3. try {
  4. cholesky.solve(bBar, bbBar, ab, aa, aBar)
  5. } catch {
  6. // if Auto solver is used and Cholesky fails due to singular AtA, then fall back to
  7. // Quasi-Newton solver.
  8. case _: SingularMatrixException if solverType == WeightedLeastSquares.Auto =>
  9. logWarning("Cholesky solver failed due to singular covariance matrix. " +
  10. "Retrying with Quasi-Newton solver.")
  11. // ab and aa were modified in place, so reconstruct them
  12. val _aa = getAtA(aaBarValues, aBarValues)
  13. val _ab = getAtB(abBarValues, bBar)
  14. val newSolver = new QuasiNewtonSolver(fitIntercept, maxIter, tol, None)
  15. newSolver.solve(bBar, bbBar, _ab, _aa, aBar)
  16. }
  17. case qn: QuasiNewtonSolver =>
  18. qn.solve(bBar, bbBar, ab, aa, aBar)
  19. }
  20. val (coefficientArray, intercept) = if (fitIntercept) {
  21. (solution.coefficients.slice(0, solution.coefficients.length - 1),
  22. solution.coefficients.last * bStd)
  23. } else {
  24. (solution.coefficients, 0.0)
  25. }

  上面代码的异常处理需要注意一下。在AtA是奇异矩阵的情况下,乔里斯基分解会报错,这时需要用拟牛顿方法求解。

  以上的结果是在标准空间中,所以我们需要将结果从标准空间转换到原来的空间。

  1. // convert the coefficients from the scaled space to the original space
  2. var q = 0
  3. val len = coefficientArray.length
  4. while (q < len) {
  5. coefficientArray(q) *= { if (aStdValues(q) != 0.0) bStd / aStdValues(q) else 0.0 }
  6. q += 1
  7. }