So I posted a while back on my dissatisfaction with the increasing constraint to zero that the multivariate normal imposes in high-dimensional scenarios, and I also posted an alternative that I thought was somewhat promising, if more in the “sufficient stats” spirit than generative. It seemed to work well in the non-hierarchical case I presented then, but just this weekend I started working on applying it to the hierarchical case and discovered that it’s even worse in it’s bias of correlations to zero! As a sanity check (that I should have started with) I coded up a “sample from the prior” model only to discover that I’ve clearly done something fundamentally wrong that I’m hoping someone can point out.
To back up a bit, the standard multivariate description of the data would be that we have a latent set of n samples from a k dimensional multivariate normal, and we observe each of the n \times k latent values some number of times with some degree of measurement noise, yielding the observation level data. If we want to check that stan was producing the priors we intended for just the latent parts of this model, we’d run simply:
#include helper_functions.stan // for flatten_lower_tri()
data{
int<lower=1> n ;
int<lower=1> k ;
}
transformed data{
int num_r = (k*(k-1)) %/% 2 ; // number of correlations implied by k
}
parameters{
row_vector[k] v[n] ; // latent values
row_vector[k] v_m ; // mean for each k
vector<lower=0>[k] v_s ; // sd for each k
cholesky_factor_corr[k] r_ ; // correlations among k's (on Cholesky-factor scale)
}
model{
// priors
v_m ~ std_normal() ;
v_s ~ weibull(2,1) ;
r_ ~ lkj_corr_cholesky(1) ;
// "centered" parameterization of v
v ~ multi_normal_cholesky(
v_m
, diag_pre_multiply(v_s,r_)
) ;
}
generated quantities{
// r: the upper-tri of the correlation matrix, flattened to a vector for efficient storage
vector[num_r] r = flatten_lower_tri( multiply_lower_tri_self_transpose(r_) ) ;
}
Which we could sample then inspect that we get out the expected distributions for the parameters given their priors.
Ok, now the pairwise model seeks to avoid using the multi_normal
structure to inform on the correlations but instead iterates over the pairs of k
, computes the “empirical” correlation and uses the expected sampling distribution for a correlation to increment the target, which I’ve implemented in a “check the priors” model as:
#include helper_functions.stan // for flatten_lower_tri()
data{
int<lower=1> n ;
int<lower=1> k ;
}
transformed data{
int num_r = (k*(k-1)) %/% 2 ; // number of correlations implied by k
}
parameters{
matrix[k,n] v_ ; // helper for non-centered latent values
vector[k] v_m ; // mean for each k
vector<lower=0>[k] v_s ; // sd for each k
vector<lower=-1,upper=1>[num_r] r ; // correlations among k's
}
transformed parameters{
matrix[k,n] v ; // latent values
for(i_k in 1:k){
v[i_k] = v_m[i_k] + v_[i_k] * v_s[i_k] ;
}
}
model{
// Priors
v_m ~ std_normal() ;
v_s ~ weibull(2,1) ;
r ~ uniform(-1,1) ;
// non-centered parameterization for v
to_vector(v_) ~ std_normal() ; // implies v[i_k] ~ normal(v_m[i_k],v_s[i_k])
// Now for the correlations
// First transform each row in v by the *empirical* mean & sd
matrix[k,n] v_transformed ;
for(i_k in 1:k){
v_transformed[i_k] = (v[i_k]-mean(v[i_k]))/sd(v[i_k]) ;
}
// Next iterate over pairs computing the *empirical* correlation
vector[num_r] empirical_r ;
int r_index = 0 ;
for(i_k in 1:(k-1)){
for(j_k in (i_k+1):k){
r_index = r_index + 1 ;
empirical_r[r_index] = (
dot_product(
v_transformed[i_k]
,v_transformed[j_k]
)/(n-1)
) ;
}
}
// Finally, increment target using the sampling distribution of
// an emprical correlation:
atanh(empirical_r) ~ normal( atanh(r), 1/sqrt(n-3) ) ;
}
However, while sample runs pass the standard diagnostics, the resulting distribution of the correlations are far from the expected uniform:
So I’m obviously doing something wrong.
My first thought is that these correlations seem to be consistent with the sampler exploring values for v
that aren’t informed at all by the uniform prior for the correlation, instead sampling v as if it had an identity correlation matrix, yet the samples for r
get informed by the values in v
and thereby cluster to around zero. That is, while r
and v
should mutually inform, there seems to be a strong direction of information from v
to r
but weak or completely absent direction of information from r
to v
. But even if this is an accurate insight as to what’s going on, I have no idea how to fix it. Anyone have any ideas?
R code for those that want to explore:
library(tidyverse)
mod = cmdstanr::cmdstan_model(
'stan/hierarchical_pairwise_prior_check.stan'
, include = 'stan'
)
iter_warmup = 2e3 # for good measure
iter_sampling = 1e3
parallel_chains = parallel::detectCores()/2
fit = mod$sample(
data = list(n=100,k=4)
, chains = parallel_chains*2 # for good measure
, parallel_chains = parallel_chains
, iter_warmup = iter_warmup
, iter_sampling = iter_sampling
, seed = 1
, refresh = (iter_warmup+iter_sampling)/10
, init = .1
)
# check diagnostics
fit$cmdstan_diagnose()
# viz the correlations
(
fit$draws(
variables = 'r'
, format = 'data.frame'
)
%>% pivot_longer(
cols = c(-.chain,-.iteration,-.draw)
)
%>% ggplot()
+ facet_grid(name~.)
+ geom_histogram(
mapping = aes(x=value)
)
+ scale_x_continuous(
limits=c(-1,1)
, expand = c(0,0)
, breaks = seq(-1,1,by=.2)
)
)