Efficient matrix-to-matrix power calculations

We are looking at ways to make matrix-to-matrix power calculations work.

We have matrices A and B of the same size.
For each row i, we want to calculate

a_{i1}^{b_{i1}} * a_{i2}^{b_{i2}} * … * a_{ik}^{b_{ik}}

Using numpy in python, you can write the following.

np.prod(A[i,:]**B[i,:])

We would be grateful if you could tell us how to calculate this efficiently.
Thank you in advance for your cooperation.

1 Like

Sorry to be the bearer of bad news, but computing matrix exponentials is expensive! Each multiply is \mathcal{O}(N^3) for an N \times N matrix a.

The naive approach takes n - 1 multiples to evaluate matrix_pow(a, n). NumPy is much more clever. It not only peels off powers of 2, it simultaneously deals with the remainders. You can see their code here:

Their approach exploits the transitivity of matrix multiply, namely that

((a \cdot a) \cdot a) \cdot a = (a \cdot a) \cdot (a \cdot a).

Evaluating according to right side you can get away with two multiplies if you save a \cdot a rather than 3 in the naive approach. You can keep scaling that up as powers of 2, but then you’ll have something left over. For instance if I take n = 31, I can get up to 16, then I have 15 left over. Rather than start over, the power of 15 reuses the power of 8, the power of 4, and nthe power of 2. That’s what the remainder arithmetic’s doing in the NumPy code.

Here’s a Stan translation of the NumPy algorithm, along with code in the data and transformed data blocks you can use to validate. I checked it gets the right answer for both positive and negative inputs, but be careful with negative powers because inverse can be very numerically unstable.

matrix-pow.stan

functions {
  matrix matrix_pow(matrix a, int n) {
    if (rows(a) != cols(a)) {
      reject("a must be symmetric, but was ", rows(a), " x ", cols(a));
    }
    int M = rows(a);
    if (n == 0) return diag_matrix(rep_vector(1, rows(a)));
    if (n < 0) {
      return matrix_pow(inverse(a), -n);
    }
    if (n == 1) {
      return a;
    }
    if (n == 2) {
      return a * a;
    }
    if (n == 3) {
      return a * a * a;
    }
    int result_defined = 0;
    matrix[M, M] result;
    int z_defined = 0;
    matrix[M, M] z;
    int nn = n; // nn can be modified, n can't
    while (nn > 0) {
      if (!z_defined) {
	z = a;
	z_defined = 1;
      } else {
	z = z * z;
      }
      int bit = modulus(nn, 2);
      nn /= 2;
      if (bit) {
	if (!result_defined) {
	  result = z;
	  result_defined = 1;
	} else {
	  result = result * z;
	}
      }
    }
    return result;
  }
}
data {
  int<lower=0> M;
  matrix[M, M] a;
  int n;
}
transformed data {
  print("****************************************");
  print("***** matrix power(", a, ", ", n, "): *****", matrix_pow(a, n));
  print("****************************************");
}

Here’s some CmdStanPy code to test:

>>> import cmdstanpy as csp
>>> import numpy as np

>>> a = np.array([[1, 0.4], [0.6, 1.1]])

>>> a
array([[1. , 0.4],
       [0.6, 1.1]])

>>> np.linalg.matrix_power(a, -5)
array([[ 10.27311319,  -7.49098132],
       [-11.23647198,   8.40036786]])

>>> model = csp.CmdStanModel(stan_file = 'matrix-pow.stan')

>>> model.sample(data = {'M': M, 'a': a, 'n': -5}, show_console=True, chains=1, iter_warmup=1, iter_sampling=1)

Which outputs a pile of console output that includes this:

Chain [1] ***** matrix power([[1,0.4],[0.6,1.1]], -5): *****[[10.2731,-7.49098],[-11.2365,8.40037]]

You can try for different values of n in the data dictionary and compare against the NumPy results.

4 Likes

I also created an issue in the hope that someone will code in C++ and add to Stan itself:

1 Like

Well, this is embarrassing. We have the function already coded into Stan reasonably efficiently, but it’s just not showing up in Google searches and I didn’t bother to do search within our doc. If someone needs to search in the doc, you have to scroll up on the landing page to find the search box, which is greyed out in the upper left, so easy to miss (not my excuse—I knew how to find search, I just never use it because Google’s almost always better). Anyway, here’s the doc:

It’s arguably in the wrong section—I’d be inclined to put it in matrix arithmetic. But that also wasn’t my problem—I don’t search by scanning section heads.