Speeding up AME network model sampling

I’m working on fitting Hoff’s AMEN model in Stan, using the excellent code from Adam Lauretig here as a starting point: Hoff's AMEN model in Stan

Where I’m running into problems is that my network has ~37k nodes, so getting even a fairly trivial number of samples takes north of a week. I’d like to get as efficient as I can without using reduce_sum and then use reduce_sum if needed. I’ve pasted my code below - I’ve made some changes to the linked code because my network isn’t directed (i.e. it’s symmetric).

The hangup would seem to be the loop in the model block (for n in 1:n_dyads). I’m wondering if there’s a way to vectorize this?

Any suggestions welcome - thanks!

  int n_nodes;
  int n_dyads;
  //int N; //total obs. should be n_dyads * 2
  int sender_id[n_dyads];
  int receiver_id[n_dyads];
  //int dyad_id;
  int K; // number of latent dimensions
  int B; //number of dyad level covariates
  matrix[n_dyads, B] X; //covariates
  real Y[n_dyads];

  real intercept;
  vector[B] beta; 
  //cholesky_factor_corr[2] corr_nodes; // correlation matrix w/in hh - not necessary in symmetric case
  real<lower=0>sigma_nodes; // sd w/in nodes
  vector[n_nodes] z_nodes; // for node non-centered parameterization, vector since no send/receive correlation
  cholesky_factor_corr[K * 2] corr_multi_effects; // correlation matrix for multiplicative effect
  vector<lower=0>[K * 2] sigma_multi_effects; // sd
  matrix[K * 2, n_nodes] z_multi_effects; // Multi-effect non-centered term 
transformed parameters{
    vector[n_nodes] mean_nodes; // node parameter mean
    matrix[n_nodes, K * 2] mean_multi_effects; // multi-effect mean
    mean_nodes = (sigma_nodes * z_nodes); // sd *correlation
    mean_multi_effects = (diag_pre_multiply(
      sigma_multi_effects, corr_multi_effects) * z_multi_effects)'; // sd *correlation  
  intercept ~ normal(0, 5);
  beta ~ normal(0,5); 
  //node terms
  to_vector(z_nodes) ~ normal(0, 1);
  //corr_nodes ~ lkj_corr_cholesky(5);
  sigma_nodes ~ gamma(1, 1);

  // multi-effect terms
  to_vector(z_multi_effects) ~ normal(0, 1);
  corr_multi_effects ~ lkj_corr_cholesky(5);
  sigma_multi_effects ~ gamma(1, 1);

  for(n in 1:n_dyads){
  Y[n] ~ normal(intercept + X[n]*beta +
        mean_nodes[sender_id[n]] + mean_nodes[receiver_id[n]] + 
        mean_multi_effects[sender_id[n], 1:K] * 
        (mean_multi_effects[receiver_id[n], (K+1):(K*2)])',