Originally posted on creativcoder.dev
The other day, I was reading about LSM Tree based database storage engines (specifically, the LevelDB docs) where I came to know about a phase called compaction. LSM Tree based database storage engines are key-value storage systems where every operation is append only to favour less write amplification and to reduce latency. LevelDB persists data file segments to disk (when the in-memory table hits a threshold) in sorted order to support efficient read queries. Even operations such as deletes or updates appends new entries to the underlying storage and this often leads to keys that are obsolete but still remain on disk and that increases disk usage. To overcome this, they often use a phase called compaction where several sorted files are merged into single file to remove old data records in background. In order to implement a similar compaction strategy in my toy database project, I used the merge
sub-routine from merge sort, generalizing it to k sorted arrays. In this post, I write about the implementation the merge k algorithm in Rust. If you know the merge sub-routine from merge sort, you should feel right at home understanding the algorithm.
The problem
Before formalizing the solution, we'll re-state the problem again with examples and reason up from there to come up with an implementation in Rust.
We are given k
array of items that are sorted. The problem is to implement a function that given k
sorted arrays, merges them and returns an array where all the elements are in sorted order.
For example, If we are given 2 sorted arrays:
a = [3, 5]
b = [2, 7]
where k = 2
Then the merged array we get, would be c = [2, 3, 5, 7]
Obvious naive solution
It's always a good idea to start with what you already know as the thinking lends itself to revealing more about the problem at hand. So, the obvious approach to solve this would be to just concatenate the two arrays and sort them. Here's a solution in Rust:
fn main() {
let a = vec![3, 5];
let b = vec![2, 7];
let c = [a, b].concat();
a.sort();
dbg!(a);
}
That was quite easy, but it does have a O(n*k * log n*k)
worst case time complexity, where n
is the size of the resulting vector and k
is the number of arrays. Can we do better?
Let's try a different approach. Without thinking about the implementation, how would you normally put these elements in order by hand? Let's distill down the above example and give it a thought with a case by case analysis.
Case 1:
Let's say for example, the arrays have only 1 element in them:
a = [3]
b = [2]
In this case, it's pretty trivial. We just compare the first item with the second and take the smallest of the two, and insert it as first element in the resulting array: [2]
. What remains is 3
from a
which we then append giving us [2, 3]
as the merged sequence.
Case 2:
Now, let's consider if one of the arrays have more than one element:
a = [3]
b = [2, 5]
In this case, we'll start again with the first item in both arrays and pick the smallest of the two items which is 2
. The next item to compare is 3
(from a
) and 5
(from b
), among which 3
is chosen as the next smallest element. At this point in time, we have exhausted all items in array a
. Now, whatever remains to be put in the merged array has to be from array b
. This is because of the invariant that we always pick the smallest item in each iteration so if all items from array a
are exhausted, all remaining elements must be greater than last element in a
. So we append 5
from b
to our merged array to give us: [2, 3, 5]
.
From the dry run of the above two cases, you must have already thought of using two variables as indexes over the arrays and walk over them one by one, taking the next smallest element in each iteration until one of them exhausts, and then copy all the items from the remaining array to our merged sequence. Sounds simple? Let's implement that:
Implementation of merge k, where k = 2
To simplify the implementation for readability, we'll limit the items to integer (i32
) values. Once we have the implementation in place, one can refactor it easily to make it generic over any T
.
Here's how we can implement the above index pointer based solution in Rust:
fn merge(a: &[i32], b: &[i32]) -> Vec<i32> {
let (mut i, mut j) = (0, 0);
let mut sorted = vec![];
let remaining;
let remaining_idx;
loop {
if a[i] < b[j] {
sorted.push(a[i]);
i += 1;
if i == a.len() {remaining = b; remaining_idx = j; break;}
} else {
sorted.push(b[j]);
j += 1;
if j == b.len() {remaining = a; remaining_idx = i; break;}
}
}
for i in remaining_idx..remaining.len() {
sorted.push(remaining[i]);
}
sorted
}
We have the merge
function defined that takes in two slices of integers aka reference to an array of integers and returns a Vec<i32>
(a heap allocated value). Within merge
, we create two indexes i, j
that start with 0
. We also create a remaining
and remaining_idx
to point to the array that gets left after all items from the other array gets exhausted. Next, we run a loop {}
, where we pick the smallest item, push it to sorted
and increment the respective index. We also do an additional check if we reach the end on one of the arrays and assign remaining
and remaining_idx
accordingly. After the loop, we loop over the remaining
array items and push to sorted
.
But, the above solution works only for 2 arrays. We need to generalize the solution to k sorted arrays.
Generalizing to k sorted arrays
How would one extend the above solution to k sorted collection of items?
Well it turns out that now we need to keep k pointers to k arrays and pick the smallest item out of k. It's easy enough to write code to keep track of the pointers, if we have let's say < 10 or < 100 arrays to be merged. Or instead of keeping pointers, we could store an array of indexes into the arrays. That's another solution that works again for smaller values of k. But, it's not a convenient or maintainable code to write when we have very large arrays to be merged. Moreover, the number of comparision increases linearly as k increases.
Following along the same solution, let's think about what we need to find from the k arrays at the first iteration of the loop. That's right, we need to find the minimum from k items. This might give you a hint of using some data structure that supports getting minimum item from the k
items. One data structure that comes to my mind is the heap data structure. In particular, a min-heap data structure.
Enter the min heap
A Heap is a complete binary tree (nodes are inserted breadth first left to right) where there's a relation between parent and child nodes. This is the heap property. There are two variants of heap: min heap and max heap. In min heap the every parent node is smaller than the child and conversely for the max heap. At minimum, any implementation supports three key APIs:
- inserting - Adds an element to the heap
- removing - Removes an element to the heap
- heapify - rebalances the tree while maintaining the heap property.
At all times the heap property must be maintained when inserting or removing items from the heap. Internally the heap uses the siftDown
and bubbleUp
sub-routines to maintain the heap property.
So using a min-heap, we can insert the first k items into the array. Then we keep popping until our heap is empty and keeping the array indexes.
To implement this, we need to modify our solution a bit.
Our merge
method now takes in a Vec<Vec<i32>>
as a parameter (an array of array of signed integers):
fn merge(arrays: Vec<Vec<i32>>) {
}
Next, let's create an Item
struct that will hold references to the k arrays and their respective index as their iteration state. The index will tell us where in our array we are currently at. This Item
instance will be the element that goes in our min heap.
#[derive(Debug, Eq)]
struct Item<'a> {
arr: &'a Vec<i32>,
idx: usize
}
Now Item
doesn't convey how they should be compared, as it's a new data type we've defined. So we need to tell the type system how they should be compared. Doing this is easy and we'll just need to implement a few traits and defer the comparision to the elements of the array using a helper method get_item
(we define). On the Item
struct we'll implement the required traits (PartialEq
, PartialOrd
, Ord
) so that it can be inserted into min-heap:
impl<'a> PartialEq for Item<'a> {
fn eq(&self, other: &Self) -> bool {
self.get_item() == other.get_item()
}
}
impl<'a> PartialOrd for Item<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.get_item().partial_cmp(&other.get_item())
}
}
impl<'a> Ord for Item<'a> {
fn cmp(&self, other: &Self) -> Ordering {
self.get_item().cmp(&other.get_item())
}
}
We also have a few helper methods new
and get_item
purely for convenience.
impl<'a> Item<'a> {
fn new(arr: &'a Vec<i32>, idx: usize) -> Self {
Self {
arr,
idx
}
}
fn get_item(&self) -> i32 {
self.arr[self.idx]
}
}
Once we have the required traits implemented for Item
, we'll then implement the new merge
algorithm:
use std::collections::BinaryHeap;
use std::cmp::Reverse;
use std::cmp::Ordering;
fn merge(arrays: Vec<Vec<i32>>) -> Vec<i32> {
let mut sorted = vec![];
let mut heap = BinaryHeap::with_capacity(arrays.len());
for arr in &arrays {
let item = Item::new(arr, 0);
heap.push(Reverse(item));
}
while !heap.is_empty() {
let mut it = heap.pop().unwrap();
sorted.push(it.0.get_item());
it.0.idx += 1;
if it.0.idx < it.0.arr.len() {
heap.push(it)
}
}
sorted
}
We create a Vec
that will hold our merged items. We then create a BinaryHeap
instance and push all items in arrays
as an Item
with the start index 0
. We then wrap item
with Reverse
(as it's a max heap by default) and push it to heap.
Next, we run a loop while we still have elements in heap. Within the loop, we get the smallest item by heap.pop()
and push it to sorted
and increment the item's index by 1. Because we might still have items in the heap, we check for that and push it back in the last line.
Finally once we are done, we return the sorted
array.
For this solution our complexity now reduces to O(n*k * Log(k))
Finally let's test this out on a sample dataset:
fn main() {
let a = vec![1, 5, 7];
let b = vec![-2, 3, 4];
let v = vec![a, b];
dbg!(merge(v));
}
Final code: https://gist.github.com/creativcoder/20fda0f1a947bf2af3b93ab394d6bacb
Taking it a step further
An optimization on top of this would be to stream sorted values as they are processed by exposing an iterator over the k
items in heap.
With that said, I am open to suggestions, improvements to the solution and you can do so with comments below. Until next time!
Top comments (9)
This is a good problem, got asked it in a Facebook interview! A few suggestions on the implementation:
get_item()
does a bounds check every time it's called, which shouldn't be necessary. You can store the next element in theItem
struct to fix this.heap.is_empty()
and using.unwrap()
if you usewhile let Some(it) = heap.pop()
insteadheap.peek_mut()
to avoid sifting the heap twice each time through the loopsorted
with the correct capacityHere's what I came up with:
Hey Caleb,
Thank you for improving the solution. The idea to store the rest of items in
Item
is neat!I think a mention of the
kmerge
andkmerge_by
methods in the itertools crate would have been useful, for those just looking to use such functionality in their projects.Great!
Your last code block seems broken, btw :D
Hi Frank and thank you.
Oh, there's an issue with rendering that specific code block even though I have triple backticks on both close and end. Got no clue on that!
I can't see the source code of the post unfortunately. But if you could take a screen shot of that last code snippet markdown I might be able to help you :)
It's the same as other code blocks above. I also tried copy pasting a random code block (rendered properly) from above and pasted in place of that. Still the same.
The problem is that the backticks from
O(n*k * Log(k))
are not closed :)Ah, good catch. Thank you. Totally missed it, as dev.to editor does not have colored syntax highligher. Usually edit in my code editor :)