RFormula

  RFormula通过一个R model formula选择一个特定的列。
目前我们支持R算子的一个受限的子集,包括~,.,:,+,-。这些基本的算子是:

  • ~ 分开targetterms
  • + 连接term,+ 0表示删除截距(intercept)
  • - 删除term,- 1表示删除截距
  • : 交集
  • . 除了target之外的所有列

  假设abdouble列,我们用下面简单的例子来证明RFormula的有效性。

  • y ~ a + b 表示模型 y ~ w0 + w1 * a + w2 * b,其中w0是截距,w1w2是系数
  • y ~ a + b + a:b - 1表示模型y ~ w1 * a + w2 * b + w3 * a * b,其中w1,w2,w3是系数

  RFormula产生一个特征向量列和一个doublestring类型的标签列。比如在线性回归中使用R中的公式时,
字符串输入列是one-hot编码,数值列强制转换为double类型。如果标签列是字符串类型,它将使用StringIndexer转换为double
类型。如果DataFrame中不存在标签列,输出的标签列将通过公式中指定的返回变量来创建。

例子

  假设我们有一个DataFrame,它的列名是id, country, hourclicked

  1. id | country | hour | clicked
  2. ---|---------|------|---------
  3. 7 | "US" | 18 | 1.0
  4. 8 | "CA" | 12 | 0.0
  5. 9 | "NZ" | 15 | 0.0

  如果我们用clicked ~ country + hour(基于countryhour来预测clicked)来作用于RFormula,将会得到下面的结果。

  1. id | country | hour | clicked | features | label
  2. ---|---------|------|---------|------------------|-------
  3. 7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0
  4. 8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0
  5. 9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0

  下面是代码调用的例子。

  1. import org.apache.spark.ml.feature.RFormula
  2. val dataset = spark.createDataFrame(Seq(
  3. (7, "US", 18, 1.0),
  4. (8, "CA", 12, 0.0),
  5. (9, "NZ", 15, 0.0)
  6. )).toDF("id", "country", "hour", "clicked")
  7. val formula = new RFormula()
  8. .setFormula("clicked ~ country + hour")
  9. .setFeaturesCol("features")
  10. .setLabelCol("label")
  11. val output = formula.fit(dataset).transform(dataset)
  12. output.select("features", "label").show()