[Spark Streaming on Angel] FTRL

随着近几年online learning的火热发展,FTRL这种优化算法不光更能适应海量数据的要求,同时还能比较轻松的学习到一个有效且稀疏的模型,自问世以来在学术界和工业界都倍受关注和好评。基于此,我们在Spark on Angel平台实现了在线与离线方式的以FTRL进行优化的分布式LR算法。下面介绍该算法的原理与使用。

1. 算法介绍

FTRL算法兼顾了FOBOSRDA两种算法的优势,既能同FOBOS保证比较高的精度,又能在损失一定精度的情况下产生更好的稀疏性。

该算法的特征权重的更新公式(参考文献1)为:

FTRL - 图1

其中

  • G函数表示损失函数的梯度

FTRL - 图2

  • w的更新公式(针对特征权重的各个维度将其拆解成N个独立的标量最小化问题)

FTRL - 图3

  • 如果对每一维度的学习率单独考虑,w的更新公式:

FTRL - 图4

2.分布式实现

Google给出的带有L1和L2正则项的基于FTRL优化的逻辑回归算法的工程实现

FTRL - 图5

为了加快收敛速度,算法还提供了基于SVRG的方差约减FTRL算法,即在梯度更新时对梯度进行方差约减。

  • SVRG的一般过程为:

FTRL - 图6

为此,算法在损失函数的梯度g处增加了一步基于SVRG的更新,同时为了符合SVRG算法的原理,增加了两个参数rho1,rho2,近似计算每个阶段的权重和梯度(参考文献2)。给出基于SVRG算法的FTRL算法(后文简称”FTRL_VRG”)的一般过程:

FTRL - 图7

参考实现,结合Spark Streaming和Angel的特点,FTRL的分布式实现的框架图如下:

FTRL - 图8

FTRL_VRG的分布式实现框架图如下:

FTRL - 图9

3. 运行 & 性能

提供了两种数据接入方式:在线与离线方式,其中离线方式的详情查看这里

<在线方式>

说明

在线方式以kafka为消息发送机制,使用时需要填写kafka的配置信息。优化方式包括FTRL和FTRL_VRG两种

输入格式说明

  • 消息格式仅支持标准的“libsvm”数据格式或者“Dummy”格式
  • 为了模型的准确性,算法内部都自动对每个样本增加了index为0,value为1的特征值,以实现偏置效果,因此该算法的输入数据中index从1开始

参数说明

  • 算法参数

    • alpha:w更新公式中的alpha
    • beta: w更新公式中的beta
    • lambda1: w更新公式中的lambda1
    • lambda2: w更新公式中的lambda2
    • rho1:FTRL_VRG中的权重更新系数
    • rho2:FTRL_VRG中的梯度更新系数
  • 输入输出参数

    • checkPointPath:streaming流数据的checkpoint路径
    • zkQuorum:Zookeeper的配置信息,格式:”hostname:port”
    • topic:kafka的topic信息
    • group:kafka的group信息
    • dim:输入数据的维度,特征ID默认从0开始计数
    • isOneHot:数据格式是否为One-Hot,若是则为true
    • receiverNum:kafka receiver的个数
    • streamingWindow:控制spark streaming流中每批数据的持续时间
    • modelPath:训练时模型的保存路
    • logPath:每个batch的平均loss输出路径
    • partitionNum:streaming中的分区数
    • optMethod:选择采用ftrl还是ftrlVRG进行优化
    • isIncrementLearn:是否增量学习
    • batch2Save:间隔多少个batch对模型进行一次保存
  • 资源参数
    • num-executors:executor个数
    • executor-cores:executor的核数
    • executor-memory:executor的内存
    • driver-memory:driver端内存
    • spark.ps.instances:Angel PS节点数
    • spark.ps.cores:每个PS节点的Core数
    • spark.ps.memory:每个PS节点的Memory大小

提交命令

可以通过下面命令向Yarn集群提交FTRL_SparseLR算法的训练任务:

  1. ./bin/spark-submit \
  2. --master yarn-cluster \
  3. --conf spark.hadoop.angel.ps.ha.replication.number=2 \
  4. --conf fs.default.name=$defaultFS \
  5. --conf spark.yarn.allocation.am.maxMemory=55g \
  6. --conf spark.yarn.allocation.executor.maxMemory=55g \
  7. --conf spark.ps.jars=$SONA_ANGEL_JARS \
  8. --conf spark.ps.instances=20 \
  9. --conf spark.ps.cores=2 \
  10. --conf spark.ps.memory=6g \
  11. --jars $SONA_SPARK_JARS \
  12. --name $name \
  13. --driver-memory 5g \
  14. --num-executors 10 \
  15. --executor-cores 2 \
  16. --executor-memory 12g \
  17. --class com.tencent.angel.spark.ml.online_learning.FTRLRunner \
  18. spark-on-angel-mllib-<version>.jar \
  19. partitionNum:10 \
  20. modelPath:$modelPath \
  21. checkPointPath:$checkPointPath \
  22. logPath:$logPath \
  23. zkQuorum:<zookeeper IP> \
  24. group:<kafka group> \
  25. topic:<kafka topic> \
  26. rho1:0.2 \
  27. rho2:0.2 \
  28. alpha:0.1 \
  29. isIncrementLearn:false \
  30. lambda1:0.3 \
  31. lambda2:0.3 \
  32. dim:175835 \
  33. streamingWindow:10 \
  34. receiverNum:10 \
  35. batch2Save:10 \
  36. optMethod:ftrlVRG

4. 参考文献

  1. H. Brendan McMahan, Gary Holt, D. Sculley, Michael Young. Ad Click Prediction: a View from the Trenches.KDD’13, August 11–14, 2013
  2. 腾讯大数据技术峰会2017-广告中的大数据与机器学习