Calculating linear recurrences using matrices
Detailed explanation of calculating linear recurrence relations in logarithmic time using matrix exponentiation.
## Introduction
This blog post is dedicated to one of my favorite optimization techniques ever: using matrix exponentiation to calculate linear recurrence relations in logarithmic time.
Basically, the technique boils down to constructing a special transformation matrix that, when exponentiated, produces the coefficients of the th term in the linear recurrence relation.
I first learned about this technique from this CodeForces tutorial some years ago. My “problem” with it is that, while it gets the general point across, it doesn’t explain why it works, assumes a lot of prior knowledge, and is therefore inaccessible for beginners. This post’s goal is to provide a more detailed explanation through step-by-step examples so that you can build an intuition for it and understand how to apply and adapt this technique to various problems.
## Prerequisites
### Linear recurrence relations
A linear recurrence relation is an equation that relates a term in a sequence to previous terms using recursion. The use of the word linear refers to the fact that previous terms are arranged as a first-degree polynomial in the recurrence relation.
A linear recurrence relation has the following form:
refers to the -th term in the sequence, and refers to the (constant) coefficient in front of the -th term in the linear recurrence relation.
All linear recurrence relations must have a “base case”, i.e. the first values of the sequence must be known.
A famous example of a sequence defined by a linear recurrence relation is the Fibonacci sequence, where we have the following:
In the Fibonacci case, we have , and .
Note that can be equal to as well, which would correspond to skipping a term, e.g.:
In this case, we have , , , and .
### Matrices
A matrix is a rectangular array of numbers, arranged in rows and columns. In computer terms, it’s an array of arrays of numbers, e.g. int[][]
.
For example,
is a matrix with two rows and three columns, or a “two by three” matrix. Again, in computer terms, it’d be an int[2][3]
:
int matrix[2][3] = { {1, 9, -13}, {20, 5, 6} };
Matrices are used in a wide range of mathematical areas, including, but not limited to, linear algebra and graph theory. If you’ve worked with graphs, you might’ve heard of adjacency matrices before.
### Matrix multiplication
There’s a lot of complicated math surrounding matrices, but all you need to know to understand this technique is what matrix multiplication is.
Matrix multiplication is an operation that takes two matrices, e.g. and , one with dimensions and , and the other with dimensions and (i.e. the number of columns in the first matrix must be the same as the number of rows in the second matrix), and produces a new matrix with dimensions and . Each entry in the new matrix is obtained by multiplying term-by-term the entries of a row of with a column of .
In other words, matrix multiplication is an operation that takes the th row of and the th column of B, multiplies their entries term-by-term, sums them up, and puts the number that comes out in the th row and th column of the new matrix.
It’s a mouthful, but it’s not as complicated as it sounds. Here’s a color-coded example so you know what’s what:
Got it? Good.
There are a few other things I need to mention that will become important later:
- Matrix multiplication is not commutative, i.e. .
- Matrix multiplication is associative, i.e. .
- The identity matrix is the matrix equivalent of in regular multiplication. It is a square matrix that has on the main diagonal and everywhere else, e.g.:
When you multiply any matrix by the identity matrix with the corresponding size, you get the same matrix back.
If you understood how matrix multiplication works, it should be clear why it works that way. If not, I encourage you to grab a pen and paper and multiply some matrices by hand. Before proceeding with the rest of this post, you must know how matrices are multiplied; otherwise, it won’t make much sense.
Here’s an implementation of matrix multiplication:
template <class T> using vec = std::vector<T>; template <class T> using mat = vec<vec<T>>; template <class T> mat<T> operator*(mat<T> const& a, mat<T> const& b) { int m = a.size(), n1 = a[0].size(), n2 = b.size(), p = b[0].size(); assert(n1 == n2); mat<T> res(m, vec<T>(p, T(0))); for (auto i = 0; i < m; ++i) { for (auto j = 0; j < p; ++j) { for (auto k = 0; k < n1; ++k) { res[i][j] += a[i][k] * b[k][j]; } } } return res; }
It should be easy to see that its complexity is .
### Binary exponentiation
Binary exponentiation (also known as exponentiation by squaring) is a trick that allows to calculate using multiplications instead of the multiplications required by the naive approach.
It also has important applications in many tasks unrelated to arithmetic, since it can be used with any operations that have the associativity property.
To keep this post from becoming overly long and all over the place, here are some good resources on binary exponentiation:
- This cp-algorithms.com entry
- This video by Errichto
## Computing linear recurrences
To build up to the main point of this post, let’s first look at some other ways one might implement a function that computes a linear recurrence relation.
For the sake of simplicity, let’s take the Fibonacci sequence example (defined earlier).
### The naive way
The most obvious way to compute the -th Fibonacci number is to write a function like this one:
uint64_t fibonacci(uint n) { if (n <= 1) { return n; } return fibonacci(n - 1) + fibonacci(n - 2); }
On my hardware, this function can compute Fibonacci numbers until about before becoming noticeably slow. But why’s that?
The answer is that we’re computing the same values multiple times, over and over again. For example, to compute fibonacci(46)
, we’re calling fibonacci(2)
times in the process! It is obvious that this approach cannot work for large values of .
### Memoization
To avoid recomputing the same values over and over again, we can save them in some data structure like an array or a hash map, so that we can compute them once and simply retrieve them from said data structure the next time we need them. This is an idea you might recognize if you’re familiar with the concept of dynamic programming (hence why the array below is called dp
).
For the sake of simplicity, we’ll define the upper bound on to be (which is the largest Fibonacci number that can fit in a 64-bit unsigned integer data type) and use a fixed-size array:
#define MAX_N 93 uint64_t dp[MAX_N + 1] = {0}; uint64_t fibonacci(uint n) { if (n <= 1) { return n; } auto& result = dp[n]; if (result == 0) { result = fibonacci(n - 1) + fibonacci(n - 2); } return result; }
Since all Fibonacci numbers (except the -th number) are greater than , we can use to signify a missing value that needs to be computed.
On my hardware, this computes all values of Fibonacci up to inclusive pretty much instantly. So, problem solved?
Well, yes and no. While memoization solves the recomputation problem, it introduces a new problem: memory. In this case, it’s not that big of a deal, as we’re only going up to , but if we go further, e.g. if we want to compute Fibonacci numbers up to modulo some number, we’d need to store 64-bit integers in memory, which would take up approximately 8 gigabytes.
### Iterative approach
We can come up with an iterative approach for computing Fibonacci numbers that requires a constant amount of memory if we notice the following: to compute the next Fibonacci number, we only need to know what the previous two Fibonacci numbers are. We don’t need to store all previous values in memory.
So, we can have three variables, out of which the first two would store the previous two Fibonacci numbers, and the third would store the new number, and each time we compute the next Fibonacci number, we “shift” their values to the left:
uint64_t fibonacci(uint n) { if (n <= 1) { return n; } uint64_t prev = 0, curr = 1, next; for (auto i = 2u; i <= n; ++i) { next = prev + curr; prev = curr; curr = next; } return curr; }
On each iteration, prev
contains the value of and curr
contains the value of , and they’re added together and stored in curr
, which results in . After that, prev
gets assigned to curr
, and curr
gets assigned to next
, which shifts the values to the left, replacing and with and , so that can be computed on the next iteration.
Due to the order in which the operations are executed, at the end of the loop, both curr
and next
contain the value of , so either one can be used as the final result.
This approach can be generalized and can be used to compute the -th term of any linear recurrence relation in time and space (where k
is the amount of initial values/coefficients and is a constant):
uint64_t compute(uint n, std::vector<uint64_t> const& initial_values, std::vector<uint64_t> const& coefficients) { assert(initial_values.size() == coefficients.size()); auto k = initial_values.size(); if (n < k) { return initial_values[n]; } auto values = initial_values; values.push_back(0); for (auto i = k; i <= n; ++i) { values.back() = std::inner_product(values.begin(), values.end() - 1, coefficients.begin(), 0ULL); std::shift_left(values.begin(), values.end(), 1); } return values.back(); }
The compute
function above does the same thing as our iterative fibonacci
function, except it supports arbitrary linear recurrence relations, defined by their initial values and coefficients. We can make our new function compute Fibonacci numbers by passing {0, 1}
as the initial values and {1, 1}
for the coefficients.
For those of you who aren’t too familiar with the C++ STL, the call to std::inner_product
multiplies each of the previous values (represented by the range from values.begin()
to values.end() - 1
) by the corresponding coefficient, and sums them up.
### Using matrix exponentiation
Let’s think about the previous approach a bit more, in particular, the step that computes the next value. To compute the next value, you need to multiply the entries from values
and coefficients
term-by-term, then sum them up. Does that sound familiar?
If you recall, this is exactly how each element in the resulting matrix is computed when doing matrix multiplication. So, what if we can construct a matrix that when multiplied by the initial values of the sequence, produces the next value and simultaneously shifts the previous values to the left?
Since matrix multiplication operates on the rows of the first matrix and the columns of the second matrix, we can go in one of two directions: we can either organize the initial values as a matrix and have the columns of our “transformation” matrix set up in a way that does what we want, in which case the initial values must be on the left side of the multiplication, or we can rotate both matrices by 90 degrees and have the initial values on the right side of the multiplication. I will go with the latter, as it’s the way I usually do it, but both options yield the same results.
I will refer to the so-called transformation matrix as and the initial values as .
So, let’s think about the “shifting” part of things first. We want each row in the resulting matrix to be shifted by one. We can achieve this by taking the identity matrix and shifting it to the right, like this:
I think the best way to see what this does is through an example, so let’s look at what would happen for :
Our strategically placed ones do exactly what we want! So, all we need to do is fill the last row, and if you’re still with me, it should be obvious how: we fill it with the coefficients. So, our final matrix looks like this:
Now we have a matrix that produces the next values in our sequence. So, to compute the -th term in the sequence, we need to multiply by the values times, then take the first entry. And now, for the part that makes this fast: remember that matrix multiplication is associative? This means we can multiply by itself times (in other words, ) first, then multiply by the values. And how can we compute quickly? That’s right, binary exponentiation!
This approach has a total time complexity of , coming from the matrix multiplication algorithm, and coming from the binary exponentiation algorithm. Since is a constant in the context of a specific linear recurrence relation problem, the execution time only grows logarithmically with , which, as you know, is much better than linear.
Since the full implementation is long, I put it in a gist.
## Final notes
If this still feels like magic, I encourage you to sit down with a pen and paper. Come up with some linear recurrence relation, calculate some values by hand, try repeatedly substituting into the recurrence relation up to some fixed , then exponentiate the transformation matrix by hand, observe the values in the intermediate results, and see if you can notice how everything ties together.
This technique seemed like voodoo to me at first as well, and going through things by hand was the thing that made it click in my head. I hope the way I built up to it makes it more obvious what the relationship between this approach and the easier-to-understand linear approach is.