How to speed up my stan sampling

I’m trying to build a stan model with marginalization of discrete parameter. However, my discrete parameter is complex. My input data are a 9090 correlation matrix for each subject, which measures the correlation between regions. These 90 regions can be divided into 13 mutully exclusive groups. Then, I want to create 13(13-1)/2+13=91 (This is the upper triangular of a 1313 matrix, since we assume latent parameter group 1&j = groupj&i) latent group-level connection variables. we set the likelihood that all connectivity values in the 9090 matrix are from the same normal distribution with the mean \mu_ij where i and j are group indices. Below is my code, it works, but it’s just super slow… I think its probably because I’m useing too many for loops here. Does anyone have any ideas on how to speed it up? Thank you!

functions {
    #include /fun/hs.fun
    // Find the number of elements in the vector x that equal real number y
    int num_matches(vector x, real y) {
        int n = 0;
        for (i in 1:rows(x))
            if (x[i] == y)
            n += 1;
        return n;
    }
    
    // Find the indexes of the elements in the vector x that equal real number y
    array[] int which_equal(vector x, real y) {
        array[num_matches(x, y)] int match_positions;
        int pos = 1;
        for (i in 1:rows(x)) {
            if (x[i] == y) {
            match_positions[pos] = i;
            pos += 1;
            }
        }
        return match_positions;
    }

    int get_flattened_index(int n, int i, int j) {
        return (i * (2 * n - i + 1)) / 2 + (j - i);
    }

}
data {
   int<lower=1> n;  // number of subjects
   int<lower=1> N;  // number of data points
   vector[N] y;  // outcomes
   int<lower=0> P; // number of covariates
   matrix[n,P+1] X; // design matrix of intercept and covariates
   vector[N] age; // age of subjects
   array[N] int subj; // subject id
   int<lower = 1> v; // number of brain regions
   int<lower = 1> Q; // number of networks
   array[n] matrix[v,v] A; // connectivity matrix

    // whether the regularized horseshoe should be used
    int<lower=0, upper=1> regularized;

    // degrees of freedom for the half-t priors on lambda
    real<lower=1> nu;

    // scale for the half-t prior on tau
    real<lower=0> global_scale;

    // degrees of freedom for the half-t prior on tau
    real<lower=1> global_df;

    // slab scale for the regularized horseshoe
    real<lower=0> slab_scale;

    // slab degrees of freedom for the regularized horseshoe
    real<lower=0> slab_df;

}
transformed data{

}
parameters{
    matrix[n, Q*(Q-1)/2+Q] m_qr; // mean of each latent network mediator
    real<lower=0> sigma_qr; // variance of each latent network mediator
    vector[P+1] beta_x; // coefficients for m to y design matrix
    real beta_age;
    // vector[Q*(Q-1)/2+Q] beta_m; // coefficients for m to y mediators
    real<lower=0> epsilon; // variance of model
    vector[n] U; // random effects
    real<lower=0> sigma_U; // variance of random effects
    real<lower=0> tau;
    vector<lower=0>[Q*(Q-1)/2+Q] lambda_0;
    // vector<lower=0>[Q*(Q-1)/2+Q] lambda_1;
    vector[Q*(Q-1)/2+Q] noise_0;
    // vector[Q*(Q-1)/2+Q] noise_1;
    array[v] simplex[Q] ROI_net; // network prob of each ROI
}
transformed parameters{
    vector[Q*(Q-1)/2+Q] beta_m;
    beta_m = hs(noise_0, lambda_0, tau);
    // vector[Q*(Q-1)/2+Q] beta_m_age;
    // beta_m_age = hs(noise_1, lambda_1, tau);
}
model{
    for (i in 1:v){
        ROI_net[i] ~ dirichlet(rep_vector(1, Q));
    }
    int m_idx;
    for (k in 1:v){
        for (j in k+1:v){
            for (q in 1:Q){
                for (r in 1:Q){
                    if (q <= r){
                        m_idx = get_flattened_index(Q, q, r);
                        target += log(ROI_net[k,q]) + log(ROI_net[j,r]) + normal_lpdf(A[:, k, j] | m_qr[:,m_idx], sigma_qr);
                    }else{
                        m_idx = get_flattened_index(Q, r, q);
                        target += log(ROI_net[k,q]) + log(ROI_net[j,r]) + normal_lpdf(A[:, k, j] | m_qr[:,m_idx], sigma_qr);
                    }
                }
            }
        }
    }

    real mu;
    for (i in 1:N){
        mu = X[subj[i],]*beta_x + age[i]*beta_age + m_qr[subj[i],]*beta_m + U[subj[i]];
        target += normal_lpdf(y[i] | mu, epsilon);
        // target += normal_lpdf(y[i] | X[subj[i],]*beta_x + age[i]*beta_age + m_qr[subj[i],]*beta_m + age[i]*m_qr*beta_m_age, epsilon);
    }


    lambda_0 ~ student_t(nu, 0, 1);
    // lambda_1 ~ student_t(nu, 0, 1);
    tau ~ student_t(global_df, 0, global_scale * epsilon);
    noise_0 ~ std_normal();
    // noise_1 ~ std_normal();
    epsilon ~ inv_gamma(0.1, 0.1);
    U ~ normal(0, sigma_U);
    sigma_U ~ inv_gamma(0.1, 0.1);
    sigma_qr ~ inv_gamma(1, 1);
    beta_x ~ normal(0, 1);
    beta_age ~ normal(0, 1);
}
# Define model parameter
np.random.seed(123)
v = 90
Q = 13
n = 100
group_vec = np.concatenate([
    np.repeat(0, 7),
    np.repeat(1, 6),
    np.repeat(2, 8),
    np.repeat(3, 8),
    np.repeat(4, 5),
    np.repeat(5, 2),
    np.repeat(6, 8),
    np.repeat(7, 5),
    np.repeat(8, 10),
    np.repeat(9, 7),
    np.repeat(10, 8),
    np.repeat(11, 5),
    np.repeat(12, 11)
])
group_size = np.array([7, 6, 8, 8, 5, 2, 8, 5, 10, 7, 8, 5, 11])

# Define the latent means
latent_mu = np.zeros([n,int(Q*(Q-1)/2+Q)])
for i in range(n):
    latent_mu[i] = np.random.normal(0, 1, int(Q*(Q-1)/2+Q))
latent_mu_matrix = np.zeros([n, Q, Q])
for i in range(n):
    ind = np.triu_indices(Q)
    latent_mu_matrix[i, ind[0], ind[1]] = latent_mu[i]
latent_sigma = 1

sub_id = np.tile(np.arange(1, n+1),2)
# generate the connectivity matrix
# generate a_i_jl
A = np.zeros([n, v, v])

for i in range(n):
    for j in range(v):
        for l in range(v):
            if j == l:
                A[i,j,l] = 5
            elif j < l:
                A[i,j,l] = np.random.normal(latent_mu_matrix[i, group_vec[j], group_vec[l]], latent_sigma)
    ind_lower = np.tril_indices(v, -1)
    A[i][ind_lower] = A[i].T[ind_lower]

ind = np.triu_indices(v, k=1)
a_jl = np.zeros([n, int(v*(v-1)/2)])
for i in range(n):
    a_jl[i] = A[i][ind]
    
for i in range(n):
    triui = np.triu_indices_from(np.triu(A[i]), k=1)
col_names = []
for i in range(len(triui[0])):
    col_names.append('ROI' + str(triui[0][i]) + '_' + str(triui[1][i]))


triui_net = np.triu_indices_from(np.triu(np.zeros([Q,Q])), k=0)
net_net_df = pd.DataFrame({'net_1':triui_net[0], 'net_2':triui_net[1]}, dtype=int)
net_net_df['net_net'] = np.arange(0, int(Q*(Q-1)/2+Q), dtype=int)
net_net_names = []
for i in range(len(triui_net[0])):
    net_net_names.append('NET' + str(triui_net[0][i]) + '_' + str(triui_net[1][i]))
# generate y
latent_index = [0, 5, 9, 29, 44, 48, 50, 79, 85, 87]
random_sigma = 1.5
sigma = 2
T = np.array([0,1])
age = np.repeat(T, n)
N = n*len(T)
beta_latent = np.zeros(int(Q*(Q-1)/2+Q))
for i in range(int(Q*(Q-1)/2+Q)):
    if i in latent_index:
        beta_latent[i] = 5

x1 = np.random.normal(0, 1, n)
x2 = np.random.binomial(1, 0.5, n)
intercept = np.ones(n)
X = np.column_stack((intercept, x1, x2))
beta_x = np.array([1, 1, 1])

beta_age = 10
y = np.tile(np.random.normal(0, random_sigma, n), 2) + np.tile(np.dot(X, beta_x), 2) +np.tile(np.dot(latent_mu, beta_latent), 2) + beta_age*age + np.random.normal(0, sigma, N)

import nest_asyncio
nest_asyncio.apply()
import pandas as pd
import cmdstanpy


simulated_data = {'n': n,
                  'N': N,
                  'y': y,
                  'age': age,
                  'subj':sub_id,
                  'X': X,
                  'P': 2,
                'v':90,
                'Q':5,
                'A':A,
                'regularized':False,
                  'nu':1,
                  'global_scale':2,
                  'global_df':4,
                  'slab_scale':2,
                  'slab_df':4}

stan_model = cmdstanpy.CmdStanModel(stan_file='stan/lmm_latent_net.stan')


fit = stan_model.sample(simulated_data,  chains=1, iter_sampling=1000, show_console=True, seed=611)

You’ll want to change as many of your for loops into matrix operations as possible. If that doesn’t give you enough of a boost in speed, you can look into partially or fully parallelizing your code within-chain using reduce_sum and a partial sum function.