Fast Joins in Apache Beam


At work, I was given a task of joining two different events coming in via a real time stream in Apache Beam. While joins are simple in a SQL batch job, they are significantly more challenging in real time streaming systems. In a batch job, the data is bounded - it is finite and eventually will be exhausted. In a real time streaming system, the data is potentially infinite - it must be broken up into a bounded set of records or your join will take an infinite amount of time. In addition, data comes in at different times - how long should you wait to decide the other side of the join is never coming? One solution to these problems is windowing. Windowing is breaking a real time event stream into bounded pieces. One example of windowing is a fixed window.

A fixed window is a window that captures all events in a fixed period of time, such as 2 minutes. The problem is that events near the boundary will not be joined, as they would be in different windows.

Illustration of Fixed Windows
In this illustration, key A and C are in the same windows, so they will be joined. However, key B is in different window and cannot be joined.

An alternative approach is use a sliding window, a window where events can overlap, as illustrated below. This has a different tradeoff, however: many of the events will be duplicates. But we will get all of the events, even if they cross a window boundary.

Illustration of Sliding Windows
With sliding windows, windows overlap. In this example, key B will be joined because it's in one of the overlapping windows. However, A will be joined twice because it's in Window 1 and Window 2.

In any case, I found using either windowing method introduced too high latency to my join - you have to wait for the window to close, in addition for the time to use coGroupByKey() to join the objects together. What I wanted was to cache the first object I saw, and wait for the matching object to come in, then do the join. I considered using Redis as an external cache, but then I discovered BagStateSpec in Apache Beam.

BagStateSpec spec allows you to store state for each key in a PCollection. First you map your records to the form (key, value), and pass it to a DoFn. Each time the same key comes up, Beam will pull up the BagStateSpec for that key, and you can add to the list with cache.add(), and read the list with cache.read(). You can clear the cache with cache.clear().

Inner Join

Hereโ€™s the overall code with example data. It assumes that each record has a schema field showing if itโ€™s the left or right side of the join.

import apache_beam as beam
from apache_beam.coders.coders import TupleCoder, PickleCoder, StrUtf8Coder
from apache_beam.transforms.userstate import BagStateSpec

class CachedJoin(beam.DoFn):
  CACHE = BagStateSpec('cache', TupleCoder((StrUtf8Coder(), PickleCoder())))
	
  def process(self, record, cache=beam.DoFn.StateParam(CACHE)):
    key = record[0]
		value = record[1]
		schema_name = value["schema"]
		other_schema = None
		if schema_name == "left"
			other_schema = "right"
		else: 
			other_schema = "left"
		other_record = [x for x in cache.read() if x["schema"] == other_schema]
		if len(other_record) != 0:
			other_record = other_record[-1]
			cache.clear()
			value.update(other_record)
			value["key"] = key
			del value["schema"]
			yield value
		else:
			cache.add(value)

with beam.Pipeline() as pipeline:
  icons = pipeline | 'Create icons' >> beam.Create([
      ('Apple',   {"schema" : "left", "icon" : '๐ŸŽ'}),
      ('Grape',   {"schema" : "left", "icon" : '๐Ÿ‡'}),
      ('Tomato',  {"schema" : "left", "icon" : '๐Ÿ…'})
  ])

  durations = pipeline | 'Create durations' >> beam.Create([
      ('Apple',  {"schema" : "right", "schedule" : 'perennial'}),
      ('Grape',  {"schema" : "right", "schedule" : 'perennial'}),
      ('Tomato', {"schema" : "right", "schedule" : 'annual'})
  ])
  
  joined = ([icons, durations] 
  	| "Flatten" >> beam.Flatten()
  	| "Join" >> beam.ParDo(CachedJoin)
  )

This should output:

[
	{"key" : "Apple", icon "๐ŸŽ", "schedule" : 'perennial'},
	{"key" : "Grape", icon "๐Ÿ‡", "schedule" : 'perennial'},
	{"key" : "Tomato", icon "๐Ÿ…", "schedule" : 'annual'}
]

With this, I was getting 100% of data with less than 100ms latency for my job, not counting time waiting for data to arrive - fast enough to output the data for my needs.

Left Join

This works for an inner join, but thereโ€™s two problems:

  1. What to do with data that doesnโ€™t get joined and never will? It will stick around in memory .
  2. What if you want to do a left/right/full outer join?

Left joins are trickier than inner joining. In an inner join, you simply wait until the other half of the join comes in, and do the join. However, with a left join you canโ€™t be sure if you are ever going to get the other half - is it just late, or is it never coming? However, you can wait some time and then decide to give up expecting a record with a matching key.

You can use a timer to expire results when youโ€™ve decided youโ€™re done waiting. You create a timer with TimerSpec in your class. You then annotate another method with @on_timer(TIMER), which will be called when the timer ends.

Hereโ€™s an example of using a timer to accomplish a left join:

import apache_beam as beam
from apache_beam.coders.coders import TupleCoder, PickleCoder, StrUtf8Coder
from apache_beam.transforms.userstate import BagStateSpec
from apache_beam.transforms.userstate import on_timer, TimerSpec
from apache_beam.transforms.timeutil import TimeDomain

class CachedLeftJoin(beam.DoFn):
	CACHE = BagStateSpec('cache', TupleCoder((StrUtf8Coder(), PickleCoder())))
	STALE_TIMER = TimerSpec('stale', TimeDomain.REAL_TIME)
	
	def process(self, record, cache=beam.DoFn.StateParam(CACHE), stale_timer=beam.DoFn.TimerParam(STALE_TIMER)):
		key = record[0]
		value = record[1]
		schema_name = value["schema"]
		other_schema = None
		if schema_name == "left"
			other_schema = "right"
		else: 
			other_schema = "left"
		other_record = [x for x in cache.read() if x["schema"] == other_schema]
		if len(other_record) != 0:
			other_record = other_record[-1]
			cache.clear()
			value.update(other_record)
			value["key"] = key
			del value["schema"]
			yield value
		else:
			stale_timer.set(time.time() + 5) # Set the timer to 5 seconds past the current time
			cache.add((key, value))
			
    @on_timer(STALE_TIMER)
    def expire(self, cache=beam.DoFn.StateParam(CACHE)):
        right_dummy_record = {"schedule" : None}
        key,record = cache.read()[-1]
        cache.clear()
        if record["schema"] == "left":
          record = record.update(right_dummy)
          del record["schema"]
          yield record

with beam.Pipeline() as pipeline:
  icons = pipeline | 'Create icons' >> beam.Create([
      ('Apple',   {"schema" : "left", "icon" : '๐ŸŽ'}),
      ('Grape', {"schema" : "left", "icon" : '๐Ÿ‡'}),
      ('Carrot', {"schema" : "left", "icon" : '๐Ÿฅ•'}),
      ('Tomato',   {"schema" : "left", "icon" : '๐Ÿ…'})
  ])

  durations = pipeline | 'Create durations' >> beam.Create([
      ('Apple',  {"schema" : "right", "schedule" : 'perennial'}),
      ('Grape', {"schema" : "right", "schedule" : 'perennial'}),
      ('Tomato', {"schema" : "right", "schedule" : 'annual'})
  ])
  
  joined = ([icons, durations] 
  	| "Flatten" >> beam.Flatten()
  	| "Join" >> beam.ParDo(CachedJoin)
  )

This should output:

[
	{"key" : "Apple", icon "๐ŸŽ", "schedule" : 'perennial'},
	{"key" : "Grape", icon "๐Ÿ‡", "schedule" : 'perennial'},
	{"key" : "Tomato", icon "๐Ÿ…", "schedule" : 'annual'}
	{"key" : "Carrot", icon "๐Ÿฅ•", "schedule" : None}

]

Sources

Joining in Apache Beam

Stateful Processing In Apache Beam/Cloud Dataflow

Cache reuse across DoFnโ€™s in Beam


<< Previous Next >>