Often getting information about Spark partitions is essential when tuning performance. All the samples are in python.
Partition Count
Getting number of partitions of a DataFrame
is easy, but none of the members are part of DF class itself and you need to call to .rdd
. Any of the following three lines will work:
df.rdd.partitions.size
df.rdd.getNumbPartitions()
df.rdd.length
i.e.
In: rbe_s1.rdd.getNumPartitions()
Out: 13
But, how do I know a bit more about partitions, i.e. at least their sizes and what was involved in partitioning the data? I guess one could write data down to disk and inspect files, but that's tedious and often won't work in high security environments.
Partition Sizes
Getting a partition size is also not obvious, and there is not built-in function to do that. Again, one can do that with low-level RDD API, for instance .mapPartitions
which is defined as follows:
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition of this RDD.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> def f(iterator): yield sum(iterator)
>>> rdd.mapPartitions(f).collect()
[3, 7]
"""
def func(s, iterator):
return f(iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)
The example in documentation comment sort of gives it away already. Getting a list of partition sizes:
lengths = rdd.mapPartitions(get_partition_len, True).collect()
Utility Function
Putting it all together, here is a helper function that displays basic DataFrame
statistics:
from pyspark import RDD
from pyspark.sql import DataFrame
def print_partition_info(df: DataFrame):
import statistics
def get_partition_len(iterator):
yield sum(1 for _ in iterator)
rdd: RDD = df.rdd
count = rdd.getNumPartitions()
# lengths = rdd.glom().map(len).collect() # much more memory hungry than next line
lengths = rdd.mapPartitions(get_partition_len, True).collect()
print("")
print(f"{count} partition(s) total.")
print(f"size stats")
print(f" min: {min(lengths)}")
print(f" max: {max(lengths)}")
print(f" avg: {sum(lengths)/len(lengths)}")
print(f" stddev: {statistics.stdev(lengths)}")
print("")
print("detailed info")
for i, pl in enumerate(lengths):
print(f" {i}. {pl}")
Sample output:
5 partition(s) total.
size stats
min: 13
max: 4403
avg: 1277.4
stddev: 1929.5741239973136
detailed info
0. 4403
1. 1914
2. 38
3. 19
4. 13
As you can see partition 0 has most of the data, so it's definitely going to screw things up, or already does.
Repartitioning Data
Now that it's (hopefully) clear which partitions are the bad boys, you might want to re-partition the dataframe. This part is move obvious comparing to the before. Basically there are two functions - coalesce
and repartition
on the DF itself. The documentation for them is very similar and it's really confusing what to use when:
Coalesce
Changes DF partitioning, but actually doesn't do what it says on the tin. Coalesce does not physically repartition data but rather changes number of partitions. This means that some partitions claim ownership of others to reach the requested number of partitions. For instance, if you repartition DataFrame
0. 4403
1. 1914
2. 38
3. 19
4. 13
to 2 partitions (.coalesce(2)
) you will get:
0. 4454
1. 1933
so yeah, you did get 2 partitions, but that didn't make much of a different to performance, as one of the partitions is more than 2 times bigger than the other. Note that coalesce
does not shuffle data so there is nothing actually happening physically. It's really useful in many cases where you don't want the data to be moved, but want to process it sequentially with some parallelism involved.
Repartition Function
This function does repartition data shuffling it between the nodes and does physically move data around. Calling repartition(2)
on the DataFrame
above results in the following:
0. 3198
1. 3195
As you can see, partitions are almost equal in size!
You can also supply columns as function arguments, and they will be used to calculate resulting hash for the partition. This can be useful if your calculation take into account particular columns and work best with similar values to be close to each other. The function is defined as follows:
@since(1.3)
def repartition(self, numPartitions, *cols):
and cols
is actually a Column
type, meaning you can pass either column name, or any expression that results in a column. This is particularly important, as by supplying an expression you are essentially creating a partition hashing function. So it's not limited just to a dumb column name.are essentially creating a partition hashing function. So it's not limited just to a dumb column name.
P.S. Originally published on my own blog.
Top comments (0)