Per-partition operations in Spark

Versions: Spark 2.0.0

Spark was developed to work on big amount of data. If big means millions of items. For every item one or several costly operations are done, it'll lead quick to performance problems. It's one of the reasons why Spark proposes operations executed once per partition.

This post focuses on operations that are applied once on a collection of items. The first part describes available methods. The second part shows how they can be used through several test cases.

Per-partition operations

First of all it's important to understand the difference between normal and per-partition operations. Let's imagine that we must implement this algorithm in Spark:

# connect to database
# get stats of the record
# map stats to received item

All of these 3 steps can be defined in map(...) function. Concretely, it means that the number of database connection will be equal to the number of items. With per-partition basis mapping we could keep one connection par partition. Below table shows the difference:

mapper-partition map
# connect to database
# get stats of the record
# map stats to received item
# connect to database
# iterate over items
#   get stats of the record
#   map stats to received item

As you can see in this pseudocode example, per-partition functions operate on a collection of items. Normal function operates on each item separetly. It's the reason why, if our processing needs some costly operations (as database connection in the previous image), it could be more efficient to use per-partition functions.

Two available per-partition basis methods concerns mapping and foreach loop. For the first one, we can chose mapPartitions(Function) (with all its variants) or mapPartitionsWithIndex(Function, Boolean). Both return an Iterator typed to target class. The difference is that mapPartitionsWithIndex mapping function is able to recognize which partition executes given computation. The second difference is Boolean parameter. If set to true, it means that RDD created by mapping function will keep partitioner defined in parent RDD. According to the scaladoc:

`preservesPartitioning` (...) should be `false` unless this is a pair RDD and the input function doesn't modify the keys.

The second method is foreachPartitions(Function). As mapping method, this one works in similar way. The difference is that it returns nothing. Instead, it applies some changes on a collection of objects stored on the partition.

Example of per-partitions operations

Now we can see some tests using per-partition operations. For the ones illustrating mapPartitionsWithIndex(...), similar method is used - mapPartitionsToPair(Function, Boolean). It also accepts a boolean parameter indicating if partitioning must be preserved and, unlike mapPartitionsWithIndex, it's shorter to implement because it works directly on pair RDD. You can observe that in below tests:

private static final SparkConf CONFIGURATION =
  new SparkConf().setAppName("PerPartitionOperation Test").setMaster("local[1]");
private static final JavaSparkContext CONTEXT = new JavaSparkContext(CONFIGURATION);

private static final List<Integer> TEST_VALUES = 
  IntStream.rangeClosed(1, 100).boxed().collect(Collectors.toList());

@Test
public void should_map_per_partition_with_only_one_randomizer_created_on_each_partition() {
  List<Integer> replacedNumbers = CONTEXT.parallelize(TEST_VALUES, 10)
    .mapPartitions(numbersIterator -> {
      List<Integer> inputMultiplied = new ArrayList<>();
      // We want that one random integer replaces all numbers
      // within a partition
      Random randomizer = new Random();
      int replacingValue = randomizer.nextInt(100000);
      numbersIterator.forEachRemaining(number -> inputMultiplied.add(replacingValue));
      return inputMultiplied.iterator();
    }).collect();

  Set<Integer> differentReplacingNumbers = new HashSet<>();
  differentReplacingNumbers.addAll(replacedNumbers);
  assertThat(differentReplacingNumbers).hasSize(10);
}

@Test
public void should_iterate_per_partition_elements() {
  CollectionAccumulator<Tuple2<Integer, Integer>> broadcastAccumulator =
    CONTEXT.sc().collectionAccumulator("broadcast accumulator");

  CONTEXT.parallelize(TEST_VALUES, 10).foreachPartition(numbersIterator -> {
    Random randomizer = new Random();
    int replacingValue = randomizer.nextInt(100000);
    // Accumulator used only to test purposes
    numbersIterator.forEachRemaining(number ->
      broadcastAccumulator.add(new Tuple2<>(number, replacingValue)));
  });

  int previous = broadcastAccumulator.value().get(0)._2();
  for (int i = 0; i < broadcastAccumulator.value().size(); i++) {
    Tuple2<Integer, Integer> tuple = broadcastAccumulator.value().get(i);
    assertThat(tuple._2()).isEqualTo(previous);
    if (tuple._1()%10 == 0 && i+1 < broadcastAccumulator.value().size()) {
      // Check modulo of 10 because we expect to have
      // 10 partitions of 10 elements each
      previous = broadcastAccumulator.value().get(i+1)._2();
    }
  }
}

@Test
public void should_prevent_partitioning_when_key_doesn_t_change() {
  // Boolean flag in mapPartitionsWithIndex(...) is for optimization purposes. By setting
  // it at true, we inform Spark that provided function preserves keys and that it
  // operates on pair RDD.
  // Thanks to this information it knows if shuffle for subsequent transformations
  // (*ByKey, join...) can be avoided. If this flag is true,
  // Spark knows that we, developers, are sure that all data
  // resides in the same partition. By doing so, Spark won't shuffle them.

  // In this test we define the flag to true. In consequence, we should observe
  // only 1 shuffle, because of partitionBy(...) call.
  // And used partitioner doesn't balance well all keys, so it risks to
  // give incorrect results
  JavaPairRDD<Integer, Integer> integerIntegerJavaPairRDD = CONTEXT.parallelizePairs(Arrays.asList(
    new Tuple2<>(11, 101), new Tuple2<>(11, 111),
    new Tuple2<>(31, 131), new Tuple2<>(31, 151)
  )).partitionBy(new DummySwappingPartitioner());

  // To illustrate 'preserve partitioning' flag, we use mapPartitionsToPair
  // since under-the-hood both create MapPartitionsRDD in the same
  // way
  JavaPairRDD<Integer, Integer> tuple2JavaRDD =
    integerIntegerJavaPairRDD.mapPartitionsToPair(new TwoValuesMapFunction(), true);

  // now, call transformation supposing to cause shuffling
  JavaPairRDD<Integer, Integer> valuesByKeyPairRDD = tuple2JavaRDD
    .reduceByKey((v1, v2) -> v1+v2);

  tuple2JavaRDD.collect();

  List<Tuple2<Integer, Integer>>[] dataLists = 
    tuple2JavaRDD.collectPartitions(new int[]{0, 1});
  assertThat(dataLists[0]).hasSize(2);
  assertThat(dataLists[0]).contains(new Tuple2<>(11, 2), new Tuple2<>(31, 1));
  assertThat(dataLists[1]).hasSize(2);
  assertThat(dataLists[1]).contains(new Tuple2<>(11, 2), new Tuple2<>(31, 1));
  // one shuffle is expected in the first RDD because of explicit partitioner use
  assertThat(integerIntegerJavaPairRDD.toDebugString()).contains("ShuffledRDD[1]");

  assertThat(valuesByKeyPairRDD.toDebugString()).containsOnlyOnce("ShuffledRDD");
  dataLists = valuesByKeyPairRDD.collectPartitions(new int[]{0, 1});
  assertThat(dataLists[0]).hasSize(2);
  assertThat(dataLists[0]).containsOnly(new Tuple2<>(11, 2), new Tuple2<>(31, 1));
  assertThat(dataLists[1]).hasSize(2);
  assertThat(dataLists[1]).containsOnly(new Tuple2<>(11, 2), new Tuple2<>(31, 1));
}

@Test
public void should_shuffle_even_if_prevents_partition_is_set_to_true() {
  // This test doesn't prevent against shuffle because the flag
  // is set to false to tell Spark that RDD pair keys will change
  JavaPairRDD<Integer, Integer> integerIntegerJavaPairRDD = CONTEXT.parallelizePairs(Arrays.asList(
    new Tuple2<>(11, 101), new Tuple2<>(11, 111),
    new Tuple2<>(31, 131), new Tuple2<>(31, 151)
  )).partitionBy(new DummySwappingPartitioner());

  JavaPairRDD<Integer, Integer> tuple2JavaRDD =
    integerIntegerJavaPairRDD.mapPartitionsToPair(new TwoValuesMapFunction(), false);

  // now, call transformation supposing to cause shuffling
  JavaPairRDD<Integer, Integer> valuesByKeyPairRDD = tuple2JavaRDD
    .reduceByKey((v1, v2) -> v1+v2);

  tuple2JavaRDD.collect();

  List<Tuple2<Integer, Integer>>[] dataLists = 
    tuple2JavaRDD.collectPartitions(new int[]{0, 1});
  assertThat(dataLists[0]).hasSize(2);
  assertThat(dataLists[0]).contains(new Tuple2<>(11, 2), new Tuple2<>(31, 1));
  assertThat(dataLists[1]).hasSize(2);
  assertThat(dataLists[1]).contains(new Tuple2<>(11, 2), new Tuple2<>(31, 1));
  // one shuffle is expected in the first RDD because of explicit partitioner use
  assertThat(integerIntegerJavaPairRDD.toDebugString()).contains("ShuffledRDD[1]");

  assertThat(valuesByKeyPairRDD.toDebugString()).contains("ShuffledRDD[1]", "ShuffledRDD[3]");
  dataLists = valuesByKeyPairRDD.collectPartitions(new int[]{0, 1});
  // Because of shuffling, reduceByKey will work correctly. But shuffling
  // makes that we don't know where are stored our tuples, so
  // we make a greedy check
  List<Tuple2<Integer, Integer>> sums = new ArrayList<>();
  dataLists[0].stream().forEach(tuple -> sums.add(tuple));
  dataLists[1].stream().forEach(tuple -> sums.add(tuple));
  assertThat(sums).hasSize(2);
  assertThat(sums).containsOnly(new Tuple2<>(11, 4), new Tuple2<>(31, 2));
}


private static class TwoValuesMapFunction
        implements PairFlatMapFunction<Iterator<Tuple2<Integer, Integer>>, Integer, Integer>, Serializable {

  @Override
  public Iterator<Tuple2<Integer, Integer>> call(Iterator<Tuple2<Integer, Integer>> pairNumbersIterator) throws Exception {
    List<Tuple2<Integer, Integer>> outputPaired = new ArrayList<>();
    pairNumbersIterator.forEachRemaining(numberTuple -> {
      int pairValue = 1;
      if (numberTuple._1() < 12) {
        pairValue = 2;
      }
      outputPaired.add(new Tuple2<>(numberTuple._1(), pairValue));
    });
    return outputPaired.iterator();
  }
}

private static class DummySwappingPartitioner extends Partitioner {

  private int keyBigNext = 1;
  private int keySmallNext = 0;

  @Override
  public int numPartitions() {
    return 2;
  }

  @Override
  public int getPartition(Object key) { 
    Integer keyInt = (Integer) key;
    if (keyInt > 11) {
      int p = keyBigNext;
      keyBigNext = p == 0 ? 1 : 0;
      return p;
    }
    int p = keySmallNext;
    keySmallNext = p == 0 ? 1 : 0;
    return p;
  }
}

To conclude we can tell that per-partition operations improve the performance if some heavy operations can be shared among several processed items. But Spark doesn't have per-partition based version for each of transformations. Currently only mapping is handled. In additional, it's also possible to apply some internal changes through a foreachPartition method. However, there are some dangers, as the one telling Spark to not shuffle. If incorrectly implemented partitioner is used, it can lead to wrong results.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!