The Rise (Return?) Of Specialized Computing Environments
Isn't it interesting that the key advancements in computing over the last 20 years have all been driven by some kind of specialized and limited computing environments? I'm thinking specifically of:
- GPU computing
- Blockchain VMs
- Stateless functions in cloud computing environments
- Distributed computing
There have been attempts to do "general purpose programming" in some of these environments (like CUDA for GPUs and using Rust to write programs for Solana), but the seams show. While it's getting better, there are still important limitations.
The reality is that ever since CPUs stopped getting appreciably better each year (and ever since society developed a keen interest in "world computers"), the cutting edge has been in these limited, special purpose computing environments. But it's not all bad news.
Here's the bright side for people that like math and algorithms. When you don't have GBs of RAM and fast and cheap disk access at your fingertips, you have to be careful, and algorithms matter again!
The "Streaming Context"
The common denominator between these limited computed environments is what I would call the "streaming context". In a streaming context, you see each observation in a dataset once, and never again. Keeping the entire dataset in memory or re-fetching from storage is out of the question, because the number of observations is potentially very large, in-process memory may be limited, and data storage / retrieval may be prohibitively costly.
What kinds of algorithms are useful in the streaming context? Almost by definition, any algorithm that can compute meaningful results in "one-pass".
A running total might be the simplest example of a one-pass algorithm. It's not too much more difficult to compute the mean in one pass. But I'd like to take the data-description direction as far as it can possibly go.
What if we could capture the entire data distribution in just one pass through the data, without actually storing the data?
The Statistics Bucket
Our goal is to create a "stats bucket" that gathers all sorts of useful information about a dataset - not just the usual sample statistics (mean, variance, skewness, ...), but also a detailed description of the shape of the distribution. It must gather this information in one pass over the data.
But the Stats Bucket should be update-able. When new data is streamed to the bucket, the sample statistics and the shape of the distribution are updated.
It will take several blog posts to develop the Stats Bucket.
Here's how we'll proceed:
In Part 1 (this post), we'll derive some recurrence relations to efficiently compute summary statistics and arbitrary order central moments.
In Part 2, we'll use a mixture distribution to approximately reconstruct the original data distribution based on the central moments computed in Part 1.
In Part 3, we will conduct simulation studies to assess the accuracy of the technique.
In Part 4, we'll port it over to Rust and/or a Solana program. I haven't decided yet. Stay tuned!
Let's Play A Game
To make the limitations involved a bit clearer and analogize the "streaming context", let's talk about it as if it were a game, called the "Number Reading Game". Here's how you play:
I begin reading a list of numbers to you. I'll read them to you one by one. You don't know how long the list of numbers is in advance, but it's going to be long. When I'm done, you have to tell me the mean. But here's the catch: You can't write down the numbers I've read.
No problem, right? All you have to do is keep track of the total and the count. There's no rules against that!
total = 0
N = 0
for y in ys:
total += y
N += 1
μ = total / N
If you're not a fan of keeping track of a large sum in total
, there's an alternative formula you can use. We'll call it the "mean update formula".
(Hint: To see why this works, multiply both sides by n and rewrite the terms with summations)
Did you notice something interesting about the mean update formula? It expresses the updated mean in terms of the current mean and the next observation. In mathematics, that's called a recurrence relation. Make a mental note of the idea, because we're going to be using oodles of those.
Let's Play Another Game
Okay, maybe that was a bit too easy. Let's try something else. When I'm done reading the numbers, you have to tell me the variance - a measure of how spread out the numbers are.
For a refresher, here's the formula for the (sample size corrected) variance , where μ is the mean of the data:
Here's some pseudo-code for computing the variance in the most straightforward way possible:
# First pass through `ys` - compute the mean
total = 0
N = 0
for y in ys:
total += y
N += 1
μ = total / N
# Second pass through `ys` - calculate "2nd central moment"*
M2 = 0
for y in ys:
M2 += (y - μ) ** 2
# Compute variance from `m2`
variance = M2 / (N - 1)
Uh-oh! There's two loops here - the first loop is computing the mean (μ), and the second loop computes M2
*. That's two passes through the data, and that's against the rules!
Please note that here and elsewhere I am calling quantities like the "central moments". Technically the central moments are , not . It just less of a mouthful than saying 'the sum of the p'th power of the differences from the mean'
Anyway, remember our recurrence relation for the mean? Wouldn't it be nice if there was a recurrence relation for the variance? Well lucky for us, there is! That means we can compute the variance in one pass.
The formula is as follows:
(Psst... here's a derivation)
In the above formula, we use a recurrence relation to compute the updated second central moment based on the current second central moment , and then divide the updated second central moment by (N-1) to get the updated variance
You'll notice that the formula makes use of both the current and updated mean, μ and μ', so a complete code implementation would also need to include an implementation of the mean update formula from earlier.
Even so, the implementation is pretty straightforward:
M2 = 0
mean = 0
old_mean = 0
N = 0
for y in ys:
N = N + 1
old_mean = mean
mean = old_mean + (y - mean) / N
M2 = M2 + (y - mean) * (y - old_mean)
variance = M2 / (N-1)
Central Moments
As you might have guessed, if there is a second central moment there are first, third, fourth, fifth and so on central moments ( ).
The general formula for central moments is:
where p is the order of the moment. Again, I'm diverging from convention and calling these quantities the moments, even though the stats types will rightly point out I'm missing the division by N.
One common use for the third and second central moments is calculating the "skewness" of a distribution. Somewhat counterintuitively, the below distribution is "right skewed" even though the lump "leans left". That's because the bulk of the data (or its "mass") is located to the right of the peak.
In terms of the second and third central moment , the skewness is:
But if we are going to play our number reading game with skewness, we would need a recurrence relation for the third central moment. And just to future-proof ourselves, while we're at it, why don't we work out a recurrence relation for arbitrary central moments?
Recurrence Relations for Arbitrary Central Moments
It's difficult to calculate the central moments of order p in a single pass for the same reason that p=2 and p=3 were difficult - the presence of μ in the middle of the equation:
However, there have been several research papers dedicated to the incremental (or recursive) computation of central moments. More recent papers have adapted older approaches to address numerical foot-guns like catastrophic cancellation. In this post I am largely adapting this paper from the mid 2000's.
The authors describe a pairwise update formulas for means and central moments. The idea is that in order to compute the central moments of a dataset , you can combine the moments from datasets A and B according to their formula. And to compute the moments of, say, dataset A, you could divide the dataset further, and so on, until we reach the base case of single elements which have known moments.
This isn't the approach we want to take, but we can adapt it rather easily by taking A to be the current dataset, and B to be the set containing the single observation y, B = {y}.
Here is the pairwise formula, adapted from the paper:
Where .
Now, let's adapt this formula to the incremental case, where B={y}. The formula simplifies quite a bit because all central moments of a singleton B={y} are zero except for .
So, in other words, the central moments of B when B is just the single element {y} are:
With that in mind, the term beginning with drops out of the summation except when p = k, and we are left with:
"Almost" because we are missing one final term when k = p (and therefore ).
That simplifies to:
Add "Almost" and "Rest of", and we've got the complete equation.
And now we're ready to start coding!
Yea! Enough Math, Show Me Some Code!
I've been a bit math-heavy, and that's because there's so much ground to cover. But the implementation is interesting as well - and not completely straightforward!
We would like to create a data structure that stores the bounds (minimum and maximum) as well as the mean and first 10 central moments. We will provide convenience methods that give the variance, standard deviation, skewness, and kurtosis.
An important note: In this implementation I am storing the "zero'th" moment (p = 0) in index 0 of self.moments
. This makes the indexing of the array line up with the 1-based indexing in the papers, which is nice. For example, the "second central moment" is stored in self.moments[2]
as opposed to the somewhat confusing self.moments[1]
. What's more, because of how central moments are defined, self.moments[0]
is always N, the count of the dataset. So, we get a "count tracker" for N for free.
@dataclass
class StatsBucket:
# How many orders of central moments to compute
n_moments : int = 0
# The central moments, with index 0 containing the count and index 1 containing the sum
moments : List[float] = field(default_factory = list)
# The mean - must be separately updated because we are storing central moments
mean : float = None
# minimum and maximum - separately updated and useful for bounds
minimum : float = None
maximum : float = None
def __init__(self,
n_moments : int,
moments : Optional[List[float]] = None,
mean : Optional[float] = None):
self.n_moments = n_moments
if moments is not None:
self.moments = moments
self.mean = mean
else:
self.moments = [0] + [0]*(self.n_moments)
The implementations for the summary statistics are all straightforward and based on definitions you can find on Wikipedia:
def n(self):
return self.moments[0]
def sample_mean(self):
return self.mean
def sample_variance(self):
return self.moments[2]/self.n()
def corrected_sample_variance(self):
return self.moments[2]/(self.n()-1)
def corrected_sample_stdev(self):
return self.corrected_sample_variance()**0.5
def sample_stdev(self):
return self.sample_variance()**0.5
def sample_skewness(self):
return (self.moments[3]/self.n()) / ((self.moments[2]/self.n())**1.5)
def sample_excess_kurtosis(self):
return (self.moments[4]/self.n()) / (self.moments[2]/self.n())**2.0 - 3
Now, we'll need a few methods:
-
initialize(ys : List[float])
for initializing an emptyStatsBucket
with the observationsys
and computing all the central moments and statistics in one pass. -
update(self, y : float)
for incorporating a new observation into an initializedRunningStats
instance and updating all relevant statistics and central moments. -
combine(self, other)
for computing the central moments ofStatsBuckets
representing sets A and B. The existingStatsBucket
is updated.
Here is the formula for updating the moment of order p, where all current central moments are stored in the array Ms
:
@staticmethod
def calculate_updated_moment(p : int,
Ms: List[float],
mean : float,
y : float) -> float:
s21 = y - mean
n = Ms[0] + 1
n1 = Ms[0]
n2 = 1
Σ = 0
for k in range(0,p+1):
res = math.comb(p,k) * ((Ms[p-k] * (s21*(-n2/n))**k ))
Σ += res
Σ+= ((s21*n1/n)**p )
return Σ
Now, we would like to update all moments of order 0 through P. And presumably we will be doing this in a loop over a large number of observations, so we would like to avoid list allocations if possible.
So, it would be a good idea to update self.moments
in place, through something like:
for i in range(p):
self.moments[i] = ...calculate updated moment i...
However, if you take a look at the recurrence relation, the formula for central moment p uses all central moments order 0 through p. So if we attempt to update self.moments
in-place, we'll "clobber" the results and get an incorrect result.
So, we have to update the moments in reverse order. The correct loop looks like this:
for i in range(p,-1,-1):
self.moments[i] = ...calculate updated moment i...
Here is the Python code, which updates Ms
in-place and returns an updated mean
, min_
, and max_
back to the caller:
@staticmethod
def update_stats(Ms : List[float],
mean : float,
min_ : float,
max_ : float,
y : float):
# backwards iteration here super important to avoid self-clobbering during computation
P = len(Ms)-1
for p in range(P,-1,-1):
Ms[p] = StatsBucket.calculate_updated_moment(p,
Ms,
mean,
y)
# update mean
n = Ms[0]
mean = mean + (y-mean) / (n)
# update bounds
min_ = min(y,min_)
max_ = max(y,max_)
return mean, min_, max_
The implementation of initialize
is essentially that of update
. The combine
calls a static method which uses the pairwise formulas from the paper. M1s
are the central moments from set A, M2s
are the central moments from set B, and s21
is the mean of set B subtracted from the mean of set A.
@staticmethod
def calculate_combined_moment(p : int,
M1s : List[float],
M2s: List[float],
s21 : float) -> float:
n,n1,n2 = M1s[0] + M2s[0], M1s[0], M2s[0]
Σ = 0
for k in range(0,p+1):
res = math.comb(p,k) * ((M1s[p-k] * (s21*(-n2/n))**k ) + (M2s[p-k] * (s21*n1/n)**k ))
Σ += res
return Σ
Does It Work?
You bet it does. Let's create a StatsBucket
that computes the first 10 central moments, and initialize it with 100 observations from a normal distribution:
bucket = StatsBucket(n_moments = 10)
ys = np.random.randn(100)
bucket.initialize(ys)
Then we'll print out various sample statistics and compare them to the corresponding calculations from numpy:
print(np.mean(ys), '=', bucket.sample_mean())
print(np.var(ys, correction = 1), '=', bucket.corrected_sample_variance())
print(np.std(ys, correction = 1), '=', bucket.corrected_sample_stdev())
print(scipy.stats.skew(ys), '=', bucket.sample_skewness())
print(scipy.stats.kurtosis(ys), '=', bucket.sample_excess_kurtosis())
Here are the results - they all agree to at least the 11th decimal place:
0.04460728719077849 = 0.04460728719077859
0.9989105114779977 = 0.9989105114780001
0.9994551072849635 0.9994551072849647
0.04379407058692495 0.04379407058692474
-0.022496024676501136 -0.022496024676511794
Now, let's take a look at the central moments. First, we can compute them in the traditional way:
expected_moments = []
for m in range(10+1):
u = np.mean(ys)
expected_moments.append((sum((y-u)**m for y in ys)))
And then we can test to see if they all agree to within 6 decimal places, which they do:
expected_moments = np.asarray(expected_moments)
stats_bucket_moments = np.asarray(bucket.moments)
abs_diff = np.abs(expected_moments - stats_bucket_moments)
print(np.all(abs_diff < 1e-6))
Result:
> True
Summary
In Part 1 of this blog post series, we showed how we can incrementally update arbitrary order central moments using a one-pass, in-place algorithm. In the next part we'll use the central moments to approximate the shape of the data distribution. We are well on our way to having a powerful one-pass data analyzer.
Stay tuned!
Shameless Plug
After doing the founder mode grind for a year, I'm looking for work. I've got over 15 years experience in a wide variety of technical fields and programming languages and also experience managing teams. Math and statistics are focus areas. DM me and let's talk!
Final Note
A shoutout is deserved for John D. Cook, for making me aware that it was even possible to compute central moments through recurrence relations.
Top comments (0)