Stateful processing in Apache Beam

Versions: Apache Beam 2.2.0 https://github.com/bartosz25/beam-learning

Real-time processing is most of the time somehow related to stateful processing. Either we need to solve some sessionization problem, count the number of visitors per minute etc. Not surprisingly Apache Beam comes with the API adapted to put in place the solutions to them.

This post introduces the idea of stateful processing in Apache Beam. In the first section we'll discover the main points about this category of processing implemented in the framework. The second part focuses on the code and presents the objects involved in stateful applications. The last section gives some examples of stateful processing.

State in Apache Beam

The state in Apache Beam represents a situation occurred between 2 different points in the time. It can for instance group all events produced between this interval, aggregate, reduce them etc. Thus, the state is mutable since it evolves with time and new incoming elements. Technically speaking, Beam's stateful code allows each processing function to access persistent mutable state for each processed entry.

State storage

I've spent some time on looking for how the state storage is currently implemented in Apache Beam. The answer can be found in the org.apache.beam.sdk.state package and the implementations of the interfaces representing data structures accepted in stateful transform: MapState, SetState or ValueState.

The default implementation used by direct runner stores the state in memory. However, as Kenneth Knowles presented it during the Stateful processing of massive out of order streams with Apache Beam talk, Apache Beam is used to describe the computation pipeline and the state storage details rely on the runner executing the code.

For instance the state implementation in Spark relies on in-memory storage (stateTable attribute in org.apache.beam.runners.spark.stateful.SparkStateInternals). The objects are represented as an array of bytes.

Though it's quite possible that the state is stored in less volatile place, as for instance a key-value data store or simply on disk. It's up to runner to decide about the used strategy.

We can develop that short description and characterize the state with the following properties:

State API

The state definition consists on the use of @StateId annotation on a private and final field being the instance of org.apache.beam.sdk.state.StateSpec<StateT extends State>. The state is identified by the name that must be unique through the transform. The state must be final because it can't be modified after the declaration. Similarly to its private access - the state must remain local to the transform. Regarding to the unique name, it's required to an efficient identification. This requirement to declare the state makes the processing fail fast, i.e. we don't even need to start the pipeline to see that it's incorrectly defined.

The identification of state by name is used later, when the state is attached to the processing (@ProcessElement-annotated) method. The same @StateId annotation is used to specify the state in this context as an additional argument. The arguments of @ProcessElement method are later analyzed by DoFnSignatures#analyzeExtraParameter(ErrorReporter methodErrors, FnAnalysisContext fnContext, MethodAnalysisContext methodContext, TypeDescriptor<? extends DoFn<?, ?>> fnClass, ParameterDescription param, TypeDescriptor<?> inputT, TypeDescriptor<?> outputT).

The type of state passed in processing method is not StateSpec but org.apache.beam.sdk.state.State. The state can be one of the following types:

State examples

The points described above are presented in the following tests:

private static final Instant NOW = new Instant(0);
private static final Instant SEC_1_DURATION = NOW.plus(Duration.standardSeconds(1));
private static final Instant SEC_2_DURATION = NOW.plus(Duration.standardSeconds(2));
private static final Instant SEC_5_DURATION = NOW.plus(Duration.standardSeconds(5));
private static final Instant SEC_6_DURATION = NOW.plus(Duration.standardSeconds(6));

@Test
public void should_fail_on_applying_stateful_transform_for_not_key_value_pairs() {
  Pipeline pipeline = BeamFunctions.createPipeline("Shared state example");
  Coder<String> utfCoder = StringUtf8Coder.of();
  TestStream<String> words = TestStream.create(utfCoder).addElements(
      TimestampedValue.of("cat", SEC_1_DURATION), TimestampedValue.of("hat", SEC_1_DURATION)
    )
    .advanceWatermarkToInfinity();
  Duration windowDuration = Duration.standardSeconds(5);
  Window<String> window = Window.<String>into(FixedWindows.of(windowDuration));

  assertThatThrownBy(() -> {
    pipeline.apply(words).apply(window).apply(ParDo.of(new MinLengthTextSharedStateProcessing(2)));
    pipeline.run().waitUntilFinish();
  }).isInstanceOf(IllegalArgumentException.class)
    .hasMessage("ParDo requires its input to use KvCoder in order to use state and timers.");
}

@Test
public void should_correctly_increment_counter_for_each_encountered_item() {
  Pipeline pipeline = BeamFunctions.createPipeline("Counter example");
  Coder<String> utf8Coder = StringUtf8Coder.of();
  Coder<Integer> varIntCoder = VarIntCoder.of();
  KvCoder<String, Integer> keyValueCoder = KvCoder.of(utf8Coder, varIntCoder);
  TestStream<KV<String, Integer>> words = TestStream.create(keyValueCoder).addElements(
      TimestampedValue.of(KV.of("a", 1), SEC_1_DURATION), TimestampedValue.of(KV.of("a", 2), SEC_1_DURATION),
      TimestampedValue.of(KV.of("b", 5), SEC_1_DURATION), TimestampedValue.of(KV.of("a", 6), SEC_1_DURATION),
      TimestampedValue.of(KV.of("c", 2), SEC_1_DURATION), TimestampedValue.of(KV.of("a", 7), SEC_1_DURATION),
      TimestampedValue.of(KV.of("a", 3), SEC_1_DURATION), TimestampedValue.of(KV.of("a", 9), SEC_1_DURATION),
      TimestampedValue.of(KV.of("d", 2), SEC_1_DURATION), TimestampedValue.of(KV.of("a", 1), SEC_1_DURATION),
      TimestampedValue.of(KV.of("a", 2), SEC_1_DURATION)
    )
    .advanceWatermarkToInfinity();
  Duration windowDuration = Duration.standardSeconds(15);
  Window<KV<String, Integer>> window = Window.<KV<String, Integer>>into(FixedWindows.of(windowDuration))
    .triggering(AfterPane.elementCountAtLeast(20))
    .withAllowedLateness(windowDuration, Window.ClosingBehavior.FIRE_ALWAYS)
    .accumulatingFiredPanes();

  PCollection<String> results = pipeline.apply(words).apply(window)
          .apply(ParDo.of(new DoFn<KV<String, Integer>, String>() {
    private static final String COUNTER_NAME = "counter";

    // Its definition is required, otherwise the following exception is thrown:
    // parameter of type MapState<String, Integer> at index 1: reference to undeclared StateId: "counter"
    @StateId(COUNTER_NAME)
    private final StateSpec<MapState<String, Integer>> mapState = StateSpecs.map();

    @ProcessElement
    public void process(ProcessContext processContext,
            @StateId(COUNTER_NAME) MapState<String, Integer> letterCounterState) {
      KV<String, Integer> element = processContext.element();
      ReadableState<Integer> letterSumState = letterCounterState.get(element.getKey());
      int currentSum = letterSumState.read() != null ? letterSumState.read() : 0;
      int letterSum = currentSum + element.getValue();
      letterCounterState.put(element.getKey(), letterSum);
      processContext.output(element.getKey()+"="+letterSum);
    } 
  }));
  IntervalWindow window1 = new IntervalWindow(NOW, NOW.plus(windowDuration));
  // It also shows that state is not adapted to be returned every time because instead of having a, b,c and d
  // with final values, we have intermediary results
  PAssert.that(results).inFinalPane(window1).containsInAnyOrder("a=31", "a=29", "a=28", "a=19", "a=16", "a=9",
    "a=3", "b=5", "a=1", "c=2", "d=2");
  pipeline.run().waitUntilFinish();
}

@Test
public void should_show_the_use_of_combine_state() {
  Pipeline pipeline = BeamFunctions.createPipeline("Combined state example");
  Coder<String> utf8Coder = StringUtf8Coder.of();
  KvCoder<String, String> keyValueCoder = KvCoder.of(utf8Coder, utf8Coder);
  TestStream<KV<String, String>> words = TestStream.create(keyValueCoder).addElements(
      TimestampedValue.of(KV.of("10:00", "/index.html"), SEC_1_DURATION),
      TimestampedValue.of(KV.of("10:01", "/cart.html"), SEC_2_DURATION),
      TimestampedValue.of(KV.of("10:01", "/cancel_order.html"), SEC_1_DURATION),
      TimestampedValue.of(KV.of("10:05", "/delivery.html"), SEC_5_DURATION),
      TimestampedValue.of(KV.of("10:02", "/index.html"), SEC_2_DURATION),
      TimestampedValue.of(KV.of("10:06", "/cart.html"), SEC_6_DURATION),
      TimestampedValue.of(KV.of("10:01", "/login.html"), SEC_6_DURATION),
      TimestampedValue.of(KV.of("10:07", "/payment.html"), SEC_6_DURATION.plus(Duration.standardSeconds(1))),
      TimestampedValue.of(KV.of("10:08", "/order_confirmation.html"), SEC_6_DURATION.plus(Duration.standardSeconds(2)))
  )
  .advanceWatermarkToInfinity();
  Duration windowDuration = Duration.standardSeconds(5);
  Window<KV<String, String>> window = Window.<KV<String, String>>into(FixedWindows.of(windowDuration))
    .triggering(AfterPane.elementCountAtLeast(20))
    .withAllowedLateness(windowDuration, Window.ClosingBehavior.FIRE_ALWAYS)
    .accumulatingFiredPanes();

  PCollection<KV<String, Integer>> results = pipeline.apply(words).apply(window).apply(ParDo.of(new SessionTransform()));
  IntervalWindow window1 = new IntervalWindow(NOW, NOW.plus(windowDuration));
  // This assert shows that the stateful processing doesn't work well when we output the computed values every time
  // The output for 10:01 is returned twice
  PAssert.that(results).inFinalPane(window1).containsInAnyOrder(KV.of("10:00", 1), KV.of("10:01", 2), KV.of("10:02", 1), KV.of("10:01", 1));
  // The presence of 10:01 with different value than in previous window proves that the state is bounded
  // to the window duration
  IntervalWindow window2 = new IntervalWindow(NOW.plus(windowDuration), NOW.plus(windowDuration).plus(windowDuration));
  PAssert.that(results).inFinalPane(window2).containsInAnyOrder(KV.of("10:05", 1), KV.of("10:06", 1),
    KV.of("10:07", 1), KV.of("10:08", 1), KV.of("10:01", 1));
  pipeline.run().waitUntilFinish();
}

@Test
public void should_accumulate_encountered_words_with_bagstate() {
  Pipeline pipeline = BeamFunctions.createPipeline("Bag state example");
  Coder<String> utf8Coder = StringUtf8Coder.of();
  KvCoder<String, String> keyValueCoder = KvCoder.of(utf8Coder, utf8Coder);
  TestStream<KV<String, String>> words = TestStream.create(keyValueCoder).addElements(
    TimestampedValue.of(KV.of("p", "paradigm"), SEC_1_DURATION),
    TimestampedValue.of(KV.of("p", "programming"), SEC_1_DURATION)
  )
  .advanceWatermarkToInfinity();
  Duration windowDuration = Duration.standardSeconds(15);
  Window<KV<String, String>> window = Window.<KV<String, String>>into(FixedWindows.of(windowDuration))
    .triggering(AfterPane.elementCountAtLeast(20))
    .withAllowedLateness(windowDuration, Window.ClosingBehavior.FIRE_ALWAYS)
    .accumulatingFiredPanes();

  PCollection<String> results = pipeline.apply(words).apply(window)
          .apply(ParDo.of(new DoFn<KV<String, String>, String>() {
    private static final String ACCUMULATOR_NAME = "accumulator";

    @StateId(ACCUMULATOR_NAME)
    private final StateSpec<BagState<String>> accumulatorStateSpec = StateSpecs.bag();

    @ProcessElement
    public void processElement(ProcessContext processContext,
            @StateId(ACCUMULATOR_NAME) BagState<String> wordsAccumulator) {
      KV<String, String> letterWordPair = processContext.element();
      wordsAccumulator.add(letterWordPair.getValue());
      processContext.output(Joiner.on(",").join(wordsAccumulator.read()));
    }
  }));

  IntervalWindow window1 = new IntervalWindow(NOW, NOW.plus(windowDuration));
  // Since the ordering is not deterministic, we check against 2 possible outputs
  PAssert.that(results).inFinalPane(window1).satisfies((SerializableFunction<Iterable<String>, Void>) input -> {
    List<String> expectedWords = Arrays.asList("paradigm", "paradigm,programming",
      "programming", "programming,paradigm");
    int matchedWords = 0;
    for (String generatedWord : input) {
      if (expectedWords.contains(generatedWord)) {
        matchedWords++;
      }
    }
    Assert.assertEquals(matchedWords, 2);
    return null;
  });
  pipeline.run().waitUntilFinish();
}

@Test
public void should_show_that_stateful_processing_order_is_also_not_deterministic() {
  List<KV<String, String>> clubsPerCountry = Arrays.asList(
          KV.of("Germany", "VfB Stuttgart"), KV.of("Germany", "Bayern Munich"), KV.of("Germany", "FC Koln"));

  for (int i = 0; i < 10; i++) {
    Pipeline pipeline = BeamFunctions.createPipeline("Not deterministic stateful processing");
    pipeline.apply(Create.of(clubsPerCountry))
      .apply(ParDo.of(new DoFn<KV<String, String>, String>() {
        private static final String ACCUMULATOR_NAME = "accumulator";

        @StateId(ACCUMULATOR_NAME)
        private final StateSpec<BagState<String>> accumulatorStateSpec = StateSpecs.bag();

        @ProcessElement
        public void processElement(ProcessContext processContext,
                                    @StateId(ACCUMULATOR_NAME) BagState<String> clubsAccumulator) {
          clubsAccumulator.add(processContext.element().getValue());
          String clubs = Joiner.on("-").join(clubsAccumulator.read());
          StringsAccumulator.CLUBS.add(clubs);
          processContext.output(clubs);
        }
      }));
    pipeline.run().waitUntilFinish();
  }
  // We deal with 3 keys. If the processing order would be guaranteed, we'd always retrieve the same pairs.
  // It means we'd have the size of accumulated entries equal to 3. But since the processing order is not
  // deterministic, obviously we have more items than that.
  assertThat(StringsAccumulator.CLUBS.getEntries().size()).isGreaterThan(3);
}

@Test
public void should_show_that_stateful_processing_is_key_based() {
  List<KV<String, String>> clubsPerCountry = Arrays.asList(
    KV.of("Germany", "VfB Stuttgart"), KV.of("Germany", "Bayern Munich"), KV.of("Germany", "FC Koln"),
    KV.of("Holland", "Ajax Amsterdam"), KV.of("Holland", "Sparta Rotterdam"),
    KV.of("Spain", "FC Barcelona"), KV.of("Spain", "Real Madrid"));

  Pipeline pipeline = BeamFunctions.createPipeline("Not deterministic stateful processing");
  PCollection<String> results = pipeline.apply(Create.of(clubsPerCountry))
    .apply(ParDo.of(new DoFn<KV<String, String>, String>() {

      private static final String COUNTER_NAME = "occurrences_counter";

      @StateId(COUNTER_NAME)
      private final StateSpec<ValueState<Integer>> counter = StateSpecs.value(VarIntCoder.of());

      @ProcessElement
      public void processElement(ProcessContext processContext,
                                  @StateId(COUNTER_NAME) ValueState<Integer> counterState) {
        int currentValue = Optional.ofNullable(counterState.read()).orElse(0);
        int incrementedCounter = currentValue + 1;
        counterState.write(incrementedCounter);
        processContext.output(processContext.element().getKey()+"="+incrementedCounter);
      }
    }));

  // If the state cell wasn't be key-based, the number of accumulated entries would probably be incremental,
  // i.e. 1, 2, 3, 4, 5, 6, 7
  PAssert.that(results).containsInAnyOrder("Germany=1", "Germany=2", "Germany=3", "Holland=1", "Holland=2",
          "Spain=1", "Spain=2");
  pipeline.run().waitUntilFinish();
}

@Test
public void should_fail_when_two_states_with_the_same_ids_are_defined() {
  Pipeline pipeline = BeamFunctions.createPipeline("Two states with the same name");
  List<KV<String, String>> clubsPerCountry = Arrays.asList(
    KV.of("Germany", "VfB Stuttgart"), KV.of("Germany", "Bayern Munich"), KV.of("Germany", "FC Koln"));

  assertThatThrownBy(() -> {
      pipeline.apply(Create.of(clubsPerCountry))
        .apply(ParDo.of(new DoFn<KV<String, String>, String>() {

          private static final String ACCUMULATOR_NAME = "accumulator";

          @StateId(ACCUMULATOR_NAME)
          private final StateSpec<BagState<String>> accumulatorStateSpec = StateSpecs.bag();

          @StateId(ACCUMULATOR_NAME)
          private final StateSpec<BagState<String>> accumulatorStateSpecDuplicated = StateSpecs.bag();

          @ProcessElement
          public void processElement(ProcessContext processContext,
                                      @StateId(ACCUMULATOR_NAME) BagState<String> clubsAccumulator) {
            clubsAccumulator.add(processContext.element().getValue());
            String clubs = Joiner.on(", ").join(clubsAccumulator.read());
            processContext.output(clubs);
          }
        }));  
      pipeline.run().waitUntilFinish();
  }).isInstanceOf(IllegalArgumentException.class)
    .hasMessageContaining("Duplicate StateId \"accumulator\", used on both of " +
            "[private final org.apache.beam.sdk.state.StateSpec ")
    .hasMessageContaining("and [private final org.apache.beam.sdk.state.StateSpec");
}

@Test
public void should_fail_when_the_state_spec_declaration_is_not_final() {
  Pipeline pipeline = BeamFunctions.createPipeline("Two states with the same name");
  List<KV<String, String>> clubsPerCountry = Arrays.asList(
    KV.of("Germany", "VfB Stuttgart"), KV.of("Germany", "Bayern Munich"), KV.of("Germany", "FC Koln"));

  assertThatThrownBy(() -> {
      pipeline.apply(Create.of(clubsPerCountry))
        .apply(ParDo.of(new DoFn<KV<String, String>, String>() {

          private static final String ACCUMULATOR_NAME = "accumulator";

          @StateId(ACCUMULATOR_NAME)
          private StateSpec<BagState<String>> accumulatorStateSpec = StateSpecs.bag();

          @ProcessElement
          public void processElement(ProcessContext processContext,
                                      @StateId(ACCUMULATOR_NAME) BagState<String> clubsAccumulator) {
            clubsAccumulator.add(processContext.element().getValue());
            String clubs = Joiner.on(", ").join(clubsAccumulator.read());
            processContext.output(clubs);
          }
        })); 
      pipeline.run().waitUntilFinish();
  }).isInstanceOf(IllegalArgumentException.class)
    .hasMessageContaining("Non-final field private").hasMessageContaining("org.apache.beam.sdk.state.StateSpec")
    .hasMessageContaining("annotated with StateId. State declarations must be final.");
}

enum StringsAccumulator {
  CLUBS;

  private Set<String> entries = new HashSet<>();

  public void add(String entry) {
    System.out.println("Adding " + entry);
    entries.add(entry);
  }

  public Set<String> getEntries() {
    return entries;
  }
}


public static class MinLengthTextSharedStateProcessing extends DoFn<String, String> {

  private static final String COUNTER_NAME = "counter";

  @StateId(COUNTER_NAME)
  private final StateSpec<ValueState<Integer>> counter = StateSpecs.value(VarIntCoder.of());

  private int textMinLength;

  public MinLengthTextSharedStateProcessing(int textMinLength) {
    this.textMinLength = textMinLength;
  }

  @ProcessElement
  public void process(ProcessContext processContext, @StateId(COUNTER_NAME) ValueState<Integer> counterState) {
    String word = processContext.element();
    if (word.length() >= textMinLength) {
      int counterValue = counterState.read();
      int newCounterValue = counterValue + 1;
      counterState.write(newCounterValue);
      processContext.output(word+"="+newCounterValue);
    }
  }

}

public static class SessionCombiner extends Combine.CombineFn<String, List<String>, Integer> {

  @Override
  public List<String> createAccumulator() {
    return new ArrayList<>();
  }

  @Override
  public List<String> addInput(List<String> accumulator, String log) {
    accumulator.add(log);
    return accumulator;
  }

  @Override
  public List<String> mergeAccumulators(Iterable<List<String>> accumulators) {
    List<String> mergedAccumulator = new ArrayList<>();
    accumulators.forEach(accumulatorToMerge -> mergedAccumulator.addAll(accumulatorToMerge));
    return mergedAccumulator;
  }

  @Override
  public Integer extractOutput(List<String> accumulator) {
    return accumulator.size();
  }
}

public static class SessionTransform extends DoFn<KV<String, String>, KV<String, Integer>> {

  private static final String COUNTER_NAME = "counter";

  @StateId(COUNTER_NAME)
  private final StateSpec<CombiningState<String, List<String>, Integer>> sessionStateSpec =
    StateSpecs.combining(ListCoder.of(StringUtf8Coder.of()), new SessionCombiner());

  @ProcessElement
  public void process(ProcessContext processContext,
                      @StateId(COUNTER_NAME) CombiningState<String, List<String>, Integer> sessionState) {
    KV<String, String> element = processContext.element();
    sessionState.add(element.getValue());
    processContext.output(KV.of(element.getKey(), sessionState.read()));
  }

}

The stateful processing is available in Apache Beam as it's available in the most of data processing frameworks (e.g. Spark). The state represents a situation happened between 2 different points in time. Moreover, it applies only to keyed datasets and it ends when the window expires. Its use is pretty straightforward since it's based on annotating the state object and defining it later in the processing method. As shown in the 2nd and the 3rd sections, the state can be of different type and a given key can have multiple state objects that can be of different types.