Calculating linear recurrences using matrices

#math (1)#compsci (2)
Thumbnail

Detailed explanation of calculating linear recurrence relations in logarithmic time using matrix exponentiation.

## Introduction

This blog post is dedicated to one of my favorite optimization techniques ever: using matrix exponentiation to calculate linear recurrence relations in logarithmic time.

Basically, the technique boils down to constructing a special transformation matrix that, when exponentiated, produces the coefficients of the nnth term in the linear recurrence relation.

I first learned about this technique from this CodeForces tutorial some years ago. My “problem” with it is that, while it gets the general point across, it doesn’t explain why it works, assumes a lot of prior knowledge, and is therefore inaccessible for beginners. This post’s goal is to provide a more detailed explanation through step-by-step examples so that you can build an intuition for it and understand how to apply and adapt this technique to various problems.

## Prerequisites

### Linear recurrence relations

A linear recurrence relation is an equation that relates a term in a sequence to previous terms using recursion. The use of the word linear refers to the fact that previous terms are arranged as a first-degree polynomial in the recurrence relation.

A linear recurrence relation has the following form:

xn=c1×xn1+c2×xn2++ck×xnkx_n = c_1 \times x_{n - 1} + c_2 \times x_{n - 2} + \cdots + c_{k} \times x_{n - k}

xnx_n refers to the nn-th term in the sequence, and cic_i refers to the (constant) coefficient in front of the xnix_{n - i}-th term in the linear recurrence relation.

All linear recurrence relations must have a “base case”, i.e. the first kk values of the sequence must be known.

A famous example of a sequence defined by a linear recurrence relation is the Fibonacci sequence, where we have the following:

x0=0x1=1xn=xn1+xn2x_0 = 0 \\ x_1 = 1 \\ x_n = x_{n - 1} + x_{n - 2}

In the Fibonacci case, we have k=2k = 2, c1=1c_1 = 1 and c2=1c_2 = 1.

Note that cic_i can be equal to 00 as well, which would correspond to skipping a term, e.g.:

xn=xn1+3×xn3x_n = x_{n - 1} + 3 \times x_{n - 3}

In this case, we have k=3k = 3, c1=1c_1 = 1, c2=0c_2 = 0, and c3=3c_3 = 3.

### Matrices

A matrix is a rectangular array of numbers, arranged in rows and columns. In computer terms, it’s an array of arrays of numbers, e.g. int[][].

For example,

[19132056]\begin{bmatrix} 1 & 9 & -13\\ 20 & 5 & 6 \end{bmatrix}

is a matrix with two rows and three columns, or a “two by three” matrix. Again, in computer terms, it’d be an int[2][3]:

int matrix[2][3] = {
    {1, 9, -13},
    {20, 5, 6}
};  

Matrices are used in a wide range of mathematical areas, including, but not limited to, linear algebra and graph theory. If you’ve worked with graphs, you might’ve heard of adjacency matrices before.

### Matrix multiplication

There’s a lot of complicated math surrounding matrices, but all you need to know to understand this technique is what matrix multiplication is.

Matrix multiplication is an operation that takes two matrices, e.g. AA and BB, one with dimensions MM and NN, and the other with dimensions NN and PP (i.e. the number of columns in the first matrix must be the same as the number of rows in the second matrix), and produces a new matrix with dimensions MM and PP. Each entry in the new matrix is obtained by multiplying term-by-term the entries of a row of AA with a column of BB.

In other words, matrix multiplication is an operation that takes the iith row of AA and the jjth column of B, multiplies their entries term-by-term, sums them up, and puts the number that comes out in the iith row and jjth column of the new matrix.

It’s a mouthful, but it’s not as complicated as it sounds. Here’s a color-coded example so you know what’s what:

[1234]×[5678]=[1×5+2×71×6+2×83×5+4×73×6+4×8]=[19224350]\begin{bmatrix} \color{#dc2626}{1} & \color{#dc2626}{2} \\ \color{#2563eb}{3} & \color{#2563eb}{4} \end{bmatrix} \times \begin{bmatrix} \color{#16a34a}{5} & \color{#ca8a04}{6} \\ \color{#16a34a}{7} & \color{#ca8a04}{8} \end{bmatrix} = \begin{bmatrix} \color{#dc2626}{1} \color{white}{\times} \color{#16a34a}{5} \color{white}{+} \color{#dc2626}{2} \color{white}{\times} \color{#16a34a}{7} & \color{#dc2626}{1} \color{white}{\times} \color{#ca8a04}{6} \color{white}{+} \color{#dc2626}{2} \color{white}{\times} \color{#ca8a04}{8} \\ \color{#2563eb}{3} \color{white}{\times} \color{#16a34a}{5} \color{white}{+} \color{#2563eb}{4} \color{white}{\times} \color{#16a34a}{7} & \color{#2563eb}{3} \color{white}{\times} \color{#ca8a04}{6} \color{white}{+} \color{#2563eb}{4} \color{white}{\times} \color{#ca8a04}{8} \end{bmatrix} = \begin{bmatrix} 19 & 22 \\ 43 & 50 \end{bmatrix}

Got it? Good.

There are a few other things I need to mention that will become important later:

  • Matrix multiplication is not commutative, i.e. A×BB×AA \times B ≠ B \times A.
  • Matrix multiplication is associative, i.e. (A×B)×C=A×(B×C)(A \times B) \times C = A \times (B \times C).
  • The identity matrix is the matrix equivalent of 11 in regular multiplication. It is a square matrix that has 11 on the main diagonal and 00 everywhere else, e.g.:

[100010001]\begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix}

When you multiply any matrix by the identity matrix with the corresponding size, you get the same matrix back.

If you understood how matrix multiplication works, it should be clear why it works that way. If not, I encourage you to grab a pen and paper and multiply some matrices by hand. Before proceeding with the rest of this post, you must know how matrices are multiplied; otherwise, it won’t make much sense.

Here’s an implementation of matrix multiplication:

template <class T>
using vec = std::vector<T>;

template <class T>
using mat = vec<vec<T>>;

template <class T>
mat<T> operator*(mat<T> const& a, mat<T> const& b) {
  int m = a.size(), n1 = a[0].size(), n2 = b.size(), p = b[0].size();
  assert(n1 == n2);

  mat<T> res(m, vec<T>(p, T(0)));
  for (auto i = 0; i < m; ++i) {
    for (auto j = 0; j < p; ++j) {
      for (auto k = 0; k < n1; ++k) {
        res[i][j] += a[i][k] * b[k][j];
      }
    }
  }

  return res;
}  

It should be easy to see that its complexity is O(M×N×P)O(M {\times} N {\times} P).

### Binary exponentiation

Binary exponentiation (also known as exponentiation by squaring) is a trick that allows to calculate ana^n using O(log2n)O(\log_2{n}) multiplications instead of the O(n)O(n) multiplications required by the naive approach.

It also has important applications in many tasks unrelated to arithmetic, since it can be used with any operations that have the associativity property.

To keep this post from becoming overly long and all over the place, here are some good resources on binary exponentiation:

## Computing linear recurrences

To build up to the main point of this post, let’s first look at some other ways one might implement a function that computes a linear recurrence relation.

For the sake of simplicity, let’s take the Fibonacci sequence example (defined earlier).

### The naive way

The most obvious way to compute the nn-th Fibonacci number is to write a function like this one:

uint64_t fibonacci(uint n) {
  if (n <= 1) {
    return n;
  }

  return fibonacci(n - 1) + fibonacci(n - 2);
}  

On my hardware, this function can compute Fibonacci numbers until about n40n \leq 40 before becoming noticeably slow. But why’s that?

The answer is that we’re computing the same values multiple times, over and over again. For example, to compute fibonacci(46), we’re calling fibonacci(2) 1,134,903,1701,134,903,170 times in the process! It is obvious that this approach cannot work for large values of nn.

### Memoization

To avoid recomputing the same values over and over again, we can save them in some data structure like an array or a hash map, so that we can compute them once and simply retrieve them from said data structure the next time we need them. This is an idea you might recognize if you’re familiar with the concept of dynamic programming (hence why the array below is called dp).

For the sake of simplicity, we’ll define the upper bound on nn to be 9393 (which is the largest Fibonacci number that can fit in a 64-bit unsigned integer data type) and use a fixed-size array:

#define MAX_N 93

uint64_t dp[MAX_N + 1] = {0};

uint64_t fibonacci(uint n) {
  if (n <= 1) {
    return n;
  }

  auto& result = dp[n];
  if (result == 0) {
    result = fibonacci(n - 1) + fibonacci(n - 2);
  }

  return result;
}  

Since all Fibonacci numbers (except the 00-th number) are greater than 00, we can use 00 to signify a missing value that needs to be computed.

On my hardware, this computes all values of Fibonacci up to 9393 inclusive pretty much instantly. So, problem solved?

Well, yes and no. While memoization solves the recomputation problem, it introduces a new problem: memory. In this case, it’s not that big of a deal, as we’re only going up to n93n \leq 93, but if we go further, e.g. if we want to compute Fibonacci numbers up to n109n \leq 10^9 modulo some number, we’d need to store 10910^9 64-bit integers in memory, which would take up approximately 8 gigabytes.

### Iterative approach

We can come up with an iterative approach for computing Fibonacci numbers that requires a constant amount of memory if we notice the following: to compute the next Fibonacci number, we only need to know what the previous two Fibonacci numbers are. We don’t need to store all previous values in memory.

So, we can have three variables, out of which the first two would store the previous two Fibonacci numbers, and the third would store the new number, and each time we compute the next Fibonacci number, we “shift” their values to the left:

uint64_t fibonacci(uint n) {
  if (n <= 1) {
    return n;
  }

  uint64_t prev = 0, curr = 1, next;
  for (auto i = 2u; i <= n; ++i) {
    next = prev + curr;
    prev = curr;
    curr = next;
  }

  return curr;
}  

On each iteration, prev contains the value of Fi2F_{i - 2} and curr contains the value of Fi1F_{i - 1}, and they’re added together and stored in curr, which results in FiF_{i}. After that, prev gets assigned to curr, and curr gets assigned to next, which shifts the values to the left, replacing Fi2F_{i-2} and Fi1F_{i-1} with Fi1F_{i-1} and FiF_{i}, so that Fi+1F_{i + 1} can be computed on the next iteration.

Due to the order in which the operations are executed, at the end of the loop, both curr and next contain the value of FnF_{n}, so either one can be used as the final result.

This approach can be generalized and can be used to compute the nn-th term of any linear recurrence relation in O(nk)O(nk) time and O(k)O(k) space (where k is the amount of initial values/coefficients and is a constant):

uint64_t compute(uint n,
                 std::vector<uint64_t> const& initial_values,
                 std::vector<uint64_t> const& coefficients) {
  assert(initial_values.size() == coefficients.size());

  auto k = initial_values.size();
  if (n < k) {
    return initial_values[n];
  }

  auto values = initial_values;
  values.push_back(0);

  for (auto i = k; i <= n; ++i) {
    values.back() = std::inner_product(values.begin(), values.end() - 1,
                                       coefficients.begin(), 0ULL);
    std::shift_left(values.begin(), values.end(), 1);
  }

  return values.back();
}  

The compute function above does the same thing as our iterative fibonacci function, except it supports arbitrary linear recurrence relations, defined by their initial values and coefficients. We can make our new function compute Fibonacci numbers by passing {0, 1} as the initial values and {1, 1} for the coefficients.

For those of you who aren’t too familiar with the C++ STL, the call to std::inner_productmultiplies each of the previous values (represented by the range from values.begin() to values.end() - 1) by the corresponding coefficient, and sums them up.

### Using matrix exponentiation

Let’s think about the previous approach a bit more, in particular, the step that computes the next value. To compute the next value, you need to multiply the entries from values and coefficients term-by-term, then sum them up. Does that sound familiar?

If you recall, this is exactly how each element in the resulting matrix is computed when doing matrix multiplication. So, what if we can construct a matrix that when multiplied by the initial values of the sequence, produces the next value and simultaneously shifts the previous values to the left?

Since matrix multiplication operates on the rows of the first matrix and the columns of the second matrix, we can go in one of two directions: we can either organize the initial values as a 1×k1 \times k matrix and have the columns of our “transformation” matrix set up in a way that does what we want, in which case the initial values must be on the left side of the multiplication, or we can rotate both matrices by 90 degrees and have the initial values on the right side of the multiplication. I will go with the latter, as it’s the way I usually do it, but both options yield the same results.

I will refer to the so-called transformation matrix as TT and the initial values as FF.

So, let’s think about the “shifting” part of things first. We want each row in the resulting matrix to be shifted by one. We can achieve this by taking the identity matrix and shifting it to the right, like this:

[0100000100000100000100000]\begin{bmatrix} 0 & \color{#16a34a}{1} & 0 & 0 & \cdots & 0 \\ 0 & 0 & \color{#16a34a}{1} & 0 & \cdots & 0 \\ 0 & 0 & 0 & \color{#16a34a}{1} & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \color{#16a34a}{1} \\ 0 & 0 & 0 & 0 & \cdots & 0 \end{bmatrix}

I think the best way to see what this does is through an example, so let’s look at what would happen for k=3k = 3:

[010001000]×[x0x1x2]=[0×x0+1×x1+0×x20×x0+0×x1+1×x20×x0+0×x1+0×x2]=[x1x20]\begin{bmatrix} 0 & \color{#16a34a}{1} & 0 \\ 0 & 0 & \color{#16a34a}{1} \\ 0 & 0 & 0 \end{bmatrix} \times \begin{bmatrix} x_0 \\ x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} 0 {\times} x_0 + \color{#16a34a}{1} \color{white}{\times} x_1 + 0 {\times} x_2 \\ 0 {\times} x_0 + 0 {\times} x_1 + \color{#16a34a}{1} \color{white}{\times} x_2 \\ 0 {\times} x_0 + 0 {\times} x_1 + 0 {\times} x_2 \end{bmatrix} = \begin{bmatrix} x_1 \\ x_2 \\ 0 \end{bmatrix}

Our strategically placed ones do exactly what we want! So, all we need to do is fill the last row, and if you’re still with me, it should be obvious how: we fill it with the coefficients. So, our final matrix looks like this:

[01000001000001000001ckck1ck2ck3c1]\begin{bmatrix} 0 & \color{#16a34a}{1} & 0 & 0 & \cdots & 0 \\ 0 & 0 & \color{#16a34a}{1} & 0 & \cdots & 0 \\ 0 & 0 & 0 & \color{#16a34a}{1} & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & 0 & 0 & \cdots & \color{#16a34a}{1} \\ \color{#16a34a}{c_k} & \color{#16a34a}{c_{k-1}} & \color{#16a34a}{c_{k-2}} & \color{#16a34a}{c_{k-3}} & \cdots & \color{#16a34a}{c_1} \end{bmatrix}

Now we have a matrix that produces the next values in our sequence. So, to compute the nn-th term in the sequence, we need to multiply TT by the values nn times, then take the first entry. And now, for the part that makes this fast: remember that matrix multiplication is associative? This means we can multiply TT by itself nn times (in other words, TnT^n) first, then multiply by the values. And how can we compute TnT^n quickly? That’s right, binary exponentiation!

This approach has a total time complexity of O(k3×log2n)O(k^3 \times log_2{n}), k3k^3 coming from the matrix multiplication algorithm, and log2nlog_2{n} coming from the binary exponentiation algorithm. Since kk is a constant in the context of a specific linear recurrence relation problem, the execution time only grows logarithmically with nn, which, as you know, is much better than linear.

Since the full implementation is long, I put it in a gist.

## Final notes

If this still feels like magic, I encourage you to sit down with a pen and paper. Come up with some linear recurrence relation, calculate some values by hand, try repeatedly substituting into the recurrence relation up to some fixed nn, then exponentiate the transformation matrix by hand, observe the values in the intermediate results, and see if you can notice how everything ties together.

This technique seemed like voodoo to me at first as well, and going through things by hand was the thing that made it click in my head. I hope the way I built up to it makes it more obvious what the relationship between this approach and the easier-to-understand linear approach is.