Estimating Polychoric Correlation Matrix with Stan

Hi @JLC

Sorry to revive this thread, but I think I got an implementation working on wanted to share it in case anyone else has trouble.

I used the above truncated multivariate normal code from @bgoodri above as well as a modified dirichlet prior for the cutpoints which I adapted from @betanalpha awesome blog post.

Here is the R code to generate some sample data:

library(tidyverse)
library(cmdstanr)
library(bayesplot)

sigma <- c(1,0.5,0.5,1) %>% matrix(.,nrow=2,byrow=T)
dat <- mvtnorm::rmvnorm(1e3, sigma=sigma)

plot(dat,pch='.')

cut_points <- list(
    c(-2,1),
    c(0, 1.7)
)

discretize_data <- function(x){
  c(
    sum(x[1] > cut_points[[1]]),
    sum(x[2] > cut_points[[2]])
  )
}

discrete_data <- dat %>% apply(., 1, discretize_data) %>% t()
table(discrete_data[,1],discrete_data[,2])


data_list <- list(
  D = ncol(discrete_data), N = nrow(discrete_data),
  y = discrete_data
)
fp <- file.path('***The path to your stan code***')
mod <- cmdstan_model(fp, force_recompile = getOption("cmdstanr_force_recompile", default = F))
poly_cor <- mod$sample(data = data_list,
                           chains = 2,
                           parallel_chains = 2,
                           iter_warmup = 2000,
                           adapt_delta = 0.8,
                           save_warmup = F,
                           iter_sampling = 2000)

As well as the stan code:

functions{
    real trunc_norm(array[] int y_data, array[] vector cut_points, matrix chol_mat, array[] real u, int D_y, real static_mu){
        real j_log_collection;
        j_log_collection = 0;
        real prev;
        vector[D_y] z;
        prev = 0;
        for (d in 1:D_y) {
            if (y_data[d] == 2){
                real bound_lower;
                real t;
                bound_lower = Phi((cut_points[d,2] - (static_mu + prev)) / (chol_mat[d,d]));
                t = bound_lower + (1 - bound_lower) * u[d];
                z[d] = inv_Phi(t);
                j_log_collection += log1m(bound_lower);
            }
            else if (y_data[d] == 1){
                real bound_lower;
                real bound_upper;
                real t;
                bound_lower = Phi((cut_points[d,1] - (static_mu + prev)) / (chol_mat[d,d]));
                bound_upper = Phi((cut_points[d,2] - (static_mu + prev)) / (chol_mat[d,d]));
                t = bound_lower + (bound_upper - bound_lower)*u[d];
                z[d] = inv_Phi(t);
                j_log_collection += log(bound_upper - bound_lower);
            }
            else if (y_data[d] == 0){
                real bound_upper;
                real t;
                bound_upper = Phi((cut_points[d,1] - (static_mu + prev)) / (chol_mat[d,d]));
                t = bound_upper*u[d];
                z[d] = inv_Phi(t);
                j_log_collection += log(bound_upper);
            }
            if (d < D_y){
                prev = chol_mat[d + 1, 1:d] * head(z,d);
            }
        }
    return(j_log_collection);
    }

    real induced_dirichlet_lpdf(vector c, vector alpha, real gamma){
        int K = num_elements(c) + 1;
        vector[K - 1] cuml = Phi(c - gamma);
        vector[K] p;
        matrix[K,K] J = rep_matrix(0,K,K);

        p[1] = cuml[1];
        for (k in 2:(K-1)){
            p[k] = cuml[k] - cuml[k-1];
        }
        p[K] = 1 - cuml[K-1];

        for (k in 1:K) J[k,1] = 1;

        for (k in 2:K){
            real rho = exp(std_normal_lpdf(c[k-1] - gamma));
            J[k,k] = -rho;
            J[k - 1, k] = rho;
        }
        return dirichlet_lpdf(p | alpha) + log_determinant(J);
    }
}

data {
  int<lower=1> D;
  int<lower=0> N;
  array[N, D] int<lower=0, upper=10> y;
//   array[D] matrix[D,D] L_Omega;
}
parameters {
  cholesky_factor_corr[D] L_Omega;
  array[N,D] real<lower=0, upper=1> u;
  array[D] ordered[D] c_points;
}
model {
    L_Omega ~ lkj_corr_cholesky(4);
    target += induced_dirichlet_lpdf(c_points[1] | rep_vector(1,D + 1), 0);
    target += induced_dirichlet_lpdf(c_points[2] | rep_vector(1,D + 1), 0);

    for (n in 1:N) target += trunc_norm(y[n], c_points, L_Omega, u[n], D, 0);
}
generated quantities {
   corr_matrix[D] Omega;
   Omega = multiply_lower_tri_self_transpose(L_Omega);
}

The above code does a pretty good job recovering the cut_points and covariance matrix:

# A tibble: 5 x 10
  variable         mean  median     sd    mad      q5     q95  rhat ess_bulk ess_tail
  <chr>           <dbl>   <dbl>  <dbl>  <dbl>   <dbl>   <dbl> <dbl>    <dbl>    <dbl>
1 c_points[1,1] -2.02   -2.02   0.0881 0.0895 -2.17   -1.88    1.00    8672.    2688.
2 c_points[1,2]  0.920   0.919  0.0454 0.0464  0.847   0.994   1.00    5049.    2837.
3 c_points[2,1]  0.0172  0.0168 0.0394 0.0398 -0.0468  0.0827  1.00    9414.    3447.
4 c_points[2,2]  1.67    1.67   0.0688 0.0687  1.56    1.79    1.00    6424.    2501.
5 Omega[1,2]     0.480   0.480  0.0418 0.0412  0.409   0.548   1.00    3458.    3262.

Not sure how well it would scale for higher dimensional data, or more cutpoints, but this simple code works pretty well! Hope that helps!

3 Likes