Parallelizing Bayesian Hierarchical Model with Many Parameters

Hey all,

Our code has finally reached the point where we can run about twenty-five simulated subjects with a quarter of the data of real subjects in about two hours (1000 warmups, 1000 iterations) and get pretty good parameter recovery. We’re looking to make the jump to our full ~150 real subject data set, and a minimum ESS calculation based on this paper indicates that we should run about 10000 iterations to achieve a significance and tolerance of 0.05. Given that, and the non-linear growth in run time by participant, we want to fully take advantage of parallelization using reduce_sum.

Our cluster uses cmdstan 2.30.1

This is our code:

functions {

	//custom function for lognormal inverse CDF
	matrix logncdfinv(vector x, vector m, vector s) {							

		int szx;
		int szm;
		szx = size(x);
		szm = size(m);	
		matrix[szx,szm] y;
		y = exp( rep_matrix(s',szx) * sqrt(2) .* rep_matrix(inv_erfc( -2 .* x + 2 ),szm) + rep_matrix(m',szx));
		return y;

		}

	//likelihood model for CASANDRE				
	real casandre_log(matrix resps, real guess, vector sds, vector se, real conf, int nt, int sampn) {			

		matrix[nt,4] llhC;

		matrix[nt,sampn] avgs;
		avgs = rep_matrix(se,sampn).*rep_matrix(sds',nt);
		
		for (tr in 1:nt){

			matrix[3,sampn] raws;

			for (rws in 1:sampn) {
				raws[1,rws] = normal_cdf(-conf,avgs[tr,rws],sds[rws]);
				}
			for (rws in 1:sampn) {
				raws[2,rws] = normal_cdf(0,avgs[tr,rws],sds[rws]);
				}
			for (rws in 1:sampn) {
				raws[3,rws] = normal_cdf(conf,avgs[tr,rws],sds[rws]);
				}

			vector[3] ratiodist;

			ratiodist[1] = mean(raws[1,:]);
			ratiodist[2] = mean(raws[2,:]);
			ratiodist[3] = mean(raws[3,:]);

			llhC[tr,1] = ((guess/4) + ((1-guess)*ratiodist[1]));
			llhC[tr,2] = ((guess/4) + ((1-guess)*(ratiodist[2]-ratiodist[1])));
			llhC[tr,3] = ((guess/4) + ((1-guess)*(ratiodist[3]-ratiodist[2])));
			llhC[tr,4] = ((guess/4) + ((1-guess)*(1-ratiodist[3])));

 		}

		real ll;
		ll = sum(columns_dot_product(resps, log(llhC)));
		return ll;
		
	}
		
}
data {

	int ns;
	int nt;
	int sampn;
	matrix[4,ns*nt] respslong
	row_vector[ns*nt] orislong;
	vector[sampn] sampx;
	
}
transformed data {

	
	matrix[ns*nt,4] respstall;
	vector[ns*nt] oristall;
	int nth;

	respstall = respslong';
	oristall = orislong';

	array[ns] matrix[nt,4] resps;
	matrix[nt,ns] oris

	nth = 1

	for (sub in 1:ns) {

		resps[sub] = respstall[nth:(nth+nt-1),1:4];							

		oris[:,sub] = oristall[nth:(nth+nt-1)];

		nth = nth + nt;

	}

}
parameters {

	//Parameters
	vector<lower=0,upper=1>[ns] guess;
	vector<lower=0>[ns] sens;
	vector[ns] crit;
	vector<lower=0>[ns] meta;
	vector<lower=0>[ns] conf;

	//Hyperparameters
	real snm;
	real<lower=0> sns;
	real cm;
	real<lower=0> cs;
	real mm;
	real<lower=0> ms;
	real ccm;
	real<lower=0> ccs;

}
model {

	//Calculate local variables
	matrix[nt,ns] sm;
	vector[ns] sc;
	matrix[sampn,ns] xtrans;
	matrix[sampn,ns] sds
	matrix[nt,ns] se;

	sm = oris.*rep_matrix(sens',nt);
	sc = crit.*sens;

	xtrans = logncdfinv(sampx,log(1/sqrt(meta.^2+1)),sqrt(log(meta.^2+1)));							

	sds = 1./xtrans;
	se = sm-rep_matrix(sc',nt);
	
	//Hyperpriors
	snm 	~	normal(0,1);
	sns 	~	lognormal(0,1);
	cm 	~	normal(0,1);
	cs 	~	lognormal(0,1);
	mm 	~	normal(0,1);
	ms 	~	lognormal(0,1);
	ccm 	~	normal(0,1);
	ccs 	~	lognormal(0,1);

	//Priors
	guess 	~	beta(1,193/3);
	sens	~	lognormal(snm,sns);
	crit	~	normal(cm,cs);
	meta  	~	lognormal(mm,ms);
	conf  	~	lognormal(ccm,ccs);

	//Loop through the participants
	for (i in 1:ns) {

		//Likelihood Model						
		resps[i] ~ casandre(guess[i],sds[:,i],se[:,i],conf[i],nt,sampn);						//likelihood model for this this participant

	}

}

As you can see, each subject has five parameters, four of which are of theoretical interest (sens, crit, meta, conf), and each of the parameters of interest has a prior with two hyperparamters each. Thanks to a lot of help on a previous thread on the tan forum, much of the calculation – perhaps as much of it as is possible – has been pulled out of either the main for loop or function for loops and converted to matrix multiplication. This has resulted in a considerable speed up, but the run time for ~150 subjects (750 parameters, 8 hyperparameters), 800 trials each, and 10000 iterations, is still taking a long time, if it isn’t outright intractable for our purposes. So, if possible, we want to explicitly parallelize our code to take full advantage of our access to a large CPU cluster. Right now, the program, when compiled using openmpi and STAN_THREADS = TRUE, will utilize (according to our cluster read-outs) 25+num_chains CPUs even if you explicitly use a higher number for num_threads.

Am I mistaken in believing that our likelihood model, which is calculated independently for each subject, is ripe for parallelization with reduce_sum? And should we expect that if we write reduce_sum into the code that it will allow the utilization of a larger number of CPUs in our cluster? I’ve been trying to read up on how reduce_sum works and watch videos on it, but I’m still pretty confused about how it’s implemented, particularly in a model like ours where each subject has a set of parameters and the only pooled parameters are the ones coming from the hyperpriors. Any advice on how I can get there, or resources that I can use for this sort of problem outside of the Stan User Manual and Stan Functions Reference?

I appreciate all of your help, especially as a newbie trying to code up this beast of a model. If you’re interested in where it comes from, here is a paper explaining exactly what it is. In short, it is a two stage process model of metacognition (currently written for perceptual experiments).

Sincerely,
Corey

I guess what I’m most struggling with is following the online tutorials and documentation for creating the partial sum. They all seem to be written for likelihood functions that have very simple inputs, like integers, but each slice of the casandre likelihood function takes a mixture of a matrix (the response vector), several vectors (scaled parameters), several reals (unscaled parameters), and integers (counts). So it’s not clear to me how to set up the slices in the partial sum function that reduce_sum relies on. Would it be the following?

partial_sum_casandre_log(array[] resps,
                         int start,
                         int end,
                         vector[] guess,
                         matrix[] sds,
                         matrix[] se,
                         vector[] conf,
                         int nt,
                         int sampn) {

  return casandre_log(resps[start:end],
                      guess[start:end],
                      sds[start:end,:],
                      se[start:end,:],
                      conf[start:end],
                      nt,
                      sampn);
  }

Problem solved! I’m going to respond to my own question in case it helps someone in the future. This is the parallelization of my code:

functions {

	//Custom function for lognormal inverse CDF
	matrix logncdfinv(vector x, vector m, vector s) {							

		int szx;
		int szm;
		szx = size(x);
		szm = size(m);	
		matrix[szx,szm] y;
		y = exp( rep_matrix(s',szx) * sqrt(2) .* rep_matrix(inv_erfc( -2 .* x + 2 ),szm) + rep_matrix(m',szx));

		return y;

		}

	//Likelihood model for CASANDRE				
	real casandre_log(array[] matrix resps, vector guess, matrix sds, matrix se, vector conf) {

		real ll;
		ll = 0;

		for (n in 1:size(resps)) {

			matrix[size(se[:,n]),4] llhC;

			matrix[size(se[:,n]),size(sds[:,n])] avgs;
			avgs = rep_matrix(se[:,n],size(sds[:,n])).*rep_matrix(sds[:,n]',size(se[:,n]));
		
			for (tr in 1:size(se[:,n])){

				matrix[3,size(sds[:,n])] raws;

				for (rws in 1:size(sds[:,n])) {
					raws[1,rws] = normal_cdf(-conf[n],avgs[tr,rws],sds[rws,n]);
					}
				for (rws in 1:size(sds[:,n])) {
					raws[2,rws] = normal_cdf(0,avgs[tr,rws],sds[rws,n]);
					}
				for (rws in 1:size(sds[:,n])) {
					raws[3,rws] = normal_cdf(conf[n],avgs[tr,rws],sds[rws,n]);
					}

				vector[3] ratiodist;

				ratiodist[1] = mean(raws[1,:]);
				ratiodist[2] = mean(raws[2,:]);
				ratiodist[3] = mean(raws[3,:]);

				llhC[tr,1] = ((guess[n]/4) + ((1-guess[n])*ratiodist[1]));
				llhC[tr,2] = ((guess[n]/4) + ((1-guess[n])*(ratiodist[2]-ratiodist[1])));
				llhC[tr,3] = ((guess[n]/4) + ((1-guess[n])*(ratiodist[3]-ratiodist[2])));
				llhC[tr,4] = ((guess[n]/4) + ((1-guess[n])*(1-ratiodist[3])));

 			}

		ll += sum(columns_dot_product(resps[n], log(llhC)));
		
		}

		return ll;
		
	}

	//Partial sum function
	real partial_sum_casandre_log(array[] matrix slice_n_resps, int start, int end, vector guess, matrix sds, matrix se, vector conf) {

	return casandre_log(slice_n_resps, guess[start:end], sds[:,start:end], se[:,start:end], conf[start:end]);

	}
		
}
data {

	int ns;
	int nt;
	int sampn;
	matrix[4,ns*nt] respslong;
	row_vector[ns*nt] orislong;
	vector[sampn] sampx;
	
}
transformed data {

	matrix[ns*nt,4] respstall;
	vector[ns*nt] oristall;
	int nth;

	respstall = respslong';
	oristall = orislong';

	array[ns] matrix[nt,4] resps;
	matrix[nt,ns] oris;

	nth = 1;

	for (sub in 1:ns) {

		resps[sub] = respstall[nth:(nth+nt-1),1:4];									

		oris[:,sub] = oristall[nth:(nth+nt-1)];

		nth = nth + nt;

	}

}
parameters {

	//Parameters
	vector<lower=0,upper=1>[ns] guess;
	vector<lower=0>[ns] sens;
	vector[ns] crit;
	vector<lower=0>[ns] meta;
	vector<lower=0>[ns] conf;

	//Hyperparameters
	real snm;
	real<lower=0> sns;
	real cm;
	real<lower=0> cs;
	real mm;
	real<lower=0> ms;
	real ccm;
	real<lower=0> ccs;

}
model {

	//Calculate local variables
	matrix[nt,ns] sm;
	vector[ns] sc;
	matrix[sampn,ns] xtrans;	
	matrix[sampn,ns] sds;
	matrix[nt,ns] se;

	sm = oris.*rep_matrix(sens',nt);
	sc = crit.*sens;

	xtrans = logncdfinv(sampx,log(1/sqrt(meta.^2+1)),sqrt(log(meta.^2+1)));

	sds = 1./xtrans;
	se = sm-rep_matrix(sc',nt);
	
	//Hyperpriors
	snm 	~	normal(0,1);
	sns 	~	lognormal(0,1);
	cm 	~	normal(0,1);
	cs 	~	lognormal(0,1);
	mm 	~	normal(0,1);
	ms 	~	lognormal(0,1);
	ccm 	~	normal(0,1);
	ccs 	~	lognormal(0,1);

	//Priors
	guess 	~	beta(1,193.0/3.0);
	sens	~	lognormal(snm,sns);
	crit	~	normal(cm,cs);
	meta  	~	lognormal(mm,ms);	
	conf  	~	lognormal(ccm,ccs);

	//Likelihood model	
	target +=  reduce_sum(partial_sum_casandre_log,resps, 1, guess, sds, se, conf);

}
3 Likes

Thanks for following up on this and posting your own worked out solution. It would be great if you could file an issue for how we can possibly improve the documentation. This is a bit of an ask, I know - but many people struggle with similar things, I think.

BTW… while the original intent of reduce_sum has been to slice the first argument. This was a good thing for performance with its first version. Already with the second version of it things have improved so much that slicing isn’t really needed for good performance. Hence, you can slice a dummy integer index sequence which labels 1 to number of terms in the sum. Then you just do the slicing by accessing the respective elements from the shared parameters. This leads to a few more copies being made during the reduce steps - but you won’t notice performance differences and many people perceive this as easier to code with.

(to get more speed for this program you should make an attempt to not do the normal_cdf calls one by one; instead collect things into a vector and then do one big vectorised call - this should give you a big boost in speed)

2 Likes

Hey,

Thanks! I may have some time to write up an issue in a month or so, after a big upcoming conference.

I also just tried vectorizing normal_cdf, but it appears to only return a single real, regardless of whether or not its inputs are vectors. This also appears to have been brought up in a previous thread where it was noted that normal_cdf does not give vectorized outputs.

This might still work for me if a call to the normal_cdf with vector inputs returns the sum of the normal cdf for each element in the input vectors, which I can then divide by the number of x-axis samples of to find the mean. Otherwise, I may be misunderstanding what I should be doing, but it doesn’t appear to work

Sincerely,
Corey

Apologies! I misread. You need the CDFs value by value in which case you have to do the calls one by one, of course. In that case you could try the Phi_approx function from Stan if you are ok with an approximation.

Thanks! I might actually try making a user-defined normal_cdf function that takes matrices and does matrix multiplication instead of looping. I know typically Stan-defined functions tend to be faster, but I’m wondering if trading the for-loop for matrix multiplication makes it worth it. If it results in a boost in speed, I’ll post the function in the developer forum

Edit: it was most definitely not faster!