Hi! I’m very new to the wonderful world of Stan but also to this type of model, so I’m not too sure which of the issues I’m having are more about the model and which are more about the implementation in Stan…many of my problems might be resolved if I could just find more literature on the model.
I’m trying to implement a latent Gaussian model of categorical data. This is what I have so far
data {
int<lower=0> N; // Number of samples
int<lower=0> Dx; // The dimension of the observed variables
int<lower=0> Dz; // The dimension of the latent variables
int<lower=1> K; // The (k)cardinality of the observed variables
int<lower=0> X[N,Dx];
}
parameters {
matrix[N, Dz] Z; // The latent matrix
matrix[K, Dz] W[Dx]; // The weight matrix
vector<lower=0>[Dz] alpha; // ARD prior
}
transformed parameters{
vector<lower=0>[Dz] t_alpha;
t_alpha = inv(sqrt(alpha));
}
model {
to_vector(Z) ~ normal(0,1);
alpha ~ gamma(1e-3,1e-3);
for(dx in 1:Dx){
for(k in 1:K){
for(dz in 1:Dz){
W[dx,k,dz] ~ normal(0, t_alpha[dz]);
}
}
}
for(n in 1:N){
for(dx in 1:Dx){
target += categorical_logit_lpmf(X[n,dx] | W[dx]*Z[n]');
}
}
}
I’m not certain this is even correct but it’s also not immediately obviously wrong in so much as, at least with variational inference, I seem to get something that looks reasonable. This brings me to my first issue. Which is that I seem to be getting very different results with VB and NUTS. I generated mock data as follows.
N = 100
Dx = 10
Dz_true = 2
Dz = 10
K = 3
mu, sigma = 0, 1.0
Z = np.random.normal(mu, sigma, (N, Dz_true))
alpha = np.ones(Dz_true)
W = np.zeros((Dx,K,Dz_true))
for dx in range(Dx):
for dz in range(Dz_true):
W[dx,:,dz] = np.random.normal(0,1/alpha[dz]**0.5,K)
X = np.zeros((N, Dx), dtype=int)
for n in range(N):
for dx in range(Dx):
X[n, dx] = choices(np.arange(1,K+1), softmax(np.matmul(W[dx],Z[n].T)))[0]
data = {'N': N, 'Dx': Dx, 'Dz': Dz, 'K': K, 'X': X}
If I then try to fit the model using NUTS
fit_NUTS = model.sampling(data=data)
It takes ages to run (about 11 minutes on my laptop) and neither n_eff nor Rhat look very good. As a crude check, I tried plotting the following array
fit_NUTS_summary_df = fit_NUTS.to_dataframe()
w_columns_df = fit_NUTS_summary_df[fit_NUTS_summary_df.columns[fit_NUTS_summary_df.columns.str.contains('W[',regex=False)]]
w_means = w_columns_df.mean()
W_array = np.zeros((Dx,K,Dz),dtype=float)
for index_str, value in w_means.items():
coord = extract_coords_from_index_string(index_str)
W_array[coord[0]-1,coord[1]-1, coord[2]-1] = value
W_NUTS_df = pd.DataFrame(np.sum(W_array**2, axis = 1))
plt.figure(figsize=(15,8))
sns.heatmap(W_NUTS_df, annot=True, cmap = 'bwr', center=0);
If I have done this all correctly, then all but two columns should have vanishing values, which isn’t obviously the case. When I try again with variational inference however, it looks better.
%%time
fit_VB = model.vb(data=data)
Now the fitting only takes about 3 seconds and the results look a lot better
means_df = pd.DataFrame({'mean_par_names': fit_VB['mean_par_names'], 'mean_pars': fit_VB['mean_pars']})
W_means_df = means_df[means_df['mean_par_names'].str.contains('W')]
W_means_df.set_index('mean_par_names', inplace=True)
del W_means_df.index.name
W_array = np.zeros((Dx,K,Dz),dtype=float)
for index_str, row in W_means_df.iterrows():
coord = extract_coords_from_index_string(index_str)
#print((coord[0]-1,coord[1]-1))
W_array[coord[0]-1,coord[1]-1,coord[2]-1] = row['mean_pars']
W_df = pd.DataFrame(np.sum(W_array**2, axis = 1))
W_df = pd.DataFrame(np.sum(W_array**2, axis = 1))
plt.figure(figsize=(15,8))
sns.heatmap(W_df, annot=True, cmap = 'bwr', center=0);
So my questions are
- How can I change the model to get better results with NUTS?
- How can I change my implementation to make inference more efficient?
More generally though, if people know of other implementations of this model, or even literature on this model that they could point me towards, I would be enormously grateful.