The Bytewax framework is a flexible tool designed to meet the challenges faced by Python developers in today's data-driven world. It aims to provide seamless integrations and time-saving shortcuts for data engineers dealing with streaming data, making their work more efficient and effective. One of the important sides of developing Bytewax is input connectors. These connectors help in establishing the connection between the external systems and Bytewax to help users in importing data from external systems.
Here we're going to show how to write a custom input connector by walking through how we wrote our built-in Kafka input connector.
Writing input connectors for arbitrary systems while supporting failure recovery and strong delivery guarantees requires a solid understanding of how recovery works internal to Bytewax and the chosen output system. We strongly encourage you to use the connectors we have built into bytewax.connectors
if possible, and read the documentation on their limits.
If you are interested in writing your own, this article can give you an introduction into some of the decisions involved in writing an input connector for an ordered, partitioned input stream.
If you need any help at all writing a connector, come say "hi" and ask questions in the Bytewax community Slack! We are happy to help!
Partitions
Writing a subclass for bytewax.inputs.PartitionedInput
is the core API for writing an input connector when you have an input that has a fixed number of partitions. A partition is a "sub-stream" of data that can be read concurrently and independently.
To write a PartitionedInput
subclass, you need to answer three questions:
- How many partitions are there?
- How can I build a source that reads a single partition?
- How can I rewind a partition and read from a specific item?
This is done via the abstract methods list_parts
, build_part
, and the resume_state
variable respectively.
We're going to use the confluent-kafka
package to actually communicate with the Kafka cluster. Let's import all the things we'll need for this input source.
from typing import Dict, Iterable
from confluent_kafka import (
Consumer,
KafkaError,
OFFSET_BEGINNING,
TopicPartition,
)
from confluent_kafka.admin import AdminClient
from bytewax.inputs import PartitionedInput, StatefulSource
Our KafkaInput connector is going to read from a specific set of topics on a cluster. First, let's define our class and write a constructor that takes all the arguments that make sense for configuring this specific kind of input source. This is going to be the public entry point to this connector, and is what you'll pass to the bytewax.dataflow.Dataflow.input
operator.
class KafkaInput(PartitionedInput):
def __init__ (
self,
brokers: Iterable[str],
topics: Iterable[str],
tail: bool = True,
starting_offset: int = OFFSET_BEGINNING,
add_config: Dict[str, str] = None,
):
add_config = add_config or {}
if isinstance(brokers, str):
raise TypeError("brokers must be an iterable and not a string")
self
Listing Partitions
Next, let's answer question one: how many partitions are there? Conveniently, confluent-kafka
provides an AdminClient.list_topics
which give you the partition count of each topic, packed deep in a metadata object. The signature of PartitionedInput.list_parts
says it must return a set of strings with IDs of all the partitions. Let's build the AdminClient
using our configuring instance variables and then delegate to a _list_parts
function so we can re-use it if necessary.
# Continued
# class KafkaInput(PartitionedInput):
def list_parts(self):
config = {
"bootstrap.servers": ",".join(self._brokers),
}
config.update(self._add_config)
client = AdminClient(config)
return set(_list_parts(client, self._topics))
This function unpacks the nested metadata returned from AdminClient.list_topics
, and returns a string that looks like "3-my_topic" for the third partition in the topic my_topic
.
def _list_parts(client, topics):
for topic in topics:
# List topics one-by-one so if auto-create is turned on,
# we respect that.
cluster_metadata = client.list_topics(topic)
topic_metadata = cluster_metadata.topics[topic]
if topic_metadata.error is not None:
raise RuntimeError(
f"error listing partitions for Kafka topic `{topic!r}`: "
f"{topic_metadata.error.str()}"
)
part_idxs = topic_metadata.partitions.keys()
for i in part_idxs:
yield f"{i}-{topic}"
How do you decide what the partition ID string should be? It should be something that globally identifies this partition, hence combining partition number and topic name.
PartitionedInput.list_parts
might be called multiple times from multiple workers as a Bytewax cluster is setup and resumed, so it must return exactly the same set of partitions on every call in order to work correctly. Changing numbers of partitions is not currently supported with recovery.
Building Partitions
Next, let's answer question two: how can I build a source that reads a single partition? We can use confluent-kafka
's Consumer
to make a Kafka consumer that will read a specific topic and partition starting from an offset. The signature of PartitionedInput.build_part
takes a specific partition ID (we'll ignore the resume state for now) and must return a stateful source.
We parse the partition ID to determine which Kafka partition we should be consuming from. (Hence the importance of having a globally unique partition ID.) Then we build a Consumer
that connects to the Kafka cluster, and build our custom _KafkaSource
stateful source. That is where the actual reading of input items happens.
# Continued
# class KafkaInput(PartitionedInput):
def build_part(self, for_part, resume_state):
part_idx, topic = for_part.split("-", 1)
part_idx = int(part_idx)
assert topic in self._topics, "Can't resume from different set of Kafka topics"
config = {
# We'll manage our own "consumer group" via recovery
# system.
"group.id": "BYTEWAX_IGNORED",
"enable.auto.commit": "false",
"bootstrap.servers": ",".join(self._brokers),
"enable.partition.eof": str(not self._tail),
}
config.update(self._add_config)
consumer = Consumer(config)
return _KafkaSource(
consumer, topic, part_idx, self._starting_offset, resume_state
)
Stateful Input Source
What is a stateful source? It is defined by subclassing bytewax.inputs.StatefulSource
. You can think about it as a "snapshot-able Python iterator": something that produces a stream of items via StatefulSource.next
, and also lets the Bytewax runtime ask for a snapshot of the position of the source via StatefulSource.snapshot
.
Our _KafkaSource
is going to read items from a specific Kafka topic's partition. Let's define that class and have a constructor that takes in all the details to start reading that partition: the consumer (already configured to connect to the correct Kafka cluster), the topic, the specific partition index, the default starting offset (beginning or end of the topic), and again we'll ignore the resume state for just another moment.
class _KafkaSource(StatefulSource):
def __init__ (self, consumer, topic, part_idx, starting_offset, resume_state):
self._offset = resume_state or starting_offset
# Assign does not activate consumer grouping.
consumer.assign([TopicPartition(topic, part_idx, self._offset)])
self._consumer = consumer
self._topic = topic
The beating heart of the input source is the StatefulSource.next method. It is periodically called by Bytewax and behaves similar to a built-in Python iterator's __next__
method. It must do one of three things: return a new item to send into the dataflow, return None signaling that there is no data currently but might be later, or raise StopIteration when the partition is complete.
Consumer.poll gives us a method to ask if there are any new messages on the partition we setup this consumer to follow. And if there are, unpack the data message and return it. Otherwise handle the no data case, the end-of-stream case, or an exceptional error case.
# Continued
# class _KafkaSource(StatefulSource):
def next(self):
msg = self._consumer.poll(0.001) # seconds
if msg is None:
return
elif msg.error() is not None:
if msg.error().code() == KafkaError._PARTITION_EOF:
raise StopIteration()
else:
raise RuntimeError(
f"error consuming from Kafka topic `{self.topic!r}`: {msg.error()}"
)
else:
item = (msg.key(), msg.value())
# Resume reading from the next message, not this one.
self._offset = msg.offset() + 1
return item
An important thing to note here is that StatefulSource.next must never block. The Bytewax runtime employs a sort of cooperative multitasking, and so each operator must return quickly, even if it has nothing to do, so other operators in the dataflow that do have work can run. Unfortunately, currently there is no way in the Bytewax API to prevent polling of input sources (as input comes from outside the dataflow, Bytewax has no way of knowing when more data is available, so must constantly check). The best practice here is to pause briefly if there is no data to prevent a full spin-loop on no new data, but not so long you block other operators from doing their work.
There is also a StatefulSource.close
method which enables you to do any well-behaved shutdown when EOF is reached. This is not guaranteed to be called in a failure situation and should not be crucial to the connecting system. In this case, Consumer.close
does graceful shutdown.
# class _KafkaSource(StatefulSource):
def close(self):
self._consumer.close()
Resume State
Lets explain how failure recovery works for input connectors. Bytewax's recovery system allows the dataflow to quickly resume processing and output without needing to replay all input. It does this by periodically snapshot all internal state, input positions, and output positions of the dataflow. Then when it needs to recover after a failure, it loads all state from a recent snapshot, and starts re-playing input items in the same order from the instant of the snapshot and overwriting output items. This will cause the state and output of the dataflow to evolve in the same way during the resume execution as during the previous execution.
Snapshotting
So, we need to keep track of the current position somewhere in each partition. Kafka has the concept of message offsets, which is an incrementing immutable integer that is the position of each message. In _KafkaSource.next
, we kept track of the offset of the next message that partition will read via self._offset = msg.offset() + 1
.
Bytewax calls StatefulSource.snapshot
when it needs to record that partition's position and returns that internally stored next message offset.
# Continued
# class _KafkaSource(StatefulSource):
def snapshot(self):
return self._offset
Resume
On resume after a failure, Bytewax's recovery machinery does the hard work of collecting all the snapshots, finding the ones that represent a coherent set of states across the previous execution's cluster, and threading each bit of snapshot data back through into PartitionedInput.build_part
for the same partition. To properly take advantage of that, your resulting partition must resume reading from the same spot represented by that snapshot.
Since we were storing the Kafka message offset of the next message to be read in _KafkaSource._offset
, we need to ensure we thread through that message offset back into the Consumer
when it is built. That happens via passing resume_state
into the _KafkaSource
constructor, and it assigning that consumer to start reading from that offset. Looking at that code again:
# Continued
# class _KafkaSource(StatefulSource):
# def __init__ (self, consumer, topic, part_idx, starting_offset, resume_state):
self._offset = resume_state or starting_offset
# Assign does not activate consumer grouping.
consumer.assign([TopicPartition(topic, part_idx, self._offset)])
...
As one extra wrinkle, if there is no resume state for this partition if the partition is being built for the first time, None
will be passed for resume_state
in PartitionedInput.build_part
. In that case, we need to fill in the requested "default starting offset": either "beginning of topic" or "end of topic". In the case where we do have resume state, we should ignore that since we need to start from the specific offset to uphold the recovery contract.
Delivery Guarantees
Let's talk for a moment about how this recovery model with snapshots impacts delivery guarantees. A well-designed input connector on its own can only guarantee that the output of a dataflow to a downstream system is at-least-once: the recovery system will ensure that we replay any input that might not have been output due to where the execution cluster failed, but it requires coordination with the output connector (via something like transactions or two-phase commits) to ensure that the replay does not result in duplicated writes downstream and exactly-once processing.
Non-Replay-Able Sources
If your input source does not have the ability to replay old data, you can still use it with Bytewax, but your delivery guarantees are limited to at-least-once. For example, listening to an ephemeral SSE or WebSocket stream, you can always start listening, but often the request API does not let you specify an ability to replay missing events. When Bytewax attempts to resume, all the other operators will have their internal state returned to that last coherent snapshot, but since the input sources do not rewind, it will appear that the dataflow has missed out on all input between when that snapshot was taking and resume.
In this case, your StatefulSource.snapshot
can return None
and no recovery data will be saved. You can then ignore the resume_state
argument of PartitionedInput.build_part
because it will always be None
.
Top comments (0)