Hey all,
I’ve been working to translate a hierarchical nonlinear model from JAGS to STAN, mostly to allow timely inference on a large-ish dataset of ~100,000 observations. I’ve got a version in STAN that works well for simulated data, but performance isn’t much better than the JAGS version.
I suspect that there are probably some things I could do to speed things up, but I’m relatively new to STAN so I don’t have a good intuition for what strategies will be most fruitful. I’m hoping the community can help!
The model fits a nonlinear function (a log-polynomial) to overdispersed count data with a negative binomial distribution, and includes a non-centered parameterization of plot and species-specific random effects on the curve parameters.
The application is to help understand how changes in climate drive changes in the timing of flowering in plant communities. We used the JAGS version in this paper:
http://onlinelibrary.wiley.com/doi/10.1002/ecy.1996/full
Here’s the STAN program:
data {
  int N; //the number of observations
  int NG; //the number of groups
  int NS; //the number of species
  int y[N]; //the response
  vector[N] x1; //first predictor.
  vector[N] x2; //second predictor
  int group[N]; //group index
  int spp[N]; //species index
}
parameters {
  vector[2] beta_opt; //the regression parameters on the optimum
  vector[2] beta_width; //the regression parameters on the width
  vector[2] beta_height; //the regression parameters on the height
  vector[NG] opt_group; //random intercepts for the optimum
  vector[NG] height_group; //random intercepts for the height
  vector[NG] width_group; //random intercepts for the width
  vector[NS] opt_spp; //random intercepts for the optimum
  vector[NS] height_spp; //random intercepts for the height
  vector[NS] width_spp; //random intercepts for the width
  vector<lower=0>[3] group_sd; //sd for group random intercepts
  vector<lower=0>[3] spp_sd; //sd for group random intercepts
  real<lower=0> phi; //the overdispersion parameters
}
transformed parameters {
  vector[N] mu;//the linear predictor
  vector[N] opt; //optima
  vector[N] width; //width
  vector[N] height; //height
  
  opt = beta_opt[1] + x2 * beta_opt[2] + opt_group[group] + opt_spp[spp];
  width = exp(beta_width[1] + x2 * beta_width[2] + width_group[group] + width_spp[spp])*-1;
  height = beta_height[1] + x2 * beta_height[2] + height_group[group] + height_spp[spp];
  
  //can't vectorize the exponentiation step.
  for(i in 1:N){
      mu[i] = exp(width[i] * square(x1[i] - opt[i]) + height[i]);
  }
}
model {  
  
  //Priors for fixed effects
  phi ~ cauchy(10,50);
  beta_width[1] ~ normal(0,5); //prior for the intercept
  beta_width[2] ~ normal(0,5); //prior for the slope
  beta_opt[1] ~ normal(0,5); //prior for the intercept
  beta_opt[2] ~ normal(0,5); //prior for the slope
  beta_height[1] ~ normal(0,5); //prior for the intercept
  beta_height[2] ~ normal(0,5); //prior for the slope
  
  //Priors for random effects
  group_sd[1] ~ cauchy(0,3);
  group_sd[2] ~ cauchy(0,3);
  group_sd[3] ~ cauchy(0,3);
  
  spp_sd[1] ~ cauchy(0,5);
  spp_sd[2] ~ cauchy(0,5);
  spp_sd[3] ~ cauchy(0,5);
  
  for(j in 1:NG){
    opt_group[j] ~ normal(0,group_sd[1]);
    width_group[j] ~ normal(0,group_sd[2]);
    height_group[j] ~ normal(0,group_sd[3]);
  }
  
  for(k in 1:NS){
    opt_spp[k] ~ normal(0,spp_sd[1]);
    width_spp[k] ~ normal(0,spp_sd[2]);
    height_spp[k] ~ normal(0,spp_sd[3]);
  }
  
  //likelihood
  y ~ neg_binomial_2(mu,phi);
}
generated quantities {
 vector[N] y_rep;
 for(n in 1:N){
  y_rep[n] = neg_binomial_2_rng(mu[n],phi); //posterior draws to get posterior predictive checks
 }
}
Data Simulation / Fitting / PP Check code is here:
library(rstan)
library(dplyr)
library(ggplot2)
library(shinystan)
setwd("~/code/pheno_scaling/")
####Simulates data####
n <- 1000
ngroups <- 10
nspp <- 10
phi <- 50
width_int <- 2
width_slope <- -0.2
width_group_sd <- 0.1
width_spp_sd <- 0.6
height_int <- 4
height_slope <- 0.2
height_group_sd <- 0.2
height_spp_sd <- 1
opt_int <- 0
opt_slope <- 0.4
opt_group_sd <- 0.1
opt_spp_sd <- 1
x1 <- rnorm(n,0,1)
x2 <- rnorm(n,0,1)
group <- sample(1:ngroups,size=n,replace=TRUE)
spp <- sample(1:nspp,size=n,replace=TRUE)
simdata <- data.frame(x1=x1,
                      x2=x2,
                      group=group,
                      spp=spp)
group_rnd_int <- data.frame(group=1:ngroups,
                            width_grp=rnorm(ngroups,0,width_group_sd),
                            height_grp=rnorm(ngroups,0,height_group_sd),
                            opt_grp=rnorm(ngroups,0,opt_group_sd))
spp_rnd_int <- data.frame(spp=1:nspp,
                            width_spp=rnorm(nspp,0,width_spp_sd),
                            height_spp=rnorm(nspp,0,height_spp_sd),
                            opt_spp=rnorm(nspp,0,opt_spp_sd))
simdata <- left_join(simdata,group_rnd_int,by="group")
simdata <- left_join(simdata,spp_rnd_int,by="spp")
width <- exp(width_int + width_slope*x2 + simdata$width_grp + simdata$width_spp)*-1
height <- height_int + height_slope*x2 + simdata$height_grp + simdata$height_spp
opt <- opt_int + opt_slope*x2 + simdata$opt_spp
simdata$linpred <- exp(width*(x1-opt)^2+height)
simdata$y <- rnbinom(n=n,size = phi, mu=simdata$linpred )
ggplot(simdata)+
  geom_point(aes(x=x2,y=x1,color=y))+
  facet_grid(facets=as.factor(spp)~as.factor(group))+
  theme_bw()
####Fits the model in STAN.
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
standata <- list(N=length(simdata$y),
                 NG=ngroups,
                 NS=nspp,
                 K=2,
                 y=simdata$y,
                 x1=simdata$x1,
                 x2=simdata$x2,
                 group=simdata$group,
                 spp=simdata$spp)
track_pars <- c("beta_opt","beta_width","beta_height","group_sd",
                "spp_sd","phi","y_rep")
fit <- stan("./code/stanmodel_negbin_vertexform.stan",
            data=standata,iter=1000,chains=3,
            include=TRUE,pars=track_pars)
disp_pars <- c("beta_opt","beta_width","beta_height","group_sd",
                "spp_sd","phi")
summary(fit,pars=disp_pars)
traceplot(fit,pars=disp_pars)
plot(fit,pars=track_pars)
#pairs(fit,pars=track_pars)
####Posterior predictive checks.
library(bayesplot)
y_rep <- as.matrix(fit, pars = c("y_rep"))
ppc_dens_overlay(y=simdata$y, y_rep[1:50,],main="Data vs. Fit")
true_params <- c(opt_int,opt_slope,
                 width_int,width_slope,
                 height_int,height_slope,
                 opt_group_sd,width_group_sd,height_group_sd,
                 opt_spp_sd,width_spp_sd,height_spp_sd,phi)
mcmc_recover_hist(other_pars, true=true_params)