Finally got the chance to test this out in comparison.
Code below - used some data from a non-represenative survey of MTurks from the 2020 election, tested a number of random effects for a sample size ~ 5000. I used random effects for gender, age, income and states.
I left the gender effects as random, instead of just having a âmale fixed effectâ (which would be more efficient I think) because in the matrix formulation you donât want to have to look at which effects you can make âfixedâ and which you want to keep as ârandomâ, and even if there is no shrinkage for k<3 you still get to identify the model via the soft sum-to-zero constraint on gender, so thatâs good enough for me.
Your intuition bares out - the sparse matrix approach is more than twice as slow. Itâs a shame because the coding-time gains can be quite large imho and the results seem to be pretty much identical, suggesting the approach is in principle viable.
I wonder if there is any way to speed up the sparse-matrix computation ? Again I think this is pretty vital for situations with large numbers of random effects⌠youâre not going to code up 100 effects separately - itâs just so inefficient.
I attach the plots for convergencetradvsparse_convergence.pdf (8.5 KB) and coefficient comparisons tradvsparse_reff.pdf (16.4 KB) (mean, 2.5th quantile and 97.5th quantile) and the computation times are :
print(reffmatrix.time )
Time difference of 5.866842 mins
print(trad.time)
Time difference of 2.488841 mins
# # # # # # # # # # # # # # # Check Speed and accuracy of this random effect model # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
temp.list = list( n = dim(train.voters.star)[1],
Y = train.voters.star$vote2020,
gender_id = match(train.voters.star$gender,levels(SF_complete_temp$gender)),
N_gender = nlevels(train.voters.star$gender),
age_id = match(train.voters.star$age_bins,levels(train.voters.star$age_bins)),
N_age = nlevels(train.voters.star$age_bins),
income_id = match(train.voters.star$commercial_estimated_hh_income,levels(train.voters.star$commercial_estimated_hh_income)),
N_income = nlevels(train.voters.star$commercial_estimated_hh_income),
state_id = match(train.voters.star$state,levels(SF_complete_temp$state)),
N_state = nlevels(train.voters.star$state)
)
reff_vars.temp = c("gender","age_bins","commercial_estimated_hh_income","state")
Z.temp = data.table()
for(j in 1:length(reff_vars.temp)){
temp = train.voters.star[,reff_vars.temp[j],with=FALSE]
colnames(temp)=paste(colnames(temp),"___",sep="")
Z.temp = cbind(Z.temp, model.matrix(~.-1,data = temp))
}
reff_vars_id.temp = match(sub("\\_\\_\\_.*", "",colnames(Z.temp)), reff_vars.temp )
temp.list$Z = Z.temp
temp.list$q = dim(Z.temp)[2]
temp.list$reff_vars_id = reff_vars_id.temp
temp.list$n_reff= max(reff_vars_id.temp)
temp.list$Zw_size = sum(!unlist(Z)==0)
temp.list$Zv_size = sum(apply(Z,2,sum)!=0)
temp.list$Zu_size = sum(apply(Z,1,sum)!=0)
# # # # Multilevel Regression Code + BYM Spatial Smoothing
model.reffmatrix= "
data {
// Outcome data
int<lower = 1> n; // total number of observations
int<lower = 0> Y[n]; // vector of labels
// Random effects data
int<lower = 1> q; // number of random-effect covariates
matrix[n,q] Z; // random design matrix
int<lower = 1> reff_vars_id[q]; // variable-group ids (for local shrinkage)
int<lower = 1> n_reff; // no. of separate variance components
// Data for sparse matrix operations
int<lower = 1> Zw_size; // vector of the non-zero real values
int<lower = 1> Zv_size; // array of integer column indices
int<lower = 1> Zu_size; // array of integer indices indicating where in Zw a given rowâs values start
}
transformed data {
// Tranzlate Z to sparse matrix
vector[Zw_size] Zw = csr_extract_w(Z);
int Zv[Zv_size] = csr_extract_v(Z);
int Zu[Zu_size] = csr_extract_u(Z);
}
parameters {
real alpha; // global fixed intercept
vector[q] eta; // random effects
real<lower = 0> reff_scales[n_reff]; // scale of the random effects
}
transformed parameters{
vector[q] eta_star; // scaled random effects
vector[n] reff_pred; // random effects predictor
vector[n] mu; // logit-scale total linear predictor
// calculate random-effects predictor
for(i in 1:q){
eta_star[i] = eta[i]*reff_scales[reff_vars_id[i]];
}
reff_pred = csr_matrix_times_vector(n, q, Zw, Zv, Zu, eta_star);
// linear function of the logit-scale propensity to be a recruit
mu = alpha + reff_pred;
}
model {
// // // Fixed Effects
alpha ~ std_normal();
eta ~ std_normal();
reff_scales ~ std_normal();
// // // Likelihood
Y ~ bernoulli_logit(mu);
}
"
pars.reffmatrix = c('alpha','eta_star','reff_scales')
model.trad= "
data {
// Outcome data
int<lower = 1> n; // total number of observations
int<lower = 0> Y[n]; // vector of labels
// Random effects data
int gender_id[n]; // index of gender categories
int<lower = 0> N_gender; // number of gender categories
int age_id[n]; // index of age categories
int<lower = 0> N_age; // number of age categories
int income_id[n]; // index of income categories
int<lower = 0> N_income; // number of income categories
int state_id[n]; // index of state categories
int<lower = 0> N_state; // number of state categories
}
parameters {
real alpha; // global fixed intercept
vector[N_gender] eta_gender;
vector[N_age] eta_age;
vector[N_income] eta_income;
vector[N_state] eta_state;
real<lower = 0> gender_scale;
real<lower = 0> age_scale;
real<lower = 0> income_scale;
real<lower = 0> state_scale;
}
transformed parameters{
vector[N_gender] eta_gender_star; //
vector[N_age] eta_age_star; //
vector[N_income] eta_income_star; //
vector[N_state] eta_state_star; //
vector[n] mu; // logit-scale linear predictor
eta_gender_star = eta_gender * gender_scale;
eta_age_star = eta_age * age_scale;
eta_income_star = eta_income * income_scale;
eta_state_star = eta_state * state_scale;
// linear function of the logit-scale propensity to vote republican
mu = alpha + eta_gender_star[gender_id] + eta_age_star[age_id] + eta_income_star[income_id] + eta_state_star[state_id];
}
model {
// // // Fixed Effects
alpha ~ std_normal();
eta_gender ~ std_normal();
eta_age ~ std_normal();
eta_income ~ std_normal();
eta_state ~ std_normal();
gender_scale ~ std_normal();
age_scale ~ std_normal();
income_scale ~ std_normal();
state_scale ~ std_normal();
// // // Likelihood
Y ~ bernoulli_logit(mu);
}
"
pars.trad = c('alpha',
'eta_gender_star',
'eta_age_star','eta_income_star',
'eta_state_star')
# # # Fit Total Model # # # # # # # # # # # # # # # # # # # # # # # #
start.time = Sys.time()
stan.model.fit.reffmatrix <- stan(model_code = model.reffmatrix,
data = temp.list,
iter = 500,
warmup = 250,
thin = 4,
pars = pars.reffmatrix,
cores =4,
chains = 4,
control = list(max_treedepth =10),
verbose = TRUE)
end.time = Sys.time()
reffmatrix.time = end.time - start.time
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
start.time = Sys.time()
stan.model.fit.trad <- stan(model_code = model.trad,
data = temp.list,
iter = 500,
warmup = 250,
thin = 4,
pars = pars.trad,
cores =4,
chains = 4,
control = list(max_treedepth =10),
verbose = TRUE)
end.time = Sys.time()
trad.time = end.time - start.time
print(trad.time)
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# generate estimated parameters object
model.params.reffmatrix = extract(stan.model.fit.reffmatrix,pars = pars.reffmatrix)
summary_fit.reffmatrix = summary(stan.model.fit.reffmatrix)
model.params.trad = extract(stan.model.fit.trad,pars = pars.trad)
summary_fit.trad = summary(stan.model.fit.trad)
# save summary
par(mfrow = c(1,2))
# # # Convergence Diagnostics
plot(summary_fit.trad$summary[,"Rhat"],pch = NA,
ylim = c(min(min(summary_fit.trad$summary[,"Rhat"],na.rm=TRUE),0.85),max(max(summary_fit.trad$summary[,"Rhat"],na.rm=TRUE),1.5)),
bty = "n",ylab = "Rhat",xlab = 'Index of Parameters',main = 'Convergence Diagnostics - Traditional Approach')
abline(h = 1.1,col= 'red',lty = 2)
params = unique(sub("\\[.*", "",rownames(summary_fit.trad$summary)))
for(j in 1:length(params)){
points(x = grep(params[j],rownames(summary_fit.trad$summary)),
y = summary_fit.trad$summary[grep(params[j],rownames(summary_fit.trad$summary)),"Rhat"],
pch = unlist(lapply(X = 15:16,function(x){rep(x,8)}))[j],col = adjustcolor( rep(seq(1,8,1),2)[j],0.5))
}
legend('topleft',legend = params ,col = adjustcolor(rep(seq(1,8,1),2)[1:length(params)],0.85),
pch = unlist(lapply(X = 15:16,function(x){rep(x,8)}))[1:length(params)])
# # #
plot(summary_fit.reffmatrix$summary[,"Rhat"],pch = NA,
ylim = c(min(min(summary_fit.reffmatrix$summary[,"Rhat"],na.rm=TRUE),0.85),max(max(summary_fit.reffmatrix$summary[,"Rhat"],na.rm=TRUE),1.5)),
bty = "n",ylab = "Rhat",xlab = 'Index of Parameters',main = 'Convergence Diagnostics - Sparse Matrix Approach')
abline(h = 1.1,col= 'red',lty = 2)
params = unique(sub("\\[.*", "",rownames(summary_fit.reffmatrix$summary)))
for(j in 1:length(params)){
points(x = grep(params[j],rownames(summary_fit.reffmatrix$summary)),
y = summary_fit.reffmatrix$summary[grep(params[j],rownames(summary_fit.reffmatrix$summary)),"Rhat"],
pch = unlist(lapply(X = 15:16,function(x){rep(x,8)}))[j],col = adjustcolor( rep(seq(1,8,1),2)[j],0.5))
}
legend('topleft',legend = params ,col = adjustcolor(rep(seq(1,8,1),2)[1:length(params)],0.85),
pch = unlist(lapply(X = 15:16,function(x){rep(x,8)}))[1:length(params)])
# # # check estimates
par(mfrow = c(1,3))
plot(y = summary_fit.trad $summary[grep("eta",rownames(summary_fit.trad $summary)),"mean"],
x = summary_fit.reffmatrix$summary[grep("eta",rownames(summary_fit.reffmatrix$summary)),"mean"],
main = 'mean',ylab = 'traditional',xlab = 'sparse')
abline(0,1)
plot(y = summary_fit.trad $summary[grep("eta",rownames(summary_fit.trad $summary)),"2.5%"],
x = summary_fit.reffmatrix$summary[grep("eta",rownames(summary_fit.reffmatrix$summary)),"2.5%"],
main = '2.5%',ylab = 'traditional',xlab = 'sparse')
abline(0,1)
plot(y = summary_fit.trad $summary[grep("eta",rownames(summary_fit.trad $summary)),"97.5%"],
x = summary_fit.reffmatrix$summary[grep("eta",rownames(summary_fit.reffmatrix$summary)),"97.5%"],
main = '97.5%',ylab = 'traditional',xlab = 'sparse')
abline(0,1)