Cumulative Calculations: updateStateByKey()

To keep track of the log statistics for all of time, state must be maintained between processing RDD’s in a DStream.

To maintain state for key-pair values, the data may be too big to fit in memory on one machine - Spark Streaming can maintain the state for you. To do that, call the
updateStateByKey function of the Spark Streaming library.

First, in order to use updateStateByKey, checkpointing must be enabled on the streaming context. To do that, just call checkpoint
on the streaming context with a directory to write the checkpoint data. Here is
part of the main function of a streaming application that will save state for all of time:

  1. public class LogAnalyzerStreamingTotal {
  2. public static void main(String[] args) throws InterruptedException {
  3. SparkConf conf = new SparkConf().setAppName("Log Analyzer Streaming Total");
  4. JavaSparkContext sc = new JavaSparkContext(conf);
  5. JavaStreamingContext jssc = new JavaStreamingContext(sc,
  6. new Duration(10000)); // This sets the update window to be every 10 seconds.
  7. // Checkpointing must be enabled to use the updateStateByKey function.
  8. jssc.checkpoint("/tmp/log-analyzer-streaming");
  9. // TODO: Insert code for computing log stats.
  10. // Start the streaming server.
  11. jssc.start(); // Start the computation
  12. jssc.awaitTermination(); // Wait for the computation to terminate

To compute the content size statistics, simply use static variables
to save the current running sum, count, min and max of the content sizes.

  1. // These static variables stores the running content size values.
  2. private static final AtomicLong runningCount = new AtomicLong(0);
  3. private static final AtomicLong runningSum = new AtomicLong(0);
  4. private static final AtomicLong runningMin = new AtomicLong(Long.MAX_VALUE);
  5. private static final AtomicLong runningMax = new AtomicLong(Long.MIN_VALUE);

To update those values, first call map on the AccessLogDStream to retrieve a contentSizeDStream. Then just update the values for the static variables by calling
foreachRDD on the contentSizeDstream, and calling actions on the RDD:

  1. JavaDStream<Long> contentSizeDStream =
  2. accessLogDStream.map(ApacheAccessLog::getContentSize).cache();
  3. contentSizeDStream.foreachRDD(rdd -> {
  4. if (rdd.count() > 0) {
  5. runningSum.getAndAdd(rdd.reduce(SUM_REDUCER));
  6. runningCount.getAndAdd(rdd.count());
  7. runningMin.set(Math.min(runningMin.get(), rdd.min(Comparator.naturalOrder())));
  8. runningMax.set(Math.max(runningMax.get(), rdd.max(Comparator.naturalOrder())));
  9. System.out.print("Content Size Avg: " + runningSum.get() / runningCount.get());
  10. System.out.print(", Min: " + runningMin.get());
  11. System.out.println(", Max: " + runningMax.get());
  12. }
  13. });

For the other statistics, since they make use of key value pairs, static variables
can’t be used anymore. The amount of state that needs to be maintained
is potentially too big to fit in memory. So
for those stats, we’ll make use of updateStateByKey so Spark streaming will maintain
a value for every key in our dataset.

But before we can call updateStateByKey, we need to create a function to pass into it. updateStateByKey takes in a different reduce function.
While our previous sum reducer just took in two values and output their sum, this
reduce function takes in a current value and an iterator of values,
and outputs one new value.

  1. private static Function2<List<Long>, Optional<Long>, Optional<Long>>
  2. COMPUTE_RUNNING_SUM = (nums, current) -> {
  3. long sum = current.or(0L);
  4. for (long i : nums) {
  5. sum += i;
  6. }
  7. return Optional.of(sum);
  8. };

Finally, we can compute the keyed statistics for all of time with this code:

  1. // Compute Response Code to Count.
  2. // Note the use of updateStateByKey.
  3. JavaPairDStream<Integer, Long> responseCodeCountDStream = accessLogDStream
  4. .mapToPair(s -> new Tuple2<>(s.getResponseCode(), 1L))
  5. .reduceByKey(SUM_REDUCER)
  6. .updateStateByKey(COMPUTE_RUNNING_SUM);
  7. responseCodeCountDStream.foreachRDD(rdd -> {
  8. System.out.println("Response code counts: " + rdd.take(100));
  9. });
  10. // A DStream of ipAddresses accessed > 10 times.
  11. JavaDStream<String> ipAddressesDStream = accessLogDStream
  12. .mapToPair(s -> new Tuple2<>(s.getIpAddress(), 1L))
  13. .reduceByKey(SUM_REDUCER)
  14. .updateStateByKey(COMPUTE_RUNNING_SUM)
  15. .filter(tuple -> tuple._2() > 10)
  16. .map(Tuple2::_1);
  17. ipAddressesDStream.foreachRDD(rdd -> {
  18. List<String> ipAddresses = rdd.take(100);
  19. System.out.println("All IPAddresses > 10 times: " + ipAddresses);
  20. });
  21. // A DStream of endpoint to count.
  22. JavaPairDStream<String, Long> endpointCountsDStream = accessLogDStream
  23. .mapToPair(s -> new Tuple2<>(s.getEndpoint(), 1L))
  24. .reduceByKey(SUM_REDUCER)
  25. .updateStateByKey(COMPUTE_RUNNING_SUM);
  26. endpointCountsDStream.foreachRDD(rdd -> {
  27. List<Tuple2<String, Long>> topEndpoints =
  28. rdd.top(10, new ValueComparator<>(Comparator.<Long>naturalOrder()));
  29. System.out.println("Top Endpoints: " + topEndpoints);
  30. });

Run LogAnalyzerStreamingTotal.java
now for yourself.