Improving efficiency for dynamic coevolutionary Stan model

My colleagues and I are working on an R package that fits dynamic coevolutionary models in Stan. The model allows users to infer causal directionality (e.g. X → Y) for multiple variables coevolving on a phylogenetic tree. For those interested, the full model algorithm is described in this pre-print.

The main package function takes a dataset, a phylogeny, and a list of variables and associated response distributions. It then generates Stan code, a data list for Stan, compiles the model, and fits the model with cmdstanr.

Since these models can run very slowly, we are interested in whether there are any ways we can feasibly improve the efficiency of the Stan code produced by the package. It would be great if some experienced users could scan their eyes over the code and give us a sense of what could be improved. It may be that there’s nothing we can do to speed it up.

Here is some example code for two Bernoulli variables. The model is rather complicated as it has to repeatedly solve a differential equation over different segments of the phylogeny.

Thanks so much in advance!

functions {
  // returns the Kronecker Product
  matrix kronecker_prod(matrix A, matrix B) {
    matrix[rows(A) * rows(B), cols(A) * cols(B)] C;
    int m;
    int n;
    int p;
    int q;
    m = rows(A);
    n = cols(A);
    p = rows(B);
    q = cols(B);
    for (i in 1:m)
      for (j in 1:n)
        for (k in 1:p)
          for (l in 1:q)
            C[p*(i-1)+k,q*(j-1)+l] = A[i,j]*B[k,l];
    return C;
  }
  
  // expected auto and cross effects over a discrete time t
  matrix A_dt(matrix A, real t) {
    return( matrix_exp(A * t) );
  }
  
  // calculate A sharp matrix
  matrix A_sharp(matrix A) {
    matrix[rows(A) * rows(A), cols(A) * cols(A)] A_temp;
    matrix[rows(A),cols(A)] I; // identity matrix
    I = diag_matrix(rep_vector(1,rows(A)));
    A_temp = kronecker_prod(A,I) + kronecker_prod(I,A);
    return(A_temp);
  }
  
  // solve SDE
  matrix cov_drift(matrix A, matrix Q, real ts) {
    matrix[rows(A) * rows(A), cols(A) * cols(A)] A_sharp_temp;
    matrix[rows(A) * rows(A), cols(A) * cols(A)] I; // identity matrix
    vector[rows(Q)*cols(Q)] row_Q;
    vector[rows(A)*cols(A)] irow_vec;
    matrix[rows(A),cols(A)] irow_mat;
    I = diag_matrix(rep_vector(1,rows(A_sharp_temp)));
    A_sharp_temp = A_sharp(A);
    // row operation takes elements of a matrix rowwise and puts them into a column vector
    for (i in 1:rows(Q))
      for (j in 1:cols(Q)) {
        row_Q[i + (j-1)*rows(Q)] = Q[j,i];
      }
    irow_vec = inverse(A_sharp_temp) * (matrix_exp(A_sharp_temp * ts) - I) * row_Q;
    // irow takes elements of a column vector and puts them in a matrix rowwise
    {
      int row_size = rows(A);
      int row_ticker = 1;
      int col_ticker = 0;
      for (i in 1:num_elements(irow_vec)) {
        col_ticker += 1;
        if (col_ticker > row_size) {
          row_ticker += 1;
          col_ticker = 1;
        }
        irow_mat[row_ticker,col_ticker] = irow_vec[i];
      }
    }
    return(irow_mat);
  }
  
  // return number of matches of y in vector x
  int num_matches(vector x, real y) {
    int n = 0;
    for (i in 1:rows(x))
      if (x[i] == y)
        n += 1;
    return n;
  }
  
  // return indices in vector x where x == y
  array[] int which_equal(vector x, real y) {
    array [num_matches(x, y)] int match_positions;
    int pos = 1;
    for (i in 1:rows(x)) {
      if (x[i] == y) {
        match_positions[pos] = i;
        pos += 1;
      }
    }
    return match_positions;
  }
}
data{
  int N_tips; // number of tips
  int N_obs; // number of observations
  int J; // number of response traits
  int N_seg; // total number of segments in the tree
  array[N_seg] int node_seq; // index of tree nodes
  array[N_seg] int parent; // index of the parent node of each descendent
  array[N_seg] real ts; // time since parent
  array[N_seg] int tip; // indicator of whether a given segment ends in a tip
  array[J,J] int effects_mat; // which effects should be estimated?
  int num_effects; // number of effects being estimated
  matrix[N_obs,J] y; // observed data
  matrix[N_obs,J] miss; // are data points missing?
  array[N_obs] int tip_id; // index between 1 and N_tips that gives the group id
  int prior_only; // should the likelihood be ignored?
}
parameters{
  vector<upper=0>[J] A_diag; // autoregressive terms of A
  vector[num_effects - J] A_offdiag; // cross-lagged terms of A
  vector<lower=0>[J] Q_diag; // self-drift terms
  vector[J] b; // SDE intercepts
  vector[J] eta_anc; // ancestral states
  matrix[N_seg - 1,J] z_drift; // stochastic drift, unscaled and uncorrelated
}
transformed parameters{
  matrix[N_seg,J] eta;
  matrix[J,J] Q; // drift matrix
  matrix[J,J] I; // identity matrix
  matrix[J,J] A; // selection matrix
  vector[J*J - J] Q_offdiag = rep_vector(0.0, J*J - J);
  matrix[N_seg,J] drift_tips; // terminal drift parameters
  matrix[N_seg,J] sigma_tips; // terminal drift parameters
  // fill A matrix //////////
  {
    int ticker = 1;
    // fill upper triangle of matrix
    for (i in 1:(J-1)) {
      for (j in (i+1):J) {
        if (effects_mat[i,j] == 1) {
          A[i,j] = A_offdiag[ticker];
          ticker += 1;
        } else if (effects_mat[i,j] == 0) {
          A[i,j] = 0;
        }
      }
    }
    // fill lower triangle of matrix
    for (i in 1:(J-1)) {
      for (j in (i+1):J) {
        if (effects_mat[j,i] == 1) {
          A[j,i] = A_offdiag[ticker];
          ticker += 1;
        } else if (effects_mat[j,i] == 0) {
          A[j,i] = 0;
        }
      }
    }
    // fill diag of matrix
    for (j in 1:J) A[j,j] = A_diag[j];
  }
  // fill Q matrix //////////
  {
    int ticker = 1;
    for (i in 1:(J-1))
      for (j in (i+1):J) {
        Q[i,j] = Q_offdiag[ticker];
        Q[j,i] = Q[i,j]; // symmetry of covariance
        ticker += 1;
      }
    for (j in 1:J) Q[j,j] = Q_diag[j];
  }
  // identity matrix
  I = diag_matrix(rep_vector(1,J));
  // setting ancestral states and placeholders
  for (j in 1:J) {
    eta[node_seq[1],j] = eta_anc[j];
    drift_tips[node_seq[1],j] = -99;
    sigma_tips[node_seq[1],j] = -99;
  }
  for (i in 2:N_seg) {
    matrix[J,J] A_delta; // amount of deterministic change (selection)
    matrix[J,J] VCV; // variance-covariance matrix of stochastic change (drift)
    vector[J] drift_seg; // accumulated drift over the segment
    A_delta = A_dt(A, ts[i]);
    VCV = cov_drift(A, Q, ts[i]);
    drift_seg = cholesky_decompose(VCV) * to_vector( z_drift[i-1,] );
    // if not a tip, add the drift parameter
    if (tip[i] == 0) {
      eta[node_seq[i],] = to_row_vector(
        A_delta * to_vector(eta[parent[i],]) + (inverse(A) * (A_delta - I) * b) + drift_seg
      );
      drift_tips[node_seq[i],] = to_row_vector(rep_vector(-99, J));
      sigma_tips[node_seq[i],] = to_row_vector(rep_vector(-99, J));
    }
    // if is a tip, omit, we'll deal with it in the model block;
    else {
      eta[node_seq[i],] = to_row_vector(
        A_delta * to_vector(eta[parent[i],]) + (inverse(A) * (A_delta - I) * b)
      );
      drift_tips[node_seq[i],] = to_row_vector(drift_seg);
      sigma_tips[node_seq[i],] = to_row_vector(diagonal(Q));
    }
  }
}
model{
  b ~ std_normal();
  eta_anc ~ std_normal();
  to_vector(z_drift) ~ std_normal();
  A_offdiag ~ std_normal();
  A_diag ~ std_normal();
  Q_diag ~ std_normal();
  if (!prior_only) {
    for (i in 1:N_obs) {
        if (miss[i,1] == 0) to_int(y[i,1]) ~ bernoulli_logit(eta[tip_id[i],1] + drift_tips[tip_id[i],1]);
        if (miss[i,2] == 0) to_int(y[i,2]) ~ bernoulli_logit(eta[tip_id[i],2] + drift_tips[tip_id[i],2]);
    }
  }
}

There’s a way you can instrument the code to see where the compute’s going:

Unfolding Kronecker products is rough. If there’s any way you can keep this implicit, it’ll probably be an enormous help. Especially since you’re just Kroneckering with the identity matrix in function A_sharp.

You don’t need parens around the return.

You can declare and define at the same time so this can be

int m = rows(A);

This is a very simplified product. I would just go through and define the entries of A_temp directly here and not try to do them in terms of big identity matrices.

There is almost never a justification for creating an identity matrix. You use this as matrix_exp(A_sharp_temp * ts) - I. This is better coded in Stan as

add_diag(matrix_exp(A_sharp_temp * ts), -1);

It will broadcast the -1 for you so you don’t need to allocate a vector.

This is just going to be nasty. Given it has a Kronecker structure, I would work with the Kronecker inverse, (A \otimes B)^{-1} = A^{-1} \otimes B^{-1}. But then the Kronecker structure from A_sharp isn’t exactly this. But whatever it is, simplifying this should be a top priority.

Given the complexity of this model, I’d urge you to put constraints where relevant on the data just for a sanity check and to help readers understand it.

Anything that’s constant should be declared in a transformed data block—this will be a big savings on useless autodiff that’d otherwise happen.

For filling the upper triangle, you can initialize to 0:

matrix[J,J] A = rep_matrix(0, J, J);

Then to fill, you want to create a 2d matrix

transformed data {
  array[2, num_effects - J] idxs = ...; // this should be pairs of (i, j) where effects_mat[i, j] == 1
 }
...
for (j in 1:num_effects - J) {
  A[idxs[j, 1], idxs[j, 2]] = A_offdiag[j];
}

It reduces a quadratic loop to sublinear.

this can just be z_drift[i-1]', but it’s even better if you represent z_drift as an array of vectors rather than as a matrix—it saves an extra allocation and copy.

You don’t ever want to invert and multiply. It’s more stable to take the matrix division, where A \ B = inverse(A) * B, but is more stable. So this should be A \ (A_delta - I). But again, don't represent the I` and do this this way. Add the diagonal -1 element as before rather than creating more identity matrices.

Reshape this so it can be just eta[parent[i]].

The to_int can be done in transformed data and saved here. In general, you want to precompute as much as you can in transformed data so that you don’t have to branch elsewhere. For example, all the i such that miss[i, 1] == 0 can be precomputed and stored in an array miss_ys_1 and then the sampling statement is just

if (!prior_only) {
  miss_y_1 ~ bernoulli_logit(eta[miss_tip_id_1], 1] + drift_tips[miss_tip_id_1]);
  miss_y_2 ~ ...
}

where you’ve also accumulated the indexes for miss_tip_id_1 to match the `miss_y_1.

Can you code prior_only by just setting N_obs = 0?

That was a lot, but I hope you got the general idea of how to think about this. There’s a chapter of the User’s Guide on optimizing code that discusses pretty much everything I suggested in one way or another, at least indirectly.

5 Likes

In addition to modifying the Stan code, you may get substantial speed-ups as discussed in a Stan blog post: Options for improving Stan sampling speed

2 Likes

Thank you so much Bob and Aki! This advice will be invaluable for our package. I’ll pass this on to my colleagues.

ctsem solves SDE’s using Kronecker products / matrix exponentials etc. I spent a bit of time writing custom functions to compute the covariance and matrix exponentials faster - I haven’t looked in any detail at your case but maybe it helps. key bits of code:

  matrix ksolve(matrix A, matrix Q){
    int d= rows(A);
    int d2= (d*d-d)%/%2;
    matrix[d+d2,d+d2] O;
    vector[d+d2] triQ;
    matrix[d,d] AQ;
    int z=0; //z is row of output
    for(j in 1:d){//for column reference of solution vector
      for(i in 1:j){ //and row reference...
        if(j >= i){ //if i and j denote a covariance parameter (from upper tri)
          int y=0; //start new output row
          z+=1; //shift current output row down
          
          for(ci in 1:d){//for columns and
            for(ri in 1:d){ //rows of solution
              if(ci >= ri){ //when in upper tri (inc diag)
                y+=1; //move to next column of output
                
                if(i==j){ //if output row is for a diagonal element
                  if(ri==i) O[z,y] = 2*A[ri,ci];
                  if(ci==i) O[z,y] = 2*A[ci,ri];
                }
                
                if(i!=j){ //if output row is not for a diagonal element
                  if(y==z) O[z,y] = A[ri,ri] + A[ci,ci]; //if column of output matches row of output, sum both A diags
                  if(y!=z){ //otherwise...
                    // if solution element we refer to is related to output row...
                    if(ci==ri){ //if solution element is a variance
                      if(ci==i) O[z,y] = A[j,ci]; //if variance of solution corresponds to row of our output
                      if(ci==j) O[z,y] = A[i,ci]; //if variance of solution corresponds to col of our output
                    }
                    if(ci!=ri && (ri==i||ri==j||ci==i||ci==j)){//if solution element is a related covariance
                      //for row 1,2 / 2,1 of output, if solution row ri 1 (match) and column ci 3, we need A[2,3]
                      if(ri==i) O[z,y] = A[j,ci];
                      if(ri==j) O[z,y] = A[i,ci];
                      if(ci==i) O[z,y] = A[j,ri];
                      if(ci==j) O[z,y] = A[i,ri];
                    }
                  }
                }
                if(is_nan(O[z,y])) O[z,y]=0;
              }
            }
          }
        }
      }
    }
    
    z=0; //get upper tri of Q
    for(j in 1:d){
      for(i in 1:j){
        z+=1;
        triQ[z] = Q[i,j];
      }
    }
    triQ=-O \ triQ; //get upper tri of asymQ
    
    z=0; // put upper tri of asymQ into matrix
    for(j in 1:d){
      for(i in 1:j){
        z+=1;
        AQ[i,j] = triQ[z];
        if(i!=j) AQ[j,i] = triQ[z];
      }
    }
    return AQ;
  }

//split a matrix exponential into pre specified blocks (ie when off diagonals for row and column are 0)
  matrix expmSubsets(matrix m, array[,] int subsets){
    int nr = rows(m);
    matrix[nr,nr] e = rep_matrix(0,nr,nr);
    for(si in 1:size(subsets)){
      int n=0;
      for(j in 1:nr) n+= subsets[si,j]!=0;
      if(n > 1){
        e[subsets[si,1:n],subsets[si,1:n] ] = matrix_exp(m[subsets[si,1:n],subsets[si,1:n]]);
      } else e[subsets[si,1],subsets[si,1] ] = exp(m[subsets[si,1],subsets[si,1]]);
    }
    return e;
  }

//asymptotic diffusion covariance
asymDIFFUSIONcov = ksolve(JAx, DIFFUSIONcov);

//discrete time coefficients
eJAx =  expmSubsets(JAx * dt, JAxsubsets);

//discrete time diffusion covariance
discreteDIFFUSION =  asymDIFFUSIONcov - 
        quad_form_sym( asymDIFFUSIONcov, eJAx' );

I think with recent stan updates it may also be possible to implement a faster solution to the covariance - c++ - Implementing the Bartels–Stewart algorithm in Eigen3 -- real matrices only? - Stack Overflow

2 Likes

Thank you @Charles_Driver. This is super helpful. I’ll pass this along.

I’m in the process of making the changes suggested above. Profiling suggests that our biggest bottleneck is the cov_drift() function, in particular the line starting with irow_vec (which has been updated to reflect Bob’s advice about avoiding identity matrices and multiplying by inverse matrices).

// solve SDE
matrix cov_drift(matrix A, matrix Q, real ts) {
  matrix[rows(A) * rows(A), cols(A) * cols(A)] A_sharp_temp;
  vector[rows(Q)*cols(Q)] row_Q;
  vector[rows(A)*cols(A)] irow_vec;
  matrix[rows(A),cols(A)] irow_mat;
  profile("kronecker_sum") {A_sharp_temp = kronecker_sum(A, A);}
  // row operation takes elements of a matrix rowwise and puts them into a column vector
  profile("row_Q") {
    for (i in 1:rows(Q))
      for (j in 1:cols(Q)) {
        row_Q[i + (j-1)*rows(Q)] = Q[j,i];
      }
  }
  profile("irow_vec") {irow_vec = (A_sharp_temp \ add_diag(matrix_exp(A_sharp_temp * ts), -1)) * row_Q;}
  // irow takes elements of a column vector and puts them in a matrix rowwise
  profile("irow_mat") {
    {
      int row_size = rows(A);
      int row_ticker = 1;
      int col_ticker = 0;
      for (i in 1:num_elements(irow_vec)) {
        col_ticker += 1;
        if (col_ticker > row_size) {
          row_ticker += 1;
          col_ticker = 1;
        }
        irow_mat[row_ticker,col_ticker] = irow_vec[i];
      }
    }
  }
  return irow_mat;
}
           name thread_id total_time forward_time reverse_time chain_stack no_chain_stack autodiff_calls no_autodiff_calls
1      irow_mat         1  0.0639180    0.0359620    0.0279560           0              0         589068             18018
2      irow_vec         1  9.4457300    5.9759000    3.4698300   534138035       59645256         589068             18018
3 kronecker_sum         1  0.7651890    0.5353950    0.2297940    37700416         589069         589068             18018
4         row_Q         1  0.0646981    0.0367919    0.0279062           0              0         589068             18018

Any ideas for how to improve the efficiency of this line would be greatly appreciated!

It’s pretty hazy in my memory, but I’m pretty sure my code is basically just a large reduction in the number of operations involved in the irow_vec type bits, avoiding matrix multiplications in favour of an elementwise approach that exploits the sparsity and block structure. A neat benefit was also a big reduction in memory usage, I think.

1 Like

Thanks Charles! Okay, we’ll dive into your code and see if there’s anything we can pull out for our use case.

Ok I took a little look at your paper. Convenient that you’re citing my work aha.

You want to compute:
image

This requires all the Kronecker stuff for every different time interval. If your system parameters are not changing at every step in time, it should be more efficient to compute the asymptotic covariance just once (or when needed if changing occasionally):

image

From this, you can compute the covariance for specific time intervals:

image

(these are copied from Driver, C. C., & Voelkle, M. C. (2018). Hierarchical Bayesian continuous time dynamic modeling. Psychological methods, 23(4), 774. )

Because in some nonlinear SDE models the diffusion covariance or temporal coefficients change, requiring recomputation of the asymptotic covariance, I wanted a faster way to compute the asymptotic covariance. Using my code above, you can compute the asymptotic covariance faster than the ‘naive’ approach. I don’t know if your approach is what I call the naive approach, but I assume it is because that is simplest.

asymptotic (delta t → inf) covariance = ksolve(A, Q);
ADT (discrete time A coefficients) = matrix_exp(A,dt);
discrete time covariance = asymDIFFUSIONcov - quad_form_sym( asymDIFFUSIONcov, ADT’ );

edit: corrected the last line to use ADT instead of A

3 Likes

Thanks so much @Charles_Driver, for these optimizations and for your papers that laid key foundations for our Stan model! Already I am seeing big gains in speed. By the way, I think we are both in Zurich at this point in time–which means I owe you a coffee/other drink sometime.

1 Like

Thank you Charles! I will have to defer to @Erik_Ringen on these details.

1 Like

I accept coffee/other drinks ;)