Joins in Apache Beam

Versions: Apache Beam 2.2.0 https://github.com/bartosz25/beam-learning

Dealing with joins in relational databases is quite straightforward thanks to underlying data structures (e.g. trees). However it's less convenient to work with them in data processing world where schemaless and denormalization rule.

A virtual conference at the intersection of Data and AI. This is not a conference for the hype. Its real users talking about real experiences.
- 40+ speakers with the likes of Hannes from Duck DB, Sol Rashidi, Joe Reis, Sadie St. Lawrence, Ryan Wolf from nvidia, Rebecca from lidl
- 12th September 2024
- Three simultaneous tracks
- Panels, Lighting Talks, Keynotes, Booth crawls, Roundtables and Entertainment.
- Topics include (ingestion, finops for data, data for inference (feature platforms), data for ML observability
- 100% virtual and 100% free

👉 Register here

This post introduces the joins in Apache Beam. The 2 first parts show 2 different manners to execute them. The first one uses group by key transform while the second one calls a Beam's extension. The third part shows how to implement joins in the code.

Group-by-key-based joins

From the posts about Spark SQL joins we can learn that Spark SQL provides different strategies to deal with joins: sort-merge, shuffle and broadcast. In Apache Beam we can reproduce some of them with the methods provided by the Java's SDK.

The first of types, broadcast join, consists on sending an additional input to the main processed dataset. In Apache Beam it can be achieved with the help of side inputs (you can read more about them in the post Side input in Apache Beam. However, this solution has some limitations. The additional input must be small enough to fit in the worker memory and not slow down the processing because of object serialization/deserialization overhead. A good candidate for broadcast join is the list of countries mapped by their ISO codes. There are no more than 200 countries so the overhead should be small. The implementation details are similar to the ones defined in the post. The only requirement is that the additional data must be a map (View.<K, V>asMap())).

The second join category is shuffle join. It uses the shuffle operation since it places all entries having the same key on the same worker as the instances of org.apache.beam.sdk.transforms.join.CoGbkResult. An example of shuffle join use case can be the join of customers with their orders that logically (= if you're working with Beam it's certainly because your dataset doesn't fit into a single machine) should represent more than 200 entries to join. Technically the shuffle join translates by 4 steps:

  1. Define PCollections to join
  2. Define the TupleTag corresponding to the created PCollections
  3. Merge the PCollections with org.apache.beam.sdk.transforms.join.CoGroupByKey transform
  4. Process received org.apache.beam.sdk.transforms.join.CoGbkResult with appropriated transform

Thanks to TupleTags defining the types of joined datasets, we can do the join of datasets having the values of different types.

Extension-based joins

The steps described above are good to know but for the real use it may be simpler to pass by clearer abstraction - joins. Apache Beam brings them with the join-library extension.

The library supports: inner join, left outer join, right outer join and full outer join operations. And unsurprisingly under-the-hood it uses the same concepts as described in the first section for shuffle join: CoGroupByKey transform. The only difference is that the join-library extension joins provide the logic to define missing values. For instance in the case of left outer join, where the right dataset doesn't contain the data matching for given key, we can simply define the missing value representation. It can be quickly notices with a simple analysis of join methods signatures:

Another difference between native-based and extension-based solutions is the fact the the first one allows to customize the processing in the case of 1:n join. By default the extension will produce n joined pairs where n is the number of corresponding entries in the right side of join. For the case of native code, as proved in the should_join_2_datasets_which_all_have_1_matching_key_with_native_sdk_code test case we've more flexibility.

Join examples

The following tests show how to use 2 described joining methods:

@Test
public void should_join_2_datasets_which_all_have_1_matching_key_with_native_sdk_code() {
  Pipeline pipeline = BeamFunctions.createPipeline("1:n join with native SDK");
  List<KV<String, Integer>> ordersPerUser1 = Arrays.asList(
    KV.of("user1", 1000), KV.of("user2", 200), KV.of("user3", 100)
  );
  List<KV<String, Integer>> ordersPerUser2 = Arrays.asList(
    KV.of("user1", 1100), KV.of("user2", 210), KV.of("user3", 110),
    KV.of("user1", 1200), KV.of("user2", 220), KV.of("user3", 120)
  );

  PCollection<KV<String, Integer>> ordersPerUser1Dataset = pipeline.apply(Create.of(ordersPerUser1));
  PCollection<KV<String, Integer>> ordersPerUser2Dataset = pipeline.apply(Create.of(ordersPerUser2));

  final TupleTag<Integer> amountTagDataset1 = new TupleTag<>();
  final TupleTag<Integer> amountTagDataset2 = new TupleTag<>();
  PCollection<KV<String, CoGbkResult>> groupedCollection = KeyedPCollectionTuple
    .of(amountTagDataset1, ordersPerUser1Dataset)
    .and(amountTagDataset2, ordersPerUser2Dataset)
    .apply(CoGroupByKey.create());

  PCollection<KV<String, Integer>> totalAmountsPerUser = groupedCollection.apply(ParDo.of(new DoFn<KV<String, CoGbkResult>, KV<String, Integer>>() {
    @ProcessElement
    public void processElement(ProcessContext processContext) {
      KV<String, CoGbkResult> element = processContext.element();
      Iterable<Integer> dataset1Amounts = element.getValue().getAll(amountTagDataset1);
      Iterable<Integer> dataset2Amounts = element.getValue().getAll(amountTagDataset2);
      Integer sumAmount = StreamSupport.stream(Iterables.concat(dataset1Amounts, dataset2Amounts).spliterator(), false)
        .collect(Collectors.summingInt(n -> n));
      processContext.output(KV.of(element.getKey(), sumAmount));
    }
  }));

  PAssert.that(totalAmountsPerUser).containsInAnyOrder(KV.of("user1", 3300), KV.of("user2", 630),
    KV.of("user3", 330));
  pipeline.run().waitUntilFinish();
}


@Test
public void should_join_2_datasets_with_side_inputs() {
  Pipeline pipeline = BeamFunctions.createPipeline("Broadcast join with side input");
  List<KVvString, String>> ordersWithCountry = Arrays.asList(
    KV.of("order_1", "fr"), KV.of("order_2", "fr"), KV.of("order_3", "pl")
  );
  List<KV<String, String>> countriesWithIsoCode = Arrays.asList(
    KV.of("fr", "France"), KV.of("pl", "Poland"), KV.of("de", "Germany")
  );

  PCollection<KV<String, String>> ordersWithCountriesDataset = pipeline.apply(Create.of(ordersWithCountry));
  PCollection<KV<String, String>> countriesMapDataset = pipeline.apply(Create.of(countriesWithIsoCode));
  PCollectionView<Map<String, String>> countriesSideInput = countriesMapDataset.apply(View.asMap());
  PCollection<String> ordersSummaries = ordersWithCountriesDataset.apply(ParDo.of(new DoFn<KV<String, String>, String>() {
    @ProcessElement
    public void processElement(ProcessContext context) {
      Map<String, String> countriesByIso = context.sideInput(countriesSideInput);
      KV<String, String> processedElement = context.element();
      String orderCountry = countriesByIso.get(processedElement.getValue());
      String orderSummary = processedElement.getKey() + " (" + orderCountry + ")";
      context.output(orderSummary);
    }
  }).withSideInputs(countriesSideInput));

  PAssert.that(ordersSummaries).containsInAnyOrder("order_1 (France)", "order_2 (France)", "order_3 (Poland)");
  pipeline.run().waitUntilFinish();
}

@Test
public void should_do_inner_join_on_2_datasets_with_sdk_extension() {
  Pipeline pipeline = BeamFunctions.createPipeline("Inner join with the extension");
  List<KV<String, Integer>> ordersPerUser1 = Arrays.asList(
    KV.of("user1", 1000), KV.of("user2", 200), KV.of("user3", 100), KV.of("user6", 100)
  );
  List<KV<String, Integer>> ordersPerUser2 = Arrays.asList(
    KV.of("user1", 1100), KV.of("user2", 210), KV.of("user3", 110), KV.of("user7", 100)
  );

  PCollection<KV<String, Integer>> ordersPerUser1Dataset = pipeline.apply(Create.of(ordersPerUser1));
  PCollection<KV<String, Integer>> ordersPerUser2Dataset = pipeline.apply(Create.of(ordersPerUser2));

  PCollection<KV<String, KV<Integer, Integer>>> joinedDatasets = Join.innerJoin(ordersPerUser1Dataset, ordersPerUser2Dataset);
  PCollection<KV<String, Integer>> amountsPerUser = joinedDatasets.apply(ParDo.of(new AmountsCalculator()));

  // user6 and user7 are ignored because they're not included in both datasets
  PAssert.that(amountsPerUser).containsInAnyOrder(KV.of("user1", 2100), KV.of("user2", 410), KV.of("user3", 210));
  pipeline.run().waitUntilFinish();
}


@Test
public void should_output_output_2_pairs_for_1_to_2_relationship_join() {
  Pipeline pipeline = BeamFunctions.createPipeline("Extension inner join for 1:n relationship");
  List<KV<String, Integer>> ordersPerUser1 = Arrays.asList(
    KV.of("user1", 1000), KV.of("user2", 200), KV.of("user3", 100)
  );
  List<KV<String, Integer>> ordersPerUser2 = Arrays.asList(
    KV.of("user1", 1100), KV.of("user2", 210), KV.of("user3", 110), KV.of("user2", 300)
  );

  PCollection<KV<String, Integer>> ordersPerUser1Dataset = pipeline.apply(Create.of(ordersPerUser1));
  PCollection<KV<String, Integer>> ordersPerUser2Dataset = pipeline.apply(Create.of(ordersPerUser2));
  PCollection<KV<String, KV<Integer, Integer>>> joinedDatasets = Join.innerJoin(ordersPerUser1Dataset, ordersPerUser2Dataset);
  PCollection<KV<String, Integer>> amountsPerUser = joinedDatasets.apply(ParDo.of(new AmountsCalculator()));

  // Join extension gives a little bit less of flexibility than the custom join processing for the case of 1:n
  // joins. It doesn't allow to combine multiple values into a single output. Instead it returns every
  // combination of joined keys
  PAssert.that(amountsPerUser).containsInAnyOrder(KV.of("user1", 2100), KV.of("user2", 410), KV.of("user3", 210),
    KV.of("user2", 500));
  pipeline.run().waitUntilFinish();
}


@Test
public void should_do_outer_full_join_on_2_datasets_with_sdk_extension() {
  Pipeline pipeline = BeamFunctions.createPipeline("Extension outer full join");
  List<KV<String, Integer>> ordersPerUser1 = Arrays.asList(
    KV.of("user1", 1000), KV.of("user2", 200), KV.of("user3", 100), KV.of("user6", 100)
  );
  List<KV<String, Integer>> ordersPerUser2 = Arrays.asList(
    KV.of("user1", 1100), KV.of("user2", 210), KV.of("user3", 110), KV.of("user7", 100)
  );

  PCollection<KV<String, Integer>> ordersPerUser1Dataset = pipeline.apply(Create.of(ordersPerUser1));
  PCollection<KV<String, Integer>> ordersPerUser2Dataset = pipeline.apply(Create.of(ordersPerUser2));

  PCollection<KV<String, KV<Integer, Integer>>> joinedDatasets = Join.fullOuterJoin(ordersPerUser1Dataset, ordersPerUser2Dataset,
    0, 0);
  PCollection<KV<String, Integer>> amountsPerUser = joinedDatasets.apply(ParDo.of(new AmountsCalculator()));
  // user6 and user7 are ignored because they're not included in both datasets
  PAssert.that(amountsPerUser).containsInAnyOrder(KV.of("user1", 2100), KV.of("user2", 410), KV.of("user3", 210),
    KV.of("user6", 100), KV.of("user7", 100));
  pipeline.run().waitUntilFinish();
}

@Test
public void should_do_outer_left_join_on_2_datasets_with_sdk_extension() {
  Pipeline pipeline = BeamFunctions.createPipeline("Extension outer left join");
  List<KV<String, Integer>> ordersPerUser1 = Arrays.asList(
    KV.of("user1", 1000), KV.of("user2", 200), KV.of("user3", 100), KV.of("user6", 100)
  );
  List<KV<String, Integer>> ordersPerUser2 = Arrays.asList(
    KV.of("user1", 1100), KV.of("user2", 210), KV.of("user3", 110), KV.of("user7", 100)
  );

  PCollection<KV<String, Integer>> ordersPerUser1Dataset = pipeline.apply(Create.of(ordersPerUser1));
  PCollection<KV<String, Integer>> ordersPerUser2Dataset = pipeline.apply(Create.of(ordersPerUser2));
  PCollection<KV<String, KV<Integer, Integer>>> joinedDatasets = Join.leftOuterJoin(ordersPerUser1Dataset, ordersPerUser2Dataset,
    0);
  PCollection<KV<String, Integer>> amountsPerUser = joinedDatasets.apply(ParDo.of(new AmountsCalculator()));
  // user6 and user7 are ignored because they're not included in both datasets
  PAssert.that(amountsPerUser).containsInAnyOrder(KV.of("user1", 2100), KV.of("user2", 410), KV.of("user3", 210),
    KV.of("user6", 100));
  pipeline.run().waitUntilFinish();
}

@Test
public void should_do_outer_right_join_on_2_datasets_with_sdk_extension() {
  Pipeline pipeline = BeamFunctions.createPipeline("Extension outer right join");
  List<KV<String, Integer>> ordersPerUser1 = Arrays.asList(
    KV.of("user1", 1000), KV.of("user2", 200), KV.of("user3", 100), KV.of("user6", 100)
  );
  List<KV<String, Integer>> ordersPerUser2 = Arrays.asList(
    KV.of("user1", 1100), KV.of("user2", 210), KV.of("user3", 110), KV.of("user7", 100)
  );

  PCollection<KV<String, Integer>> ordersPerUser1Dataset = pipeline.apply(Create.of(ordersPerUser1));
  PCollection<KV<String, Integer>> ordersPerUser2Dataset = pipeline.apply(Create.of(ordersPerUser2));
  PCollection<KV<String, KV<Integer, Integer>>> joinedDatasets = Join.rightOuterJoin(ordersPerUser1Dataset, ordersPerUser2Dataset,
    0);
  PCollection<KV<String, Integer>> amountsPerUser = joinedDatasets.apply(ParDo.of(new AmountsCalculator()));
  // user6 and user7 are ignored because they're not included in both datasets
  PAssert.that(amountsPerUser).containsInAnyOrder(KV.of("user1", 2100), KV.of("user2", 410), KV.of("user3", 210),
    KV.of("user7", 100));
  pipeline.run().waitUntilFinish();
}

class AmountsCalculator extends DoFn<KV<String, KV<Integer, Integer>>, KV<String, Integer>> {
  @ProcessElement
  public void processElement(ProcessContext processContext) {
    KV<String, KV<Integer, Integer>> element = processContext.element();
    int totalAmount = element.getValue().getKey() + element.getValue().getValue();
    processContext.output(KV.of(element.getKey(), totalAmount));
  }
} 

Joining data in distributed environment is not an easy task. Rarely we've the support of the data structures as in the case of RDBMS joins per keys. The available solutions are mainly based on the fact of shuffling the data between workers. The first of them, presented in the first section, uses CoGroupByKey transform to locate all entries sharing the joining key on the same node. So grouped elements can be freely processed either as a set of items or as a n joined pairs. The second described solution was based on Apache Beam's extension called join-library. Under-the-hood it does basically the same as CoGroupByKey transform except the fact that for the 1:n relationships it always returns n pairs of joined entries. The other difference is that the library allows to define a default value in the case of outer joins.


If you liked it, you should read:

đź“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!