StringIndexer

  StringIndexer将标签列的字符串编码为标签索引。这些索引是[0,numLabels),通过标签频率排序,所以频率最高的标签的索引为0。
如果输入列是数字,我们把它强转为字符串然后在编码。

例子

  假设我们有下面的DataFrame,它的列名是idcategory

  1. id | category
  2. ----|----------
  3. 0 | a
  4. 1 | b
  5. 2 | c
  6. 3 | a
  7. 4 | a
  8. 5 | c

  category是字符串列,拥有三个标签a,b,c。把category作为输入列,categoryIndex作为输出列,使用StringIndexer我们可以得到下面的结果。

  1. id | category | categoryIndex
  2. ----|----------|---------------
  3. 0 | a | 0.0
  4. 1 | b | 2.0
  5. 2 | c | 1.0
  6. 3 | a | 0.0
  7. 4 | a | 0.0
  8. 5 | c | 1.0

  a的索引号为0是因为它的频率最高,c次之,b最后。

  另外,StringIndexer处理未出现的标签的策略有两个:

  • 抛出一个异常(默认情况)
  • 跳过出现该标签的行

  让我们回到上面的例子,但是这次我们重用上面的StringIndexer到下面的数据集。

  1. id | category
  2. ----|----------
  3. 0 | a
  4. 1 | b
  5. 2 | c
  6. 3 | d

  如果我们没有为StringIndexer设置怎么处理未见过的标签或者设置为error,它将抛出异常,否则若设置为skip,它将得到下面的结果。

  1. id | category | categoryIndex
  2. ----|----------|---------------
  3. 0 | a | 0.0
  4. 1 | b | 2.0
  5. 2 | c | 1.0

  下面是程序调用的例子。

  1. import org.apache.spark.ml.feature.StringIndexer
  2. val df = spark.createDataFrame(
  3. Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
  4. ).toDF("id", "category")
  5. val indexer = new StringIndexer()
  6. .setInputCol("category")
  7. .setOutputCol("categoryIndex")
  8. val indexed = indexer.fit(df).transform(df)
  9. indexed.show()