Simple stan regression model takes 60 hours to complete. Why?

I want to model the interaction between x1 and x2 in this simulated data with a logistic regression


I am quite inexperienced in stan, but to my knowledge this is a very simple stan model which should take 52 hours to sample from according to stan. This is a simple logistic regression with 3 predictors which should be fairly easy to sample. I am using rstan and my setup in R looks like this:

X <- data.frame(x1 = rbinom(6534, 2, 0.3),
                x2 = rbinom(6534, 2, 0.3))

Y <- rbinom(6534, 1, 0.2)

logistic_regression <- stan_model("logistic_regression.stan")
logistic_fit <- sampling(logistic_regression,
                         list(N = dim(X)[1], 
                              y=Y, 
                              x=X),
                         iter=2000,
                         chains = 4,
                         save_warmup=FALSE)

My stan model looks like this

logistic_regression.stan


data{
  int<lower=1> N; // Rows
  int<lower=0, upper=1> y[N]; // Outcome variables
  matrix<lower=0, upper=2>[N, 2] x; // Predictor variables, always two columns
}

parameters{

  real a;
  row_vector[3] b; 
}

model{
  
  // Priors
  a ~ normal(0, 0.1);
  for (i in 1:3)
  {
    b[i] ~ normal(0, 0.2^2); 
  }

  // Likelihood
  for (i in 1:N)
  {
    y[i] ~ bernoulli_logit( a + b[1]*x[,1] + b[2]*x[,1] + b[3]*x[,1].*x[,2] );
  }

}

When I initialize sampling I get this message

Chain 1: Gradient evaluation took 21.6585 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 216585 seconds.
Chain 1: Adjust your expectations accordingly!

And the same message for chain 2, 3 & 4. I left it to run overnight and it wasn’t done the next morning. I have tried simple stan models in the past without this run time, but I don’t have the code anymore so I cannot compare to older models. I am working on a linux server on which many other people work without a problem.

Why does the sampling take this long time in this model?

First, there’s a bug in your code that is probably causing the majority of the slowdown. You loop N times to express the likelihood for each element in y, but where each said element should be associated with a single row in x, you’re not similarly indexing x (ie. x[i,1]), resulting in computing the entire set of b*x terms every loop. Since bernoulli_logit() is vectorized, you can fix this by simply skipping the loop and doing:

  y ~ bernoulli_logit( a + b[1]*x[,1] + b[2]*x[,1] + b[3]*x[,1].*x[,2] );

There’s a number of things you can further do to make this example faster:

  • Since all elements of parameter vector b receive the same prior, you don’t need to assign them in a loop but can instead do simply b ~ normal(0,0.04) ; (it also helps very slightly to actually put 0.04 rather than 0.2^2)

  • You should pre-compute x[,1]*x[,2] in R and supply that as a third column of x

  • (This is a big one) Since by your example it appears that the columns x[,1] and x[,2] will only take 4 unique values in combination ([1,1],[1,0],[0,1],[0,0]), you can compute just the unique values for b*x. That should get substantial speedup alone, but then you can also employ the sufficient statistics trick for yet more speed.

For posterity, here’s R code for your example coded how I’d recommend (and takes <1s to sample):

library(tidyverse)
library(cmdstanr)

model_code = '
data{
  int<lower=1> K ; //number of unique combinations of the covariates
  matrix<lower=0,upper=4>[K,3] x ; //unique covariates
  int<lower=1> N[K] ; //number of observations associated with each row in X
  int sum_y[K]; // total successes for each row in X
}
parameters{
  real a;
  row_vector[3] b; 
}
model{
  // Priors
  a ~ std_normal();
  b ~ std_normal(); 
  // Likelihood
  sum_y ~ binomial(
    N
    , inv_logit(
      a 
      + b[1]*x[,1] 
      + b[2]*x[,2] 
      + b[3]*x[,3]
    ) 
  );
}
'
mod = 
	(
		model_code
		%>% cmdstanr::write_stan_file()
		%>% cmdstanr::cmdstan_model()
	)

set.seed(1)
N_tot = 1e4
X = 
	(
		tibble::tibble(
			x1 = rbinom(N_tot, 2, 0.3)
			, x2 = rbinom(N_tot, 2, 0.3)
			, x3 = x1*x2
		)
		%>% as.matrix()
	)
head(X)

intercept = rnorm(1)
coefs = rnorm(3)
print(c(intercept,coefs))

Y = rbinom(
	nrow(X)
	, 1
	, plogis(
		intercept 
		+ coefs[1]*X[,1] 
		+ coefs[2]*X[,2] 
		+ coefs[3]*X[,3]
	)
)

# get unique X and sum_Y
xy_summary =
	(
		cbind(X,Y)
		%>% tibble::as_tibble()
		%>% dplyr::group_by(x1,x2,x3)
		%>% dplyr::summarise(
			sum_y = sum(Y)
			, N = n()
			, .groups = 'drop'
		)
	)
print(xy_summary)

#sample the model
fit = 
	(
		tibble::lst(
			x = dplyr::select(xy_summary, -sum_y, -N)
			, sum_y = dplyr::pull(xy_summary,sum_y)
			, N = dplyr::pull(xy_summary,N)
			, K = nrow(x)
		)
		%>% mod$sample(
			chains = parallel::detectCores()/2-1
			, parallel_chains = parallel::detectCores()/2-1
		)
	)

fit$summary()

(
	fit$draws(variables = c('a','b'))
	%>% bayesplot::mcmc_intervals()
	+ geom_point(
		data = tibble::tibble(
			x = c(intercept,coefs)
			, y = 4:1
		)
		, mapping = aes(x=x,y=y)
		, colour = 'red'
	)
)
2 Likes

My bad, I see that in your example code you use rbinom(6534, 2, 0.3), yielding vectors with 3 unique values, making for 9 unique combinations, not 4. Still should be a substantial speedup for computing only 9 things instead of 6534.

Oh, there was another bug in your code. you had b[2]*x[,1] where you should have had b[2]*x[,2]