Joint Species Distribution Model Performance

Dear All,

I am opening this discussion to share the Stan code that user @pajarom@gmail.com and I have been developing in order to realise a Joint Species Distribution Model in Stan of which we discussed performance in (Rstan on remote servers). Some of you suggested that the really long computational times this model requires could be due to the code structure.

A brief introduction on how the model works:

This is a hierarchical model with applications in ecology, which allows to use animal occurrence data to estimate coefficients that describe the way in which different species respond to a set of variables and how this response is mediated by specific characteristics and by the way in which every species is genetically related to the others. This framework is described in detail in Ovaskainen et al. (https://doi.org/10.1111/ele.12757) for those who are interested.

Every species is associated to a detection probability § which was calculated using traditional distance sampling methods. We use these detection probabilities to model the observed count of species groups y for the transect cell i at sampling time j as the binomial distribution conditional on the true population size N.

y_ij|N_i ~ Bin(N_i,p)

The estimated population size N_i is then modelled as a Poisson distribution of mean \lambda, whose logarithm is described by a linear function where the predictors are constituted by the animal responses \beta to the measurable environmental covariates listed above, as shown in equation below:

log(\lambda_i )=\beta_0+ \beta_1 Xi

In order to understand how the species responses to the environment (the \beta parameters) depend on species specific traits and on the species phylogeny, we model the beta parameters using a multivariate normal distribution for which

\beta ~ N(\mu,V)

I share below the R code to simulate a small dataset and run the model, as well as the stan code for the model itself.

# first we generate mock data

# a number of species 
# a phylogenetic tree indicating how these species are related 
# a set of species specific traits for every species 
# a set of environmental covariates (2 changing across space, 2 changing across space and time)


library(ape) # for phylogenetic data
library(mvtnorm) # to compute multivariate probabilities

n_sp=20 # number of species

n_env=2 # environmental covariates 

n_pars=n_env+1 #parameters (environmental cov + intercept)

n_t=2 # number of species specific traits 

# defining species specific  traits 
dgrass=rbeta(n_sp,1,1) # fraction of grass in diet 
log_bm=rnorm(n_sp,0,1) # log of body mass

#arrange species specific traits in a matrix + intercept
TT = as.matrix(cbind(rep(1, n_sp), scale(dgrass), scale(log_bm)))


# simulate phylogeny 
tree=rtree(n_sp)
CC=vcv(tree, corr=T) 



# sort species and re-arrange phylogenetic correlation matrix
tmp = dimnames(CC)
ids = as.numeric(as.factor(tmp[[1]]))

C = matrix(NA, ncol(CC), ncol(CC))
for(i in 1:ncol(CC)){
  for(j in 1:ncol(CC)){
    C[ids[i],ids[j]] = CC[i,j]
  }
}

# Z effect of species-specific trait on environmental responses (betas)
Z = matrix(rnorm((n_t + 1) * n_pars, 0 , 0.5),  (n_t + 1), n_pars)

# expected betas
M=TT%*%Z

# define the variation around the betas
Sigma=diag(n_pars)*0.3
rho=0.5 # defines the role of phylogenetic relatedness

betas=rmvnorm(1,mean=as.vector(M),kronecker(Sigma,rho*C+(1-rho)*diag(n_sp)))
# species responses come from a multivariate distribution where the vector 
# m is the species responses mediated by effect of traits and the variance 
# covariance matrix is adjusted to vary according to phylogenetic variation 
# via a kronecker product

Beta=matrix(betas[1,],n_sp,n_pars)

# animals are observed in transects via distance sampling, we simulate 
# observation like so 
n_sites=100

X=cbind(rep(1,n_sites),matrix(rnorm(n_sites*n_env),n_sites,n_env))
# X matrix contains all the environmental data (plus a column of 1s for the intercepts)

B=5 # max sampling distance

# simulate real abundances
N = matrix(NA, n_sites, n_sp)
for(i in 1: n_sp){
  N[,i] = rpois(n_sites, lambda = exp(X %*% Beta[i, ]))
}

p = rbeta(n_sp, 5,3) # probability of detection is unique for every species

# simulate observed dataset
data = NULL
for (i in 1:n_sites) {
  for(j in 1:n_sp){
    if(N[i,j] > 0){
      #d = runif(N[i,j], 0, B)
      #gs = rpois(N[i,j], lambda.group[j]) + 1
      #sigma = exp(Beta[j,1] + gs * Beta[j,2] + TT[j,2] * Beta[j,3])
      #p = exp(-d * d/(2 * (sigma^2)))
      y = rbinom(N[i,j], 1, p[j])
      #d = d[y == 1]
      #gs = gs[y == 1]
      y = y[y == 1]
      if (sum(y) > 0){
        data = rbind(data, cbind(rep(i, sum(y)), rep(j, sum(y)), y))
      }
    }
  }
}



colnames(data) = c("site","sp", "y")
datos = as.data.frame(data)


# run the stan model to estimate parameters

# library(cmdstanr)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())


stan_dat <- list(
  n_obs = dim(datos)[1],
  area = rep(1.0, n_sites),
  n_sites = n_sites,
  site = as.integer(datos$site), # (c(s, rep(1:n_sites, each = nzs ))),
  K = dim(X)[2], 
  X = X,
  n_max = rep(100, n_sp),
  n_s = as.integer(n_sp),
  n_t = dim(TT)[2],
  TT = TT,
  C = C,
  ones = numeric(n_sp) + 1,
  sp = datos$sp,
  p_obs = p
)

pars <- c( "b_m", "rho",  "Sigma", "z", "D")

init_f <- function () list(b_m = matrix(0, n_sp, n_pars))

fit <- stan(file = 'model_demo.stan',
            data = stan_dat,
            init = init_f,
            pars = pars,
            iter = 1000, thin = 1, chains = 3)


# some plots
fit_summary <- summary(fit)$summary
op <- par(mfrow = c(1,2))
hist(fit_summary[,10], main = "R-hat")
hist(fit_summary[,9], main = "n-eff" )
par(op)

draws <- extract(fit, pars = "rho")

plot(density(draws$rho), main = "")
abline(v=rho)

# plot trait level parameters
zs <- fit_summary[grepl("z", rownames(fit_summary)),]
#plot(c(Z) - zs[,1])
df <- data.frame(x = 1:dim(zs)[1],
                 tz = c(Z),
                 fz = zs[,1],
                 L = zs[,4],
                 U = zs[,8])

ggplot(df, aes(x = x, y = tz)) +
  geom_point(size = 3, color="red") +
  geom_point(aes(y = fz), size = 2) +
  geom_linerange(aes(ymin = L, ymax = U)) +
  theme_classic()


# plot intercepts and slopes
bs <- fit_summary[grepl("b_m", rownames(fit_summary)),]

nf = layout(matrix(c(1,2,3,4,0,0),3,2,byrow=TRUE), widths=c(1,1), heights=c(1,1,0.1))
#layout.show(nf)
op <- par( mar = c(3, 3, 2, 2) + 0.1, mgp = c(3.5, 1, 0), las = 1, bty = "n", cex = 1.2)
plot(scale(dgrass), Beta[,2],  ylab = "", xlab = "", main = "intercept", ylim=c(-3,3))
points(scale(dgrass), bs[seq(1, (n_sp*3) ,by=3) + 1,1], pch = 19, col = 2)

plot(scale(dgrass), Beta[,3], ylab = "", xlab = "", main = "slope", ylim=c(-3,3) )
points(scale(dgrass), bs[seq(1, (n_sp*3) ,by=3) + 2,1], pch = 19, col = 2)

plot(scale(log_bm),Beta[,2], ylab = "", xlab = "", ylim=c(-3,3))
points(scale(log_bm), bs[seq(1, (n_sp*3) ,by=3) + 1,1], pch = 19, col = 2)

plot(scale(log_bm),Beta[,3], ylab = "", xlab = "", ylim=c(-3,3))
points(scale(log_bm),  bs[seq(1, (n_sp*3) ,by=3) + 2,1], pch = 19, col = 2)
mtext("         scaled grass           scaled log body mass", side = 1, line = -2, outer = TRUE, cex=1.3)
par(mfrow = c(1,1))

# another plot of coefficients and estimates
plot(c(Beta), 
     c(bs[seq(1, (n_sp*3) ,by=3) ,1], bs[seq(1, (n_sp*3) ,by=3) + 1,1], bs[seq(1, (n_sp*3) ,by=3) + 2,1]), xlab = "true value", ylab = "posterior mean")
abline(0,1)

# yet another one
df <- data.frame(x = 1:dim(bs)[1],
                 tb = c(Beta),
                 fb = c(bs[seq(1, (n_sp*3) ,by=3) ,1], bs[seq(1, (n_sp*3) ,by=3) + 1,1], bs[seq(1, (n_sp*3) ,by=3) + 2,1]),
                 L =  c(bs[seq(1, (n_sp*3) ,by=3) ,4], bs[seq(1, (n_sp*3) ,by=3) + 1,4], bs[seq(1, (n_sp*3) ,by=3) + 2,4]),
                 U =  c(bs[seq(1, (n_sp*3) ,by=3) ,8], bs[seq(1, (n_sp*3) ,by=3) + 1,8], bs[seq(1, (n_sp*3) ,by=3) + 2,8])
)

ggplot(df, aes(x = x, y = tb)) +
  geom_point(size = 2, color="red") +
  geom_point(aes(y = fb), size = 1) +
  geom_linerange(aes(ymin = L, ymax = U)) +
  theme_classic()

# plot density estimates
D <- fit_summary[grepl("D", rownames(fit_summary)),]

df <- data.frame(x = 1:dim(D)[1],
                 td = colSums(N)/n_sites,
                 fd = D[,1],
                 L =  D[,4],
                 U =  D[,8]
)

ggplot(df, aes(x = x, y = td)) +
  geom_point(size = 2, color="red") +
  geom_point(aes(y = fd), size = 1) +
  geom_linerange(aes(ymin = L, ymax = U)) +
  theme_classic()

And the stan model:

functions { 
  
  /* compute the kronecker product
  * Args: 
  *   A,B: matrices 
  * Returns: 
  *   kronecker product of A and B
  */ 
  matrix kronecker(matrix A, matrix B) { 
    matrix[rows(A)*rows(B), cols(A)*cols(B)] kron; 
    for (i in 1:cols(A)) { 
      for (j in 1:rows(A)) { 
        kron[((j-1)*rows(B)+1):(j*rows(B)), ((i-1)*cols(B)+1):(i*cols(B))] = A[j,i] * B;
      } 
    } 
    return kron; 
  } 
  
  
  int qpois(real q, real lambda, int max_x) {
    int x = 0;
    real res = poisson_cdf(x, lambda);
    
    while(res < q && x < max_x){
      x = x + 1;
      res = poisson_cdf(x, lambda);
    }
    return x; 
  } 
} 

data { 
  int<lower=1> n_obs;
  int<lower=1> n_sites;           // total number of observations (sites/segments)
  real<lower=0> area[n_sites];     // area for every site
  int<lower=1> K;                 // number of sample-level predictors
  int<lower=1> n_s;               // num of species
  int<lower=1> n_t;               // num species level predictors (traits)
  int<lower=1,upper=n_s> sp[n_obs];   // species id 
  int<lower=1,upper=n_sites> site[n_obs];
  matrix[n_sites, K] X;                 // obs-level design matrix 
  matrix[n_s, n_t] TT;            // species-level traits
  matrix[n_s, n_s] C;             // phylogenetic correlation matrix
  vector[n_s] ones;               // vector on 1s
  int<lower=1> n_max[n_s];        // Upper bound of population size per spp
  real<lower=0,upper=1> p_obs[n_s];
}

transformed data {
  
  int<lower=0> Y[n_sites, n_s]; // total spp by site
  
  for(i in 1:n_sites){
    for(j in 1:n_s){
      Y[i,j] = 0;
    }
  }
  
  for(i in 1:n_obs){
    Y[site[i], sp[i]] += 1;
  }
}


parameters {
  corr_matrix[K] Omega;           // correlation matrix for var-covar of betas
  vector<lower=0>[K] tau;         // scales for the variance covariance of betas
  vector[n_s * K] betas;
  real<lower=0,upper=1> rho;      // correlation between phylogeny and betas
  vector[n_t * K] z;              // coeffs for traits
  real<lower=0,upper=1> p[n_s];   // detection probability
}

transformed parameters { 
  matrix[K, K] Sigma = quad_form_diag(Omega, tau);
  matrix[n_s*K, n_s*K] S = kronecker(Sigma, rho * C + (1-rho) * diag_matrix(ones));
  matrix[n_t, K] Z = to_matrix(z, n_t, K);    
  vector[n_s * K] m = to_vector(TT * Z);        // mean of coeffs
  matrix[n_s, K] b_m = to_matrix(betas, n_s, K);  // coeffs
} 

model {
  matrix[n_sites, n_s] log_lambda;
  int Ymax[n_sites, n_s];
  // priors
  // p ~ beta(2,2);
  Omega ~ lkj_corr(2);
  tau ~ student_t(3,0,10); // cauchy(0, 2.5); // lognormal()
  betas ~ multi_normal(m, S);
  //rho ~ beta(2,2);
  z ~ normal(0,2);
  
  // mix prior on rho
  //target += log_sum_exp(log(0.5) +  beta_lpdf(rho|1, 10), log(0.5) +  beta_lpdf(rho|2,2));
  
  for (n in 1:n_sites){
    for(s in 1:n_s){
      log_lambda[n,s] = dot_product( X[n,] , b_m[s,]) + log(area[n]);
      Ymax[n,s] = Y[n,s] + qpois(0.9999, exp(log_lambda[n,s]) * (1 - p_obs[s]), Y[n,s] + n_max[s]);
    }
  }
  
  for (n in 1:n_sites){
    for(s in 1:n_s){
      vector[Ymax[n,s] - Y[n,s] + 1] lp;
      for (j in 1:(Ymax[n,s]  - Y[n,s] + 1)){
        lp[j] = poisson_log_lpmf(Y[n,s] + j - 1 | log_lambda[n,s]) 
        + binomial_lpmf(Y[n,s] | Y[n,s] + j - 1, p_obs[s]);
      }
      
      target += log_sum_exp(lp);
    }
  }
}

generated quantities{
  int<lower=0> N[n_sites,n_s];
  real<lower=0> D[n_s];
  
  for (n in 1:n_sites){
    for(s in 1:n_s){
      N[n,s]= poisson_log_rng(dot_product( X[n,] , b_m[s,]) + log(area[n])); 
      }
    }
  
  for(s in 1:n_s) D[s] = sum(N[,s])/sum(area);
  
}

This simulation takes about 20 minutes on my laptop to run, however the version with the real dataset (several more parameters are estimated) needs more than 2 weeks. Do you identify any features in the model the structure that may account for the very long computation times?

PS: hopefully code will be formatted correctly in the post, in case it hasn’t I will try to fix it

4 Likes

I’m not a genius of optimizing the parameterization, but I do think there’s one clear target for saving some runtime, particularly if abundances are high and detection probabilities are low.

Currently, you marginalize over the true abundance by considering every possible value between the observed abundance and the .9999 quantile of the Poisson. But the observed abundance might be far lower than the .0001 quantile of the Poisson if detection probabilities are low. So you might be able to save a ton of terms by using max(Y, qpois(.0001)) as the lower bound for the marginalization?

1 Like

IIRC Poisson + binomial marginalizes to a Poisson so that instead of this

vector[Ymax[n,s] - Y[n,s] + 1] lp;
for (j in 1:(Ymax[n,s]  - Y[n,s] + 1)){
  lp[j] = poisson_log_lpmf(Y[n,s] + j - 1 | log_lambda[n,s]) 
  + binomial_lpmf(Y[n,s] | Y[n,s] + j - 1, p_obs[s]);
}
target += log_sum_exp(lp);

you can just have

target += poisson_log_lpmf(Y[n,s] | log_lambda[n,s] + log(p_obs[s]));

and there’s no need to compute Ymax at all.

Another change that probably helps is using Cholesky factors.

parameters {
  cholesky_factor_corr[K] L_Omega;
  vector<lower=0>[K] tau;
}
transformed parameters { 
  matrix[K, K] L_Sigma = diag_pre_multiply(tau,L_Omega);
  matrix[n_s*K, n_s*K] L_S = kronecker(L_Sigma,
          cholesky_decompose(rho * C + diag_matrix(rep_vector(1-rho,K))));
}
model {
  L_Omega ~ lkj_corr_cholesky(2);
  tau ~ student_t(3,0,10);
  betas ~ multi_normal_cholesky(m, L_S);
}

Lucky that Kronecker product and Cholesky decomposition play so well with each other.

7 Likes

I think you can drop the Kronecker product entirely and used a non-centered parameterization of the Matrix Normal Distribution. You could have something like this:

parameters{
  cholesky_factor_corr[K] L_Omega;
  vector<lower=0>[K] tau;
  matrix[n_s, K] beta_std;
  matrix[n_t, K] Z; // No need for z as a vector
// everything else
}
transformed parameters{
  matrix[K, K] L_Sigma = diag_pre_multiply(tau, L_Omega);
  matrix[n_s, n_s] L_phylo = cholesky_decompose(
             rho * C + diag_matrix(rep_vector(1 - rho, K)));
 
  matrix[n_t, K] Z = to_matrix(z, n_t, K);    
  matrix[n_s,  K] m = TT * Z;        // mean of coeffs

  matrix[n_s, K] beta = m + L_phylo * beta_std * L_Sigma'; // Matrix Normal NCP
}
model {
  // The fixed priors are probably better as data arguments
  L_Omega ~ lkj_corr_cholesky(2);
  tau ~ student_t(3,0,10); 
  beta_std ~ std_normal(to_vector(beta_std));
  to_vector(Z) ~ normal(0, 2); 
}
6 Likes