DEV Community

Oli Makhasoeva for bytewax

Posted on • Originally published at bytewax.io on

Lessons we learned while building a stateful Kafka connector and tips for creating yours

Bytewax

The Bytewax framework is a flexible tool designed to meet the challenges faced by Python developers in today's data-driven world. It aims to provide seamless integrations and time-saving shortcuts for data engineers dealing with streaming data, making their work more efficient and effective. One of the important sides of developing Bytewax is input connectors. These connectors help in establishing the connection between the external systems and Bytewax to help users in importing data from external systems.

Here we're going to show how to write a custom input connector by walking through how we wrote our built-in Kafka input connector.

Writing input connectors for arbitrary systems while supporting failure recovery and strong delivery guarantees requires a solid understanding of how recovery works internal to Bytewax and the chosen output system. We strongly encourage you to use the connectors we have built into bytewax.connectors if possible, and read the documentation on their limits.

If you are interested in writing your own, this article can give you an introduction into some of the decisions involved in writing an input connector for an ordered, partitioned input stream.

If you need any help at all writing a connector, come say "hi" and ask questions in the Bytewax community Slack! We are happy to help!

Partitions

Writing a subclass for bytewax.inputs.PartitionedInput is the core API for writing an input connector when you have an input that has a fixed number of partitions. A partition is a "sub-stream" of data that can be read concurrently and independently.

To write a PartitionedInput subclass, you need to answer three questions:

  1. How many partitions are there?
  2. How can I build a source that reads a single partition?
  3. How can I rewind a partition and read from a specific item?

This is done via the abstract methods list_parts, build_part, and the resume_state variable respectively.

We're going to use the confluent-kafka package to actually communicate with the Kafka cluster. Let's import all the things we'll need for this input source.

from typing import Dict, Iterable

from confluent_kafka import (
    Consumer,
    KafkaError,
    OFFSET_BEGINNING,
    TopicPartition,
)
from confluent_kafka.admin import AdminClient

from bytewax.inputs import PartitionedInput, StatefulSource

Enter fullscreen mode Exit fullscreen mode

Our KafkaInput connector is going to read from a specific set of topics on a cluster. First, let's define our class and write a constructor that takes all the arguments that make sense for configuring this specific kind of input source. This is going to be the public entry point to this connector, and is what you'll pass to the bytewax.dataflow.Dataflow.input operator.

class KafkaInput(PartitionedInput):
    def __init__ (
        self,
        brokers: Iterable[str],
        topics: Iterable[str],
        tail: bool = True,
        starting_offset: int = OFFSET_BEGINNING,
        add_config: Dict[str, str] = None,
    ):
        add_config = add_config or {}

        if isinstance(brokers, str):
            raise TypeError("brokers must be an iterable and not a string")
        self

Enter fullscreen mode Exit fullscreen mode

Listing Partitions

Next, let's answer question one: how many partitions are there? Conveniently, confluent-kafka provides an AdminClient.list_topics which give you the partition count of each topic, packed deep in a metadata object. The signature of PartitionedInput.list_parts says it must return a set of strings with IDs of all the partitions. Let's build the AdminClient using our configuring instance variables and then delegate to a _list_parts function so we can re-use it if necessary.

# Continued
# class KafkaInput(PartitionedInput):
    def list_parts(self):
        config = {
            "bootstrap.servers": ",".join(self._brokers),
        }
        config.update(self._add_config)
        client = AdminClient(config)

        return set(_list_parts(client, self._topics))

Enter fullscreen mode Exit fullscreen mode

This function unpacks the nested metadata returned from AdminClient.list_topics, and returns a string that looks like "3-my_topic" for the third partition in the topic my_topic.

def _list_parts(client, topics):
    for topic in topics:
        # List topics one-by-one so if auto-create is turned on,
        # we respect that.
        cluster_metadata = client.list_topics(topic)
        topic_metadata = cluster_metadata.topics[topic]
        if topic_metadata.error is not None:
            raise RuntimeError(
                f"error listing partitions for Kafka topic `{topic!r}`: "
                f"{topic_metadata.error.str()}"
            )
        part_idxs = topic_metadata.partitions.keys()
        for i in part_idxs:
            yield f"{i}-{topic}"

Enter fullscreen mode Exit fullscreen mode

How do you decide what the partition ID string should be? It should be something that globally identifies this partition, hence combining partition number and topic name.

PartitionedInput.list_parts might be called multiple times from multiple workers as a Bytewax cluster is setup and resumed, so it must return exactly the same set of partitions on every call in order to work correctly. Changing numbers of partitions is not currently supported with recovery.

Building Partitions

Next, let's answer question two: how can I build a source that reads a single partition? We can use confluent-kafka's Consumer to make a Kafka consumer that will read a specific topic and partition starting from an offset. The signature of PartitionedInput.build_part takes a specific partition ID (we'll ignore the resume state for now) and must return a stateful source.

We parse the partition ID to determine which Kafka partition we should be consuming from. (Hence the importance of having a globally unique partition ID.) Then we build a Consumer that connects to the Kafka cluster, and build our custom _KafkaSource stateful source. That is where the actual reading of input items happens.

# Continued
# class KafkaInput(PartitionedInput):
    def build_part(self, for_part, resume_state):
        part_idx, topic = for_part.split("-", 1)
        part_idx = int(part_idx)
        assert topic in self._topics, "Can't resume from different set of Kafka topics"

        config = {
            # We'll manage our own "consumer group" via recovery
            # system.
            "group.id": "BYTEWAX_IGNORED",
            "enable.auto.commit": "false",
            "bootstrap.servers": ",".join(self._brokers),
            "enable.partition.eof": str(not self._tail),
        }
        config.update(self._add_config)
        consumer = Consumer(config)
        return _KafkaSource(
            consumer, topic, part_idx, self._starting_offset, resume_state
        )

Enter fullscreen mode Exit fullscreen mode

Stateful Input Source

What is a stateful source? It is defined by subclassing bytewax.inputs.StatefulSource. You can think about it as a "snapshot-able Python iterator": something that produces a stream of items via StatefulSource.next, and also lets the Bytewax runtime ask for a snapshot of the position of the source via StatefulSource.snapshot.

Our _KafkaSource is going to read items from a specific Kafka topic's partition. Let's define that class and have a constructor that takes in all the details to start reading that partition: the consumer (already configured to connect to the correct Kafka cluster), the topic, the specific partition index, the default starting offset (beginning or end of the topic), and again we'll ignore the resume state for just another moment.

class _KafkaSource(StatefulSource):
    def __init__ (self, consumer, topic, part_idx, starting_offset, resume_state):
        self._offset = resume_state or starting_offset
        # Assign does not activate consumer grouping.
        consumer.assign([TopicPartition(topic, part_idx, self._offset)])
        self._consumer = consumer
        self._topic = topic

Enter fullscreen mode Exit fullscreen mode

The beating heart of the input source is the StatefulSource.next method. It is periodically called by Bytewax and behaves similar to a built-in Python iterator's __next__ method. It must do one of three things: return a new item to send into the dataflow, return None signaling that there is no data currently but might be later, or raise StopIteration when the partition is complete.

Consumer.poll gives us a method to ask if there are any new messages on the partition we setup this consumer to follow. And if there are, unpack the data message and return it. Otherwise handle the no data case, the end-of-stream case, or an exceptional error case.

# Continued
# class _KafkaSource(StatefulSource):
    def next(self):
        msg = self._consumer.poll(0.001) # seconds
        if msg is None:
            return
        elif msg.error() is not None:
            if msg.error().code() == KafkaError._PARTITION_EOF:
                raise StopIteration()
            else:
                raise RuntimeError(
                    f"error consuming from Kafka topic `{self.topic!r}`: {msg.error()}"
                )
        else:
            item = (msg.key(), msg.value())
            # Resume reading from the next message, not this one.
            self._offset = msg.offset() + 1
            return item

Enter fullscreen mode Exit fullscreen mode

An important thing to note here is that StatefulSource.next must never block. The Bytewax runtime employs a sort of cooperative multitasking, and so each operator must return quickly, even if it has nothing to do, so other operators in the dataflow that do have work can run. Unfortunately, currently there is no way in the Bytewax API to prevent polling of input sources (as input comes from outside the dataflow, Bytewax has no way of knowing when more data is available, so must constantly check). The best practice here is to pause briefly if there is no data to prevent a full spin-loop on no new data, but not so long you block other operators from doing their work.

There is also a StatefulSource.close method which enables you to do any well-behaved shutdown when EOF is reached. This is not guaranteed to be called in a failure situation and should not be crucial to the connecting system. In this case, Consumer.close does graceful shutdown.

# class _KafkaSource(StatefulSource):
    def close(self):
        self._consumer.close()

Enter fullscreen mode Exit fullscreen mode

Resume State

Lets explain how failure recovery works for input connectors. Bytewax's recovery system allows the dataflow to quickly resume processing and output without needing to replay all input. It does this by periodically snapshot all internal state, input positions, and output positions of the dataflow. Then when it needs to recover after a failure, it loads all state from a recent snapshot, and starts re-playing input items in the same order from the instant of the snapshot and overwriting output items. This will cause the state and output of the dataflow to evolve in the same way during the resume execution as during the previous execution.

Snapshotting

So, we need to keep track of the current position somewhere in each partition. Kafka has the concept of message offsets, which is an incrementing immutable integer that is the position of each message. In _KafkaSource.next, we kept track of the offset of the next message that partition will read via self._offset = msg.offset() + 1.

Bytewax calls StatefulSource.snapshot when it needs to record that partition's position and returns that internally stored next message offset.

# Continued
# class _KafkaSource(StatefulSource):
    def snapshot(self):
        return self._offset

Enter fullscreen mode Exit fullscreen mode

Resume

On resume after a failure, Bytewax's recovery machinery does the hard work of collecting all the snapshots, finding the ones that represent a coherent set of states across the previous execution's cluster, and threading each bit of snapshot data back through into PartitionedInput.build_part for the same partition. To properly take advantage of that, your resulting partition must resume reading from the same spot represented by that snapshot.

Since we were storing the Kafka message offset of the next message to be read in _KafkaSource._offset, we need to ensure we thread through that message offset back into the Consumer when it is built. That happens via passing resume_state into the _KafkaSource constructor, and it assigning that consumer to start reading from that offset. Looking at that code again:

# Continued
# class _KafkaSource(StatefulSource):
# def __init__ (self, consumer, topic, part_idx, starting_offset, resume_state):
        self._offset = resume_state or starting_offset
        # Assign does not activate consumer grouping.
        consumer.assign([TopicPartition(topic, part_idx, self._offset)])
        ...

Enter fullscreen mode Exit fullscreen mode

As one extra wrinkle, if there is no resume state for this partition if the partition is being built for the first time, None will be passed for resume_state in PartitionedInput.build_part. In that case, we need to fill in the requested "default starting offset": either "beginning of topic" or "end of topic". In the case where we do have resume state, we should ignore that since we need to start from the specific offset to uphold the recovery contract.

Delivery Guarantees

Let's talk for a moment about how this recovery model with snapshots impacts delivery guarantees. A well-designed input connector on its own can only guarantee that the output of a dataflow to a downstream system is at-least-once: the recovery system will ensure that we replay any input that might not have been output due to where the execution cluster failed, but it requires coordination with the output connector (via something like transactions or two-phase commits) to ensure that the replay does not result in duplicated writes downstream and exactly-once processing.

Non-Replay-Able Sources

If your input source does not have the ability to replay old data, you can still use it with Bytewax, but your delivery guarantees are limited to at-least-once. For example, listening to an ephemeral SSE or WebSocket stream, you can always start listening, but often the request API does not let you specify an ability to replay missing events. When Bytewax attempts to resume, all the other operators will have their internal state returned to that last coherent snapshot, but since the input sources do not rewind, it will appear that the dataflow has missed out on all input between when that snapshot was taking and resume.

In this case, your StatefulSource.snapshot can return None and no recovery data will be saved. You can then ignore the resume_state argument of PartitionedInput.build_part because it will always be None.

Top comments (0)