Please share your Stan program and accompanying data if possible.
I have a multilevel regression model (900+ group and 50+ observations per group)
I am trying to use reduce_sum() to speed up the within-chain sampling. However, based on the testing so far, it does not have any performance gain (have tried to slice over group, predictor (X1, X2) , Y ). None of them yield any better performance than single-core run.
Any idea what is the best way to use reduce_sum on multilevel model?
below is the code slice over group. Have tried the same on X1, X2, Y, but does not see any performance gain.
// do reduce_sum()
functions {
real partial_sum(int[] group_slice, int start, int end,
vector Y, vector X4, vector X1,
vector X2, vector X3,
vector alpha, vector theta, vector gamma,vector delta, vector beta,
vector sigma
) {
return normal_lpdf(Y[start:end] | alpha[group_slice] + theta[group_slice].*X4[start:end] + gamma[group_slice].*X1[start:end] + delta[group_slice].*X2[start:end] + beta[group_slice].*X3[start:end] , sigma[group_slice]);
}
}
data {
int<lower=0> N; // No. of obs
int<lower=0> J; // No. of Groups
vector[N] Y; // outcome
vector[N] X1; // predictor 1
vector[N] X2; // predictor 2
vector[N] X3; // predictor 3
int Group[N]; // grouping
vector[N] X4; // predictor 4
}
transformed data {
real meanY = mean(Y);
int grainsize = 3000; //for reduce_sum , I have about 13000+ data point , 255 group, 50+ observation per group
// int seq[N] = rep_array(1,N);
}
parameters {
vector[J] alpha; // Intercept/[[]]
vector[J] gamma; // Intercept
vector<lower=0>[J] delta_raw; // Slope
vector<lower=0>[J] beta_raw; // Slope
vector<lower=0>[J] sigma;
real<lower=0,upper=2> tau; // Standard deviation of varying intercept by group
real<lower=0,upper=2> phi;
vector<lower=0>[J] theta_raw;
}
transformed parameters {
vector<upper=0>[J] beta = - beta_raw;
vector<upper=0>[J] delta = - delta_raw;
vector<upper=0>[J] theta = - theta_raw;
}
model {
beta_raw ~ exponential(tau);
tau ~ normal(log(2),0.03);
alpha ~ normal(meanY,1);
gamma ~ std_normal();
delta_raw ~ exponential(phi);
phi ~ normal(0,10);
sigma ~ exponential(2);
theta_raw ~ normal(0,10);
target += reduce_sum(partial_sum, Group, grainsize,Y, X4, X1,X2,X3,alpha,theta,gamma,delta,beta,sigma );
}