User Defined Aggregate Functions (UDAFs)

Description

User-Defined Aggregate Functions (UDAFs) are user-programmable routines that act on multiple rows at once and return a single aggregated value as a result. This documentation lists the classes that are required for creating and registering UDAFs. It also contains examples that demonstrate how to define and register UDAFs in Scala and invoke them in Spark SQL.

Aggregator[-IN, BUF, OUT]

A base class for user-defined aggregations, which can be used in Dataset operations to take all of the elements of a group and reduce them to a single value.

IN - The input type for the aggregation.

BUF - The type of the intermediate value of the reduction.

OUT - The type of the final output result.

  • bufferEncoder: Encoder[BUF]

    Specifies the Encoder for the intermediate value type.

  • finish(reduction: BUF): OUT

    Transform the output of the reduction.

  • merge(b1: BUF, b2: BUF): BUF

    Merge two intermediate values.

  • outputEncoder: Encoder[OUT]

    Specifies the Encoder for the final output value type.

  • reduce(b: BUF, a: IN): BUF

    Aggregate input value a into current intermediate value. For performance, the function may modify b and return it instead of constructing new object for b.

  • zero: BUF

    The initial value of the intermediate result for this aggregation.

Examples

Type-Safe User-Defined Aggregate Functions

User-defined aggregations for strongly typed Datasets revolve around the Aggregator abstract class. For example, a type-safe user-defined average can look like:

  1. import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
  2. import org.apache.spark.sql.expressions.Aggregator
  3. case class Employee(name: String, salary: Long)
  4. case class Average(var sum: Long, var count: Long)
  5. object MyAverage extends Aggregator[Employee, Average, Double] {
  6. // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  7. def zero: Average = Average(0L, 0L)
  8. // Combine two values to produce a new value. For performance, the function may modify `buffer`
  9. // and return it instead of constructing a new object
  10. def reduce(buffer: Average, employee: Employee): Average = {
  11. buffer.sum += employee.salary
  12. buffer.count += 1
  13. buffer
  14. }
  15. // Merge two intermediate values
  16. def merge(b1: Average, b2: Average): Average = {
  17. b1.sum += b2.sum
  18. b1.count += b2.count
  19. b1
  20. }
  21. // Transform the output of the reduction
  22. def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  23. // Specifies the Encoder for the intermediate value type
  24. def bufferEncoder: Encoder[Average] = Encoders.product
  25. // Specifies the Encoder for the final output value type
  26. def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  27. }
  28. val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
  29. ds.show()
  30. // +-------+------+
  31. // | name|salary|
  32. // +-------+------+
  33. // |Michael| 3000|
  34. // | Andy| 4500|
  35. // | Justin| 3500|
  36. // | Berta| 4000|
  37. // +-------+------+
  38. // Convert the function to a `TypedColumn` and give it a name
  39. val averageSalary = MyAverage.toColumn.name("average_salary")
  40. val result = ds.select(averageSalary)
  41. result.show()
  42. // +--------------+
  43. // |average_salary|
  44. // +--------------+
  45. // | 3750.0|
  46. // +--------------+

Find full example code at “examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala” in the Spark repo.

  1. import java.io.Serializable;
  2. import org.apache.spark.sql.Dataset;
  3. import org.apache.spark.sql.Encoder;
  4. import org.apache.spark.sql.Encoders;
  5. import org.apache.spark.sql.SparkSession;
  6. import org.apache.spark.sql.TypedColumn;
  7. import org.apache.spark.sql.expressions.Aggregator;
  8. public static class Employee implements Serializable {
  9. private String name;
  10. private long salary;
  11. // Constructors, getters, setters...
  12. }
  13. public static class Average implements Serializable {
  14. private long sum;
  15. private long count;
  16. // Constructors, getters, setters...
  17. }
  18. public static class MyAverage extends Aggregator<Employee, Average, Double> {
  19. // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  20. @Override
  21. public Average zero() {
  22. return new Average(0L, 0L);
  23. }
  24. // Combine two values to produce a new value. For performance, the function may modify `buffer`
  25. // and return it instead of constructing a new object
  26. @Override
  27. public Average reduce(Average buffer, Employee employee) {
  28. long newSum = buffer.getSum() + employee.getSalary();
  29. long newCount = buffer.getCount() + 1;
  30. buffer.setSum(newSum);
  31. buffer.setCount(newCount);
  32. return buffer;
  33. }
  34. // Merge two intermediate values
  35. @Override
  36. public Average merge(Average b1, Average b2) {
  37. long mergedSum = b1.getSum() + b2.getSum();
  38. long mergedCount = b1.getCount() + b2.getCount();
  39. b1.setSum(mergedSum);
  40. b1.setCount(mergedCount);
  41. return b1;
  42. }
  43. // Transform the output of the reduction
  44. @Override
  45. public Double finish(Average reduction) {
  46. return ((double) reduction.getSum()) / reduction.getCount();
  47. }
  48. // Specifies the Encoder for the intermediate value type
  49. @Override
  50. public Encoder<Average> bufferEncoder() {
  51. return Encoders.bean(Average.class);
  52. }
  53. // Specifies the Encoder for the final output value type
  54. @Override
  55. public Encoder<Double> outputEncoder() {
  56. return Encoders.DOUBLE();
  57. }
  58. }
  59. Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class);
  60. String path = "examples/src/main/resources/employees.json";
  61. Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
  62. ds.show();
  63. // +-------+------+
  64. // | name|salary|
  65. // +-------+------+
  66. // |Michael| 3000|
  67. // | Andy| 4500|
  68. // | Justin| 3500|
  69. // | Berta| 4000|
  70. // +-------+------+
  71. MyAverage myAverage = new MyAverage();
  72. // Convert the function to a `TypedColumn` and give it a name
  73. TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
  74. Dataset<Double> result = ds.select(averageSalary);
  75. result.show();
  76. // +--------------+
  77. // |average_salary|
  78. // +--------------+
  79. // | 3750.0|
  80. // +--------------+

Find full example code at “examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java” in the Spark repo.

Untyped User-Defined Aggregate Functions

Typed aggregations, as described above, may also be registered as untyped aggregating UDFs for use with DataFrames. For example, a user-defined average for untyped DataFrames can look like:

  1. import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
  2. import org.apache.spark.sql.expressions.Aggregator
  3. import org.apache.spark.sql.functions
  4. case class Average(var sum: Long, var count: Long)
  5. object MyAverage extends Aggregator[Long, Average, Double] {
  6. // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  7. def zero: Average = Average(0L, 0L)
  8. // Combine two values to produce a new value. For performance, the function may modify `buffer`
  9. // and return it instead of constructing a new object
  10. def reduce(buffer: Average, data: Long): Average = {
  11. buffer.sum += data
  12. buffer.count += 1
  13. buffer
  14. }
  15. // Merge two intermediate values
  16. def merge(b1: Average, b2: Average): Average = {
  17. b1.sum += b2.sum
  18. b1.count += b2.count
  19. b1
  20. }
  21. // Transform the output of the reduction
  22. def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
  23. // Specifies the Encoder for the intermediate value type
  24. def bufferEncoder: Encoder[Average] = Encoders.product
  25. // Specifies the Encoder for the final output value type
  26. def outputEncoder: Encoder[Double] = Encoders.scalaDouble
  27. }
  28. // Register the function to access it
  29. spark.udf.register("myAverage", functions.udaf(MyAverage))
  30. val df = spark.read.json("examples/src/main/resources/employees.json")
  31. df.createOrReplaceTempView("employees")
  32. df.show()
  33. // +-------+------+
  34. // | name|salary|
  35. // +-------+------+
  36. // |Michael| 3000|
  37. // | Andy| 4500|
  38. // | Justin| 3500|
  39. // | Berta| 4000|
  40. // +-------+------+
  41. val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
  42. result.show()
  43. // +--------------+
  44. // |average_salary|
  45. // +--------------+
  46. // | 3750.0|
  47. // +--------------+

Find full example code at “examples/src/main/scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala” in the Spark repo.

  1. import java.io.Serializable;
  2. import org.apache.spark.sql.Dataset;
  3. import org.apache.spark.sql.Encoder;
  4. import org.apache.spark.sql.Encoders;
  5. import org.apache.spark.sql.Row;
  6. import org.apache.spark.sql.SparkSession;
  7. import org.apache.spark.sql.expressions.Aggregator;
  8. import org.apache.spark.sql.functions;
  9. public static class Average implements Serializable {
  10. private long sum;
  11. private long count;
  12. // Constructors, getters, setters...
  13. public Average() {
  14. }
  15. public Average(long sum, long count) {
  16. this.sum = sum;
  17. this.count = count;
  18. }
  19. public long getSum() {
  20. return sum;
  21. }
  22. public void setSum(long sum) {
  23. this.sum = sum;
  24. }
  25. public long getCount() {
  26. return count;
  27. }
  28. public void setCount(long count) {
  29. this.count = count;
  30. }
  31. }
  32. public static class MyAverage extends Aggregator<Long, Average, Double> {
  33. // A zero value for this aggregation. Should satisfy the property that any b + zero = b
  34. @Override
  35. public Average zero() {
  36. return new Average(0L, 0L);
  37. }
  38. // Combine two values to produce a new value. For performance, the function may modify `buffer`
  39. // and return it instead of constructing a new object
  40. @Override
  41. public Average reduce(Average buffer, Long data) {
  42. long newSum = buffer.getSum() + data;
  43. long newCount = buffer.getCount() + 1;
  44. buffer.setSum(newSum);
  45. buffer.setCount(newCount);
  46. return buffer;
  47. }
  48. // Merge two intermediate values
  49. @Override
  50. public Average merge(Average b1, Average b2) {
  51. long mergedSum = b1.getSum() + b2.getSum();
  52. long mergedCount = b1.getCount() + b2.getCount();
  53. b1.setSum(mergedSum);
  54. b1.setCount(mergedCount);
  55. return b1;
  56. }
  57. // Transform the output of the reduction
  58. @Override
  59. public Double finish(Average reduction) {
  60. return ((double) reduction.getSum()) / reduction.getCount();
  61. }
  62. // Specifies the Encoder for the intermediate value type
  63. @Override
  64. public Encoder<Average> bufferEncoder() {
  65. return Encoders.bean(Average.class);
  66. }
  67. // Specifies the Encoder for the final output value type
  68. @Override
  69. public Encoder<Double> outputEncoder() {
  70. return Encoders.DOUBLE();
  71. }
  72. }
  73. // Register the function to access it
  74. spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));
  75. Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
  76. df.createOrReplaceTempView("employees");
  77. df.show();
  78. // +-------+------+
  79. // | name|salary|
  80. // +-------+------+
  81. // |Michael| 3000|
  82. // | Andy| 4500|
  83. // | Justin| 3500|
  84. // | Berta| 4000|
  85. // +-------+------+
  86. Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
  87. result.show();
  88. // +--------------+
  89. // |average_salary|
  90. // +--------------+
  91. // | 3750.0|
  92. // +--------------+

Find full example code at “examples/src/main/java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java” in the Spark repo.

  1. -- Compile and place UDAF MyAverage in a JAR file called `MyAverage.jar` in /tmp.
  2. CREATE FUNCTION myAverage AS 'MyAverage' USING JAR '/tmp/MyAverage.jar';
  3. SHOW USER FUNCTIONS;
  4. +------------------+
  5. | function|
  6. +------------------+
  7. | default.myAverage|
  8. +------------------+
  9. CREATE TEMPORARY VIEW employees
  10. USING org.apache.spark.sql.json
  11. OPTIONS (
  12. path "examples/src/main/resources/employees.json"
  13. );
  14. SELECT * FROM employees;
  15. +-------+------+
  16. | name|salary|
  17. +-------+------+
  18. |Michael| 3000|
  19. | Andy| 4500|
  20. | Justin| 3500|
  21. | Berta| 4000|
  22. +-------+------+
  23. SELECT myAverage(salary) as average_salary FROM employees;
  24. +--------------+
  25. |average_salary|
  26. +--------------+
  27. | 3750.0|
  28. +--------------+