I’m trying to build a stan model with marginalization of discrete parameter. However, my discrete parameter is complex. My input data are a 9090 correlation matrix for each subject, which measures the correlation between regions. These 90 regions can be divided into 13 mutully exclusive groups. Then, I want to create 13(13-1)/2+13=91 (This is the upper triangular of a 1313 matrix, since we assume latent parameter group 1&j = groupj&i) latent group-level connection variables. we set the likelihood that all connectivity values in the 9090 matrix are from the same normal distribution with the mean \mu_ij where i and j are group indices. Below is my code, it works, but it’s just super slow… I think its probably because I’m useing too many for loops here. Does anyone have any ideas on how to speed it up? Thank you!
functions {
#include /fun/hs.fun
// Find the number of elements in the vector x that equal real number y
int num_matches(vector x, real y) {
int n = 0;
for (i in 1:rows(x))
if (x[i] == y)
n += 1;
return n;
}
// Find the indexes of the elements in the vector x that equal real number y
array[] int which_equal(vector x, real y) {
array[num_matches(x, y)] int match_positions;
int pos = 1;
for (i in 1:rows(x)) {
if (x[i] == y) {
match_positions[pos] = i;
pos += 1;
}
}
return match_positions;
}
int get_flattened_index(int n, int i, int j) {
return (i * (2 * n - i + 1)) / 2 + (j - i);
}
}
data {
int<lower=1> n; // number of subjects
int<lower=1> N; // number of data points
vector[N] y; // outcomes
int<lower=0> P; // number of covariates
matrix[n,P+1] X; // design matrix of intercept and covariates
vector[N] age; // age of subjects
array[N] int subj; // subject id
int<lower = 1> v; // number of brain regions
int<lower = 1> Q; // number of networks
array[n] matrix[v,v] A; // connectivity matrix
// whether the regularized horseshoe should be used
int<lower=0, upper=1> regularized;
// degrees of freedom for the half-t priors on lambda
real<lower=1> nu;
// scale for the half-t prior on tau
real<lower=0> global_scale;
// degrees of freedom for the half-t prior on tau
real<lower=1> global_df;
// slab scale for the regularized horseshoe
real<lower=0> slab_scale;
// slab degrees of freedom for the regularized horseshoe
real<lower=0> slab_df;
}
transformed data{
}
parameters{
matrix[n, Q*(Q-1)/2+Q] m_qr; // mean of each latent network mediator
real<lower=0> sigma_qr; // variance of each latent network mediator
vector[P+1] beta_x; // coefficients for m to y design matrix
real beta_age;
// vector[Q*(Q-1)/2+Q] beta_m; // coefficients for m to y mediators
real<lower=0> epsilon; // variance of model
vector[n] U; // random effects
real<lower=0> sigma_U; // variance of random effects
real<lower=0> tau;
vector<lower=0>[Q*(Q-1)/2+Q] lambda_0;
// vector<lower=0>[Q*(Q-1)/2+Q] lambda_1;
vector[Q*(Q-1)/2+Q] noise_0;
// vector[Q*(Q-1)/2+Q] noise_1;
array[v] simplex[Q] ROI_net; // network prob of each ROI
}
transformed parameters{
vector[Q*(Q-1)/2+Q] beta_m;
beta_m = hs(noise_0, lambda_0, tau);
// vector[Q*(Q-1)/2+Q] beta_m_age;
// beta_m_age = hs(noise_1, lambda_1, tau);
}
model{
for (i in 1:v){
ROI_net[i] ~ dirichlet(rep_vector(1, Q));
}
int m_idx;
for (k in 1:v){
for (j in k+1:v){
for (q in 1:Q){
for (r in 1:Q){
if (q <= r){
m_idx = get_flattened_index(Q, q, r);
target += log(ROI_net[k,q]) + log(ROI_net[j,r]) + normal_lpdf(A[:, k, j] | m_qr[:,m_idx], sigma_qr);
}else{
m_idx = get_flattened_index(Q, r, q);
target += log(ROI_net[k,q]) + log(ROI_net[j,r]) + normal_lpdf(A[:, k, j] | m_qr[:,m_idx], sigma_qr);
}
}
}
}
}
real mu;
for (i in 1:N){
mu = X[subj[i],]*beta_x + age[i]*beta_age + m_qr[subj[i],]*beta_m + U[subj[i]];
target += normal_lpdf(y[i] | mu, epsilon);
// target += normal_lpdf(y[i] | X[subj[i],]*beta_x + age[i]*beta_age + m_qr[subj[i],]*beta_m + age[i]*m_qr*beta_m_age, epsilon);
}
lambda_0 ~ student_t(nu, 0, 1);
// lambda_1 ~ student_t(nu, 0, 1);
tau ~ student_t(global_df, 0, global_scale * epsilon);
noise_0 ~ std_normal();
// noise_1 ~ std_normal();
epsilon ~ inv_gamma(0.1, 0.1);
U ~ normal(0, sigma_U);
sigma_U ~ inv_gamma(0.1, 0.1);
sigma_qr ~ inv_gamma(1, 1);
beta_x ~ normal(0, 1);
beta_age ~ normal(0, 1);
}
# Define model parameter
np.random.seed(123)
v = 90
Q = 13
n = 100
group_vec = np.concatenate([
np.repeat(0, 7),
np.repeat(1, 6),
np.repeat(2, 8),
np.repeat(3, 8),
np.repeat(4, 5),
np.repeat(5, 2),
np.repeat(6, 8),
np.repeat(7, 5),
np.repeat(8, 10),
np.repeat(9, 7),
np.repeat(10, 8),
np.repeat(11, 5),
np.repeat(12, 11)
])
group_size = np.array([7, 6, 8, 8, 5, 2, 8, 5, 10, 7, 8, 5, 11])
# Define the latent means
latent_mu = np.zeros([n,int(Q*(Q-1)/2+Q)])
for i in range(n):
latent_mu[i] = np.random.normal(0, 1, int(Q*(Q-1)/2+Q))
latent_mu_matrix = np.zeros([n, Q, Q])
for i in range(n):
ind = np.triu_indices(Q)
latent_mu_matrix[i, ind[0], ind[1]] = latent_mu[i]
latent_sigma = 1
sub_id = np.tile(np.arange(1, n+1),2)
# generate the connectivity matrix
# generate a_i_jl
A = np.zeros([n, v, v])
for i in range(n):
for j in range(v):
for l in range(v):
if j == l:
A[i,j,l] = 5
elif j < l:
A[i,j,l] = np.random.normal(latent_mu_matrix[i, group_vec[j], group_vec[l]], latent_sigma)
ind_lower = np.tril_indices(v, -1)
A[i][ind_lower] = A[i].T[ind_lower]
ind = np.triu_indices(v, k=1)
a_jl = np.zeros([n, int(v*(v-1)/2)])
for i in range(n):
a_jl[i] = A[i][ind]
for i in range(n):
triui = np.triu_indices_from(np.triu(A[i]), k=1)
col_names = []
for i in range(len(triui[0])):
col_names.append('ROI' + str(triui[0][i]) + '_' + str(triui[1][i]))
triui_net = np.triu_indices_from(np.triu(np.zeros([Q,Q])), k=0)
net_net_df = pd.DataFrame({'net_1':triui_net[0], 'net_2':triui_net[1]}, dtype=int)
net_net_df['net_net'] = np.arange(0, int(Q*(Q-1)/2+Q), dtype=int)
net_net_names = []
for i in range(len(triui_net[0])):
net_net_names.append('NET' + str(triui_net[0][i]) + '_' + str(triui_net[1][i]))
# generate y
latent_index = [0, 5, 9, 29, 44, 48, 50, 79, 85, 87]
random_sigma = 1.5
sigma = 2
T = np.array([0,1])
age = np.repeat(T, n)
N = n*len(T)
beta_latent = np.zeros(int(Q*(Q-1)/2+Q))
for i in range(int(Q*(Q-1)/2+Q)):
if i in latent_index:
beta_latent[i] = 5
x1 = np.random.normal(0, 1, n)
x2 = np.random.binomial(1, 0.5, n)
intercept = np.ones(n)
X = np.column_stack((intercept, x1, x2))
beta_x = np.array([1, 1, 1])
beta_age = 10
y = np.tile(np.random.normal(0, random_sigma, n), 2) + np.tile(np.dot(X, beta_x), 2) +np.tile(np.dot(latent_mu, beta_latent), 2) + beta_age*age + np.random.normal(0, sigma, N)
import nest_asyncio
nest_asyncio.apply()
import pandas as pd
import cmdstanpy
simulated_data = {'n': n,
'N': N,
'y': y,
'age': age,
'subj':sub_id,
'X': X,
'P': 2,
'v':90,
'Q':5,
'A':A,
'regularized':False,
'nu':1,
'global_scale':2,
'global_df':4,
'slab_scale':2,
'slab_df':4}
stan_model = cmdstanpy.CmdStanModel(stan_file='stan/lmm_latent_net.stan')
fit = stan_model.sample(simulated_data, chains=1, iter_sampling=1000, show_console=True, seed=611)