My colleagues and I are working on an R package that fits dynamic coevolutionary models in Stan. The model allows users to infer causal directionality (e.g. X → Y) for multiple variables coevolving on a phylogenetic tree. For those interested, the full model algorithm is described in this pre-print.
The main package function takes a dataset, a phylogeny, and a list of variables and associated response distributions. It then generates Stan code, a data list for Stan, compiles the model, and fits the model with cmdstanr.
Since these models can run very slowly, we are interested in whether there are any ways we can feasibly improve the efficiency of the Stan code produced by the package. It would be great if some experienced users could scan their eyes over the code and give us a sense of what could be improved. It may be that there’s nothing we can do to speed it up.
Here is some example code for two Bernoulli variables. The model is rather complicated as it has to repeatedly solve a differential equation over different segments of the phylogeny.
Thanks so much in advance!
functions {
// returns the Kronecker Product
matrix kronecker_prod(matrix A, matrix B) {
matrix[rows(A) * rows(B), cols(A) * cols(B)] C;
int m;
int n;
int p;
int q;
m = rows(A);
n = cols(A);
p = rows(B);
q = cols(B);
for (i in 1:m)
for (j in 1:n)
for (k in 1:p)
for (l in 1:q)
C[p*(i-1)+k,q*(j-1)+l] = A[i,j]*B[k,l];
return C;
}
// expected auto and cross effects over a discrete time t
matrix A_dt(matrix A, real t) {
return( matrix_exp(A * t) );
}
// calculate A sharp matrix
matrix A_sharp(matrix A) {
matrix[rows(A) * rows(A), cols(A) * cols(A)] A_temp;
matrix[rows(A),cols(A)] I; // identity matrix
I = diag_matrix(rep_vector(1,rows(A)));
A_temp = kronecker_prod(A,I) + kronecker_prod(I,A);
return(A_temp);
}
// solve SDE
matrix cov_drift(matrix A, matrix Q, real ts) {
matrix[rows(A) * rows(A), cols(A) * cols(A)] A_sharp_temp;
matrix[rows(A) * rows(A), cols(A) * cols(A)] I; // identity matrix
vector[rows(Q)*cols(Q)] row_Q;
vector[rows(A)*cols(A)] irow_vec;
matrix[rows(A),cols(A)] irow_mat;
I = diag_matrix(rep_vector(1,rows(A_sharp_temp)));
A_sharp_temp = A_sharp(A);
// row operation takes elements of a matrix rowwise and puts them into a column vector
for (i in 1:rows(Q))
for (j in 1:cols(Q)) {
row_Q[i + (j-1)*rows(Q)] = Q[j,i];
}
irow_vec = inverse(A_sharp_temp) * (matrix_exp(A_sharp_temp * ts) - I) * row_Q;
// irow takes elements of a column vector and puts them in a matrix rowwise
{
int row_size = rows(A);
int row_ticker = 1;
int col_ticker = 0;
for (i in 1:num_elements(irow_vec)) {
col_ticker += 1;
if (col_ticker > row_size) {
row_ticker += 1;
col_ticker = 1;
}
irow_mat[row_ticker,col_ticker] = irow_vec[i];
}
}
return(irow_mat);
}
// return number of matches of y in vector x
int num_matches(vector x, real y) {
int n = 0;
for (i in 1:rows(x))
if (x[i] == y)
n += 1;
return n;
}
// return indices in vector x where x == y
array[] int which_equal(vector x, real y) {
array [num_matches(x, y)] int match_positions;
int pos = 1;
for (i in 1:rows(x)) {
if (x[i] == y) {
match_positions[pos] = i;
pos += 1;
}
}
return match_positions;
}
}
data{
int N_tips; // number of tips
int N_obs; // number of observations
int J; // number of response traits
int N_seg; // total number of segments in the tree
array[N_seg] int node_seq; // index of tree nodes
array[N_seg] int parent; // index of the parent node of each descendent
array[N_seg] real ts; // time since parent
array[N_seg] int tip; // indicator of whether a given segment ends in a tip
array[J,J] int effects_mat; // which effects should be estimated?
int num_effects; // number of effects being estimated
matrix[N_obs,J] y; // observed data
matrix[N_obs,J] miss; // are data points missing?
array[N_obs] int tip_id; // index between 1 and N_tips that gives the group id
int prior_only; // should the likelihood be ignored?
}
parameters{
vector<upper=0>[J] A_diag; // autoregressive terms of A
vector[num_effects - J] A_offdiag; // cross-lagged terms of A
vector<lower=0>[J] Q_diag; // self-drift terms
vector[J] b; // SDE intercepts
vector[J] eta_anc; // ancestral states
matrix[N_seg - 1,J] z_drift; // stochastic drift, unscaled and uncorrelated
}
transformed parameters{
matrix[N_seg,J] eta;
matrix[J,J] Q; // drift matrix
matrix[J,J] I; // identity matrix
matrix[J,J] A; // selection matrix
vector[J*J - J] Q_offdiag = rep_vector(0.0, J*J - J);
matrix[N_seg,J] drift_tips; // terminal drift parameters
matrix[N_seg,J] sigma_tips; // terminal drift parameters
// fill A matrix //////////
{
int ticker = 1;
// fill upper triangle of matrix
for (i in 1:(J-1)) {
for (j in (i+1):J) {
if (effects_mat[i,j] == 1) {
A[i,j] = A_offdiag[ticker];
ticker += 1;
} else if (effects_mat[i,j] == 0) {
A[i,j] = 0;
}
}
}
// fill lower triangle of matrix
for (i in 1:(J-1)) {
for (j in (i+1):J) {
if (effects_mat[j,i] == 1) {
A[j,i] = A_offdiag[ticker];
ticker += 1;
} else if (effects_mat[j,i] == 0) {
A[j,i] = 0;
}
}
}
// fill diag of matrix
for (j in 1:J) A[j,j] = A_diag[j];
}
// fill Q matrix //////////
{
int ticker = 1;
for (i in 1:(J-1))
for (j in (i+1):J) {
Q[i,j] = Q_offdiag[ticker];
Q[j,i] = Q[i,j]; // symmetry of covariance
ticker += 1;
}
for (j in 1:J) Q[j,j] = Q_diag[j];
}
// identity matrix
I = diag_matrix(rep_vector(1,J));
// setting ancestral states and placeholders
for (j in 1:J) {
eta[node_seq[1],j] = eta_anc[j];
drift_tips[node_seq[1],j] = -99;
sigma_tips[node_seq[1],j] = -99;
}
for (i in 2:N_seg) {
matrix[J,J] A_delta; // amount of deterministic change (selection)
matrix[J,J] VCV; // variance-covariance matrix of stochastic change (drift)
vector[J] drift_seg; // accumulated drift over the segment
A_delta = A_dt(A, ts[i]);
VCV = cov_drift(A, Q, ts[i]);
drift_seg = cholesky_decompose(VCV) * to_vector( z_drift[i-1,] );
// if not a tip, add the drift parameter
if (tip[i] == 0) {
eta[node_seq[i],] = to_row_vector(
A_delta * to_vector(eta[parent[i],]) + (inverse(A) * (A_delta - I) * b) + drift_seg
);
drift_tips[node_seq[i],] = to_row_vector(rep_vector(-99, J));
sigma_tips[node_seq[i],] = to_row_vector(rep_vector(-99, J));
}
// if is a tip, omit, we'll deal with it in the model block;
else {
eta[node_seq[i],] = to_row_vector(
A_delta * to_vector(eta[parent[i],]) + (inverse(A) * (A_delta - I) * b)
);
drift_tips[node_seq[i],] = to_row_vector(drift_seg);
sigma_tips[node_seq[i],] = to_row_vector(diagonal(Q));
}
}
}
model{
b ~ std_normal();
eta_anc ~ std_normal();
to_vector(z_drift) ~ std_normal();
A_offdiag ~ std_normal();
A_diag ~ std_normal();
Q_diag ~ std_normal();
if (!prior_only) {
for (i in 1:N_obs) {
if (miss[i,1] == 0) to_int(y[i,1]) ~ bernoulli_logit(eta[tip_id[i],1] + drift_tips[tip_id[i],1]);
if (miss[i,2] == 0) to_int(y[i,2]) ~ bernoulli_logit(eta[tip_id[i],2] + drift_tips[tip_id[i],2]);
}
}
}