I’m working on fitting Hoff’s AMEN model in Stan, using the excellent code from Adam Lauretig here as a starting point: Hoff's AMEN model in Stan
Where I’m running into problems is that my network has ~37k nodes, so getting even a fairly trivial number of samples takes north of a week. I’d like to get as efficient as I can without using reduce_sum
and then use reduce_sum
if needed. I’ve pasted my code below - I’ve made some changes to the linked code because my network isn’t directed (i.e. it’s symmetric).
The hangup would seem to be the loop in the model block (for n in 1:n_dyads
). I’m wondering if there’s a way to vectorize this?
Any suggestions welcome - thanks!
data{
int n_nodes;
int n_dyads;
//int N; //total obs. should be n_dyads * 2
int sender_id[n_dyads];
int receiver_id[n_dyads];
//int dyad_id;
int K; // number of latent dimensions
int B; //number of dyad level covariates
matrix[n_dyads, B] X; //covariates
real Y[n_dyads];
}
parameters{
real intercept;
vector[B] beta;
//cholesky_factor_corr[2] corr_nodes; // correlation matrix w/in hh - not necessary in symmetric case
real<lower=0>sigma_nodes; // sd w/in nodes
vector[n_nodes] z_nodes; // for node non-centered parameterization, vector since no send/receive correlation
cholesky_factor_corr[K * 2] corr_multi_effects; // correlation matrix for multiplicative effect
vector<lower=0>[K * 2] sigma_multi_effects; // sd
matrix[K * 2, n_nodes] z_multi_effects; // Multi-effect non-centered term
}
transformed parameters{
vector[n_nodes] mean_nodes; // node parameter mean
matrix[n_nodes, K * 2] mean_multi_effects; // multi-effect mean
mean_nodes = (sigma_nodes * z_nodes); // sd *correlation
mean_multi_effects = (diag_pre_multiply(
sigma_multi_effects, corr_multi_effects) * z_multi_effects)'; // sd *correlation
}
model{
intercept ~ normal(0, 5);
beta ~ normal(0,5);
//node terms
to_vector(z_nodes) ~ normal(0, 1);
//corr_nodes ~ lkj_corr_cholesky(5);
sigma_nodes ~ gamma(1, 1);
// multi-effect terms
to_vector(z_multi_effects) ~ normal(0, 1);
corr_multi_effects ~ lkj_corr_cholesky(5);
sigma_multi_effects ~ gamma(1, 1);
for(n in 1:n_dyads){
Y[n] ~ normal(intercept + X[n]*beta +
mean_nodes[sender_id[n]] + mean_nodes[receiver_id[n]] +
mean_multi_effects[sender_id[n], 1:K] *
(mean_multi_effects[receiver_id[n], (K+1):(K*2)])',
1);
}
}