Speeding up a hierarchical nonlinear model

Hey all,

I’ve been working to translate a hierarchical nonlinear model from JAGS to STAN, mostly to allow timely inference on a large-ish dataset of ~100,000 observations. I’ve got a version in STAN that works well for simulated data, but performance isn’t much better than the JAGS version.

I suspect that there are probably some things I could do to speed things up, but I’m relatively new to STAN so I don’t have a good intuition for what strategies will be most fruitful. I’m hoping the community can help!

The model fits a nonlinear function (a log-polynomial) to overdispersed count data with a negative binomial distribution, and includes a non-centered parameterization of plot and species-specific random effects on the curve parameters.

The application is to help understand how changes in climate drive changes in the timing of flowering in plant communities. We used the JAGS version in this paper:
http://onlinelibrary.wiley.com/doi/10.1002/ecy.1996/full

Here’s the STAN program:

data {
  int N; //the number of observations
  int NG; //the number of groups
  int NS; //the number of species
  int y[N]; //the response
  vector[N] x1; //first predictor.
  vector[N] x2; //second predictor
  int group[N]; //group index
  int spp[N]; //species index
}
parameters {
  vector[2] beta_opt; //the regression parameters on the optimum
  vector[2] beta_width; //the regression parameters on the width
  vector[2] beta_height; //the regression parameters on the height
  vector[NG] opt_group; //random intercepts for the optimum
  vector[NG] height_group; //random intercepts for the height
  vector[NG] width_group; //random intercepts for the width
  vector[NS] opt_spp; //random intercepts for the optimum
  vector[NS] height_spp; //random intercepts for the height
  vector[NS] width_spp; //random intercepts for the width
  vector<lower=0>[3] group_sd; //sd for group random intercepts
  vector<lower=0>[3] spp_sd; //sd for group random intercepts
  real<lower=0> phi; //the overdispersion parameters
}
transformed parameters {
  vector[N] mu;//the linear predictor
  vector[N] opt; //optima
  vector[N] width; //width
  vector[N] height; //height
  
  opt = beta_opt[1] + x2 * beta_opt[2] + opt_group[group] + opt_spp[spp];
  width = exp(beta_width[1] + x2 * beta_width[2] + width_group[group] + width_spp[spp])*-1;
  height = beta_height[1] + x2 * beta_height[2] + height_group[group] + height_spp[spp];
  
  //can't vectorize the exponentiation step.
  for(i in 1:N){
      mu[i] = exp(width[i] * square(x1[i] - opt[i]) + height[i]);
  }
}
model {  
  
  //Priors for fixed effects
  phi ~ cauchy(10,50);
  beta_width[1] ~ normal(0,5); //prior for the intercept
  beta_width[2] ~ normal(0,5); //prior for the slope
  beta_opt[1] ~ normal(0,5); //prior for the intercept
  beta_opt[2] ~ normal(0,5); //prior for the slope
  beta_height[1] ~ normal(0,5); //prior for the intercept
  beta_height[2] ~ normal(0,5); //prior for the slope
  
  //Priors for random effects
  group_sd[1] ~ cauchy(0,3);
  group_sd[2] ~ cauchy(0,3);
  group_sd[3] ~ cauchy(0,3);
  
  spp_sd[1] ~ cauchy(0,5);
  spp_sd[2] ~ cauchy(0,5);
  spp_sd[3] ~ cauchy(0,5);
  
  for(j in 1:NG){
    opt_group[j] ~ normal(0,group_sd[1]);
    width_group[j] ~ normal(0,group_sd[2]);
    height_group[j] ~ normal(0,group_sd[3]);
  }
  
  for(k in 1:NS){
    opt_spp[k] ~ normal(0,spp_sd[1]);
    width_spp[k] ~ normal(0,spp_sd[2]);
    height_spp[k] ~ normal(0,spp_sd[3]);
  }
  
  //likelihood
  y ~ neg_binomial_2(mu,phi);
}
generated quantities {
 vector[N] y_rep;
 for(n in 1:N){
  y_rep[n] = neg_binomial_2_rng(mu[n],phi); //posterior draws to get posterior predictive checks
 }
}

Data Simulation / Fitting / PP Check code is here:

library(rstan)
library(dplyr)
library(ggplot2)
library(shinystan)
setwd("~/code/pheno_scaling/")

####Simulates data####
n <- 1000
ngroups <- 10
nspp <- 10
phi <- 50
width_int <- 2
width_slope <- -0.2
width_group_sd <- 0.1
width_spp_sd <- 0.6
height_int <- 4
height_slope <- 0.2
height_group_sd <- 0.2
height_spp_sd <- 1
opt_int <- 0
opt_slope <- 0.4
opt_group_sd <- 0.1
opt_spp_sd <- 1

x1 <- rnorm(n,0,1)
x2 <- rnorm(n,0,1)
group <- sample(1:ngroups,size=n,replace=TRUE)
spp <- sample(1:nspp,size=n,replace=TRUE)

simdata <- data.frame(x1=x1,
                      x2=x2,
                      group=group,
                      spp=spp)

group_rnd_int <- data.frame(group=1:ngroups,
                            width_grp=rnorm(ngroups,0,width_group_sd),
                            height_grp=rnorm(ngroups,0,height_group_sd),
                            opt_grp=rnorm(ngroups,0,opt_group_sd))

spp_rnd_int <- data.frame(spp=1:nspp,
                            width_spp=rnorm(nspp,0,width_spp_sd),
                            height_spp=rnorm(nspp,0,height_spp_sd),
                            opt_spp=rnorm(nspp,0,opt_spp_sd))

simdata <- left_join(simdata,group_rnd_int,by="group")
simdata <- left_join(simdata,spp_rnd_int,by="spp")

width <- exp(width_int + width_slope*x2 + simdata$width_grp + simdata$width_spp)*-1
height <- height_int + height_slope*x2 + simdata$height_grp + simdata$height_spp
opt <- opt_int + opt_slope*x2 + simdata$opt_spp

simdata$linpred <- exp(width*(x1-opt)^2+height)
simdata$y <- rnbinom(n=n,size = phi, mu=simdata$linpred )

ggplot(simdata)+
  geom_point(aes(x=x2,y=x1,color=y))+
  facet_grid(facets=as.factor(spp)~as.factor(group))+
  theme_bw()

####Fits the model in STAN.
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

standata <- list(N=length(simdata$y),
                 NG=ngroups,
                 NS=nspp,
                 K=2,
                 y=simdata$y,
                 x1=simdata$x1,
                 x2=simdata$x2,
                 group=simdata$group,
                 spp=simdata$spp)

track_pars <- c("beta_opt","beta_width","beta_height","group_sd",
                "spp_sd","phi","y_rep")
fit <- stan("./code/stanmodel_negbin_vertexform.stan",
            data=standata,iter=1000,chains=3,
            include=TRUE,pars=track_pars)


disp_pars <- c("beta_opt","beta_width","beta_height","group_sd",
                "spp_sd","phi")
summary(fit,pars=disp_pars)
traceplot(fit,pars=disp_pars)
plot(fit,pars=track_pars)
#pairs(fit,pars=track_pars)

####Posterior predictive checks.
library(bayesplot)

y_rep <- as.matrix(fit, pars = c("y_rep"))
ppc_dens_overlay(y=simdata$y, y_rep[1:50,],main="Data vs. Fit")

true_params <- c(opt_int,opt_slope,
                 width_int,width_slope,
                 height_int,height_slope,
                 opt_group_sd,width_group_sd,height_group_sd,
                 opt_spp_sd,width_spp_sd,height_spp_sd,phi)

mcmc_recover_hist(other_pars, true=true_params)

Cauchy is advised against of late as I recall, but if you really want it, try the tan trick illustrated in the “reparameterizing the Cauchy” section of the manual.

But you can vectorize everything inside the exp, storing as a temporary variable that you then loop over to exponentiate.

The for loops in the model block are unnecessary, you can do just:

opt_group ~ normal(0,group_sd[1]);
width_group ~ normal(0,group_sd[2]);
height_group ~ normal(0,group_sd[3]);
opt_spp ~ normal(0,spp_sd[1]);
width_spp  ~ normal(0,spp_sd[2]);
height_spp ~ normal(0,spp_sd[3]);

Can be just:

group_sd ~ cauchy(0,3);
spp_sd ~ cauchy(0,5);

Ditto all the beta_… priors

If the number of unique rows in a table produced by cbind(group,species) is far fewer than N, then your computation of opt, width and height have some redundant computations that you could instead do for just those unique combinations, indexing into the result to continue with the test of the computation (involving x1, x2). This is the second Q in the past week that would benefit from reducing redundant computation, so I’ll try to find time to post an example.

Thanks Mike! I didn’t realize the distribution assignment worked on vectors, and I’ll give some thought to the redundant computation problem. An example would be really useful!

Cheers,
Ian

Mike provided some good hints

but can you be more specific how do you measure performance? N_eff/s ?

1 Like

Yeah…nothing exact. I’d have to go back and compute N_eff for the JAGS version. Mixing is relatively poor, though, even with the current version in STAN incorporating some of Mike’s suggestions. For the worst-behaving parameters, I’m getting about 0.75 N_eff / sec on a run with 200 simulated observations, 5 species, and 5 groups. Depending on the scaling, this puts me into multi-day territory for the full dataset. Any suggestions would be greatly appreciated!

Here’s the current version with Mike’s suggestions implemented:

data {
  int N; //the number of observations
  int NG; //the number of groups
  int NS; //the number of species
  int y[N]; //the response
  vector[N] x1; //first predictor.
  vector[N] x2; //second predictor
  int group[N]; //group index
  int spp[N]; //species index
}
parameters {
  vector[2] beta_opt; //the regression parameters on the optimum
  vector[2] beta_width; //the regression parameters on the width
  vector[2] beta_height; //the regression parameters on the height
  vector[NG] opt_group; //random intercepts for the optimum
  vector[NG] height_group; //random intercepts for the height
  vector[NG] width_group; //random intercepts for the width
  vector[NS] opt_spp; //random intercepts for the optimum
  vector[NS] height_spp; //random intercepts for the height
  vector[NS] width_spp; //random intercepts for the width
  vector<lower=0>[3] group_sd; //sd for group random intercepts
  vector<lower=0>[3] spp_sd; //sd for group random intercepts
  real<lower=0> phi; //the overdispersion parameter
}
transformed parameters {
  vector[N] mu;//the linear predictor
  vector[N] opt; //optima
  vector[N] width; //width
  vector[N] height; //height
  vector[N] x_min_opt; //vectorizes expression 'x1 - opt'
  
  opt = beta_opt[1] + x2 * beta_opt[2] + opt_group[group] + opt_spp[spp];
  width = exp(beta_width[1] + x2 * beta_width[2] + width_group[group] + width_spp[spp])*-1;
  height = beta_height[1] + x2 * beta_height[2] + height_group[group] + height_spp[spp];
  x_min_opt = x1 - opt;
  
  //can't vectorize the exponentiation step.
  for(i in 1:N){
      mu[i] = exp(width[i] * square(x_min_opt[i]) + height[i]);
  }
}
model {  
  
  //Priors for fixed effects
  phi ~ student_t(4,10,50);
  beta_width[1] ~ normal(0,5); //prior for the intercept
  beta_width[2] ~ normal(0,5); //prior for the slope
  beta_opt[1] ~ normal(0,5); //prior for the intercept
  beta_opt[2] ~ normal(0,5); //prior for the slope
  beta_height[1] ~ normal(0,5); //prior for the intercept
  beta_height[2] ~ normal(0,5); //prior for the slope
  
  //Priors for random effects
  group_sd ~ normal(0.5,2); //truncated normal bc parameter defined > 0
  spp_sd ~ normal(0.5,2); //truncated normal bc parameter defined > 0
  
  opt_group ~ normal(0,group_sd[1]);
  width_group ~ normal(0,group_sd[2]);
  height_group ~ normal(0,group_sd[3]);

  opt_spp ~ normal(0,spp_sd[1]);
  width_spp ~ normal(0,spp_sd[2]);
  height_spp ~ normal(0,spp_sd[3]);
  
  //likelihood
  y ~ neg_binomial_2(mu,phi);
}
generated quantities {
 vector[N] y_rep;
 for(n in 1:N){
  y_rep[n] = neg_binomial_2_rng(mu[n],phi); //posterior draws to get posterior predictive checks
 }
}

This:

beta_width[1] ~ normal(0,5); //prior for the intercept
beta_width[2] ~ normal(0,5); //prior for the slope
beta_opt[1] ~ normal(0,5); //prior for the intercept
beta_opt[2] ~ normal(0,5); //prior for the slope
beta_height[1] ~ normal(0,5); //prior for the intercept
beta_height[2] ~ normal(0,5); //prior for the slope

Can become:

beta_width ~ normal(0,5); 
beta_opt ~ normal(0,5); 
beta_height ~ normal(0,5);

Thanks! I left priors on the betas individually specified because I suspected that I might need different priors on slopes and intercepts when I run this with real data.

One small improvement: Instead of this

for(i in 1:N){
      mu[i] = exp(width[i] * square(x_min_opt[i]) + height[i]);
  }

and

y ~ neg_binomial_2(mu,phi);

use (check that square vectorizes)

eta = width .* square(x_min_opt) + height;

and

y ~ neg_binomial_2_log(eta,phi);

To analyse the mixing problems: Which are the worst-behaving parameters? Are they 1) weakly-identified, 2) strongly correlating or 3) part of the funnel? Could you check 1) prior vs. posterior, how long tail for the posterior, treedepth, and E-BFMI, 2) stepsize per parameter divided by marginal posterior of that parameter, correlation, and treedepth distribution, 3) divergences and treedepth distribution?

Do you use R? If so, you could ShinyStan to check many of these, but you can get them out with other interfaces, too.

Could you please suggest a currently acceptable alternative?

normal or student-t with 3 or 4 degrees of freedom (see: https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations)

1 Like

Finally got around to doing this: Tip for speeding up models with repeated observations

truncated student-t 4 df is very close to exponential(1.0);

Make sure to do this with Stan’s calculation of N_eff as it’s much more conservative than the JAGS version.

Negative binomials can be hard to fit because the overdispersion makes a lot of their parameter space consistent with input data.