Marginal predictions and cross-validation with latent variable model

Hi all,

I am working on a generalised linear latent variable model to predict the joint abundance of different species at different sites. The model uses a factor analytic approach to allow the predicted abundances of each species to co-vary, through the inclusion of D site-specific uncorrelated latent factors, and a corresponding loading matrix \Lambda, which describes the response of each species to each latent factor.

The model is as follows:

y_{ij} \sim \textrm{NegBinomial}(\mu_{ij}, \kappa)\\ log(\mu_{ij}) = \alpha_i + \beta_{0j} + \beta_j x_i + z_i \lambda_j \\ \textrm{where} \space z_i \sim \textrm{Normal}(0, 1)

Where y_{ij} is the count of species j at site y, \alpha_i is a site-specific varying intercept, \beta_{0j} is a species-specific varying intercept, \beta_j is a vector of varying species-specific coefficients for site-level covariates x_i, and \lambda_j is the response of species j to the site-specific latent factor x_i.

I have written the model in Stan as follows, currently leaving out the covariates for an intercept-only model to aid development while I build my understanding:

data {
  int<lower=1> N; //Number of samples
  int<lower=1> S; //Number of species
  int<lower=1> D; //Number of latent dimensions

  array[N, S] int Y; //Species matrix
}
transformed data{
  // Number of non-zero lower triangular factor loadings
  // Ensures identifiability of the model - no rotation of factors
  int<lower=1> M; 
  M = D * (S - D) + D * (D - 1) / 2;
}
parameters {
  // Site intercepts
  real a_bar;
  real<lower=0> sigma_a;
  vector[N] a;
  
  // Species intercepts
  real<lower=0> sigma_b0;
  vector[S] b0;
  
  // Factor parameters
  vector[M] L_lower; // lower triangle of species loadings
  vector[D] L_diag; // Diagonal of species loadings
  real<lower=0> sigma_L; // variance of species loadings
  
  // Latent variables
  matrix[D, N] LV; // Per-site latent variable

  // NegBin parameters
  real<lower=0> kappa;
}
transformed parameters {
  matrix[S, D] Lambda;
  
  // Assign parameters to L matrix:
  {
    int idx2; // Index for the lower diagonal loadings
    idx2 = 0;

    // Constraints to allow identifiability of loadings
  	 for (i in 1:(D-1)) { for (j in (i+1):(D)){ Lambda[i,j] = 0; } } // 0 on upper diagonal
  	 for (i in 1:D) Lambda[i,i] = L_diag[i]; // Positive values on diagonal
  	 for (j in 1:D) {
  	   for (i in (j+1):S) {
  	     idx2 = idx2+1;
  	     Lambda[i,j] = L_lower[idx2];
  	   }
  	 }
  }
}
model {
  // Factor priors
  to_vector(LV) ~ std_normal();
  L_lower ~ std_normal();
  L_diag ~ std_normal();
  
  // Random effect priors
  a ~ std_normal();
  b0 ~ std_normal();
  a_bar ~ std_normal();
  sigma_a ~ exponential(1);
  sigma_b0 ~ exponential(1);
  sigma_L ~ exponential(1);
  
  kappa ~ exponential(1);
  
  array[N] vector[S] mu;
  for (i in 1:N) {
      mu[i,] = exp(a_bar + a[i] * sigma_a + b0 * sigma_b0 + (Lambda * sigma_L) * LV[,i]);
      Y[i,] ~ neg_binomial_2(mu[i, ], kappa);
  }
}
generated quantities {
  // Calculate implied covariance matrix among species
  matrix[S, S] COV;
  COV = multiply_lower_tri_self_transpose(Lambda);
}

I am working with the spiders dataset from the mvabund package in R for testing, as it is conveniently small, and have also built a function to simulate data from the priors for additional testing (see Identifiability of GLLVM (factor analytic model) for code). The spiders dataset consists of counts of 12 species of spiders at 28 sites, with each site represented only once in the data, for a 28 x 12 matrix.

The model compiles and runs (though without any covariates it obviously doesn’t make very good predictions), and I am trying to figure out how to compare this simple model with more complex models once I build them, using cross-validation or LOO.

The goal behind the model would be to predict the probable range of abundance for each species at some new, unmeasured site. This suggests to me that I will have to marginalise out both the site-specific intercept \alpha_i as well as the D site-specific latent factors z_i, as otherwise prediction can only happen conditional on the estimates of these variables for specific sites. My intuition on doing that would be to generate a set of uncorrelated vectors of draws from a normal distribution, one per variable, and compute the log probability of each observation for all combinations of these vectors for each set of MCMC draws, which sounds computationally expensive!

My questions are:

  1. Would my method of estimating the marginal distribution work, and is there an easier way to approximate it?
  2. Assuming it is possible to estimate, what is the best method to perform cross-validation? Will PSIS-LOO work here, or will I have to look into K-fold methods? Ideally the solution would not involve re-fitting the model with left-out data as the dataset we are building will have possibly thousands of species across about 1,000 observations, which I suspect will be too time-consuming to run re-fits on.
1 Like
  1. To simulate the posterior predictive distribution for a new site, you should be able to get away with drawing one sample from the joint random effects distribution per MCMC draw and then simulating data based on that. But it sounds like your goal is not just to simulate the prediction, but to evaluate the log-likelihood associated with a site conditional on the hyperparameters, for use in cross-validation at the site level. The challenge, then, is that you basically need to integrate the likelihood over the joint random effects distribution, and depending on how many latent factors you are using, this could be a high-dimensional integral. I don’t know what the best way to evaluate this integral is, but my intuition is that, at a minimum, you’re better off simulating a lot of times from the joint random effects distribution than simulating relatively fewer times and then doing the computation for all combinations of draws. If nothing else, this will simplify the structure of the computation and also will make it extensible to the case where the joint random effects distribution is not assumed to be independent.

  2. My guess is that PSIS-LOO is not going to be your friend here. The issue is that there are multiple site-specific parameters in the model, and you want to do holdouts consisting of entire sites. I’m not sure whether even moment matching will be able to rescue this.

  3. However, FWIW the prediction task that you’ve outlined here seems a bit unusual to me. The whole point of the “joint species distribution model” sensu Warton et al is that you can leverage the information about other species’ abundance at a site to predict the abundance of any given species at the site. You’re suggesting to fit that model, but then deprive the prediction task of any of that extra information. Suppose you had a community that was structured along an elevational gradient. You could fit a big glmm with elevation as a covariate, and you wouldn’t hesitate to use the elevational information to inform the prediction task. Now suppose you fit the latent-variable model with a single latent factor, and that model is well-fitting and recovers a latent factor that recapitulates elevation. Your proposal now is to evaluate that model’s predictive performance while depriving it of the information encapsulated in that latent factor–that is, how well does the model predict to a new site where the elevation is unknown?

Edit to add: The prediction task that you’ve outlined seems like it’s useful if your main goal is to do inference on covariates in the model, and you think that the latent variable model is just a good way to properly clean up the residual variance. This would make perfect sense to me. On the other hand, if you are interested in the predictive value of the covariance structure encapsulated by the latent variables, then I don’t think it makes sense to define the prediction task over entirely new sites with no sampled species.

1 Like

Many thanks for your detailed response! A lot to chew on here, but I have some thoughts:

Re: point 3 and your edit, you got me thinking last night about what exactly the aims are here. In brief, we are investigating the microbiome of oak trees in the UK, and have sampled 350 trees at 30 geographically dispersed sites. The goal is to understand how the microbiome changes in response to the environmental gradients between sites, as well as with the severity of a bacterial tree disease. The disease is not present at all sites, but at all sites with the disease we have sampled both diseased and undiseased trees.

To me, this puts the focus of prediction on to the covariates, as you say - given that we know some details about a site, such as its annual rainfall or altitude, how confident can we be in saying what microbial species will be present? This inference to me would certainly suggest that marginalising over the latent site parameters would be a useful move.

Perhaps I’m completely wrong here, but my intuition was that it would be better not to integrate out the species-specific factor loadings. This would allow this uncertainty, due to missing environmental predictors as well as (potentially) interactions among species (assuming one can believe that co-occurrence = interaction, which is a dicey proposition), to be propagated to the prediction for each species at the new site.

The main proposal I had for inference on the values of the latent variables was for ordination, as shown in the paper which describes Boral (https://besjournals.onlinelibrary.wiley.com/doi/10.1111/2041-210X.12514), wherein the goal is to include a set of site-level covariates such that the latent factors no longer represent unexplained environment gradients, and the position of sites within the ordination is essentially random. Assuming it’s possible to do this with the data, integrating out the latent variables then makes more sense to me than if the latent variables are capturing useful information.

That all said, your reply prompted me to do some thinking last night. One of the benefits of the data we are collecting is that we will have multiple replicates per disease class per site. Leaving out some of this data (say two out of five trees from each class from each site) for cross-validation might be a better strategy here. That way, we can first explore whether even conditional on a known site, where we have estimates of the site intercept and latent variables, we can predict the left-out data with any degree of accuracy. Then, presuming that is possible, calculating the marginal predictions using the one draw per MCMC draw approach you suggest would allow us to extrapolate predictions to new sites, but without the added complexity of validating these predictions, which may be difficult without access to left-out data corresponding to completely new sites.

1 Like

Hi all,

I wanted to follow up on this a bit to check my understanding of the PSIS-loo results I’m getting for this model. I’ve extended the model above to take a matrix of predictors X, with correlated varying effects for each species for each predictor. At a quick count with 6 predictor variables (and no species-specific intercept), there are 209 parameters, which I understand is a lot compared to the number of data points (336).

Running model$loo() gives the following, with a good number of Pareto k values > 0.7 (the intercept only model above gives all points >0.7…!):

Computed from 4000 by 336 log-likelihood matrix

         Estimate   SE
elpd_loo   -674.7 34.5
p_loo        99.3  7.7
looic      1349.4 69.0
------
Monte Carlo SE of elpd_loo is NA.

Pareto k diagnostic values:
                         Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     180   53.6%   208       
 (0.5, 0.7]   (ok)       103   30.7%   70        
   (0.7, 1]   (bad)       47   14.0%   13        
   (1, Inf)   (very bad)   6    1.8%   6         
See help('pareto-k-diagnostic') for details.

Having done a bit of searching around Discourse, I found @avehtari’s comment below which I think sums up what’s going on here:

Here, p_loo is about half the actual number of parameters, which is nice (and suggests some regularisation is happening), but the actual number of parameters is much bigger than n/5 (67.2), which I take to read that the model is highly flexible.

To test this, I added more strongly regularising priors on the variance components (going from exponential(1) to exponential(2)), which reduced the number of observations with Pareto k > 0.7, suggesting that the problems are to do with model flexibility:

Computed from 4000 by 336 log-likelihood matrix

         Estimate   SE
elpd_loo   -676.6 34.4
p_loo        98.7  7.7
looic      1353.1 68.9
------
Monte Carlo SE of elpd_loo is NA.

Pareto k diagnostic values:
                         Count Pct.    Min. n_eff
(-Inf, 0.5]   (good)     174   51.8%   212       
 (0.5, 0.7]   (ok)       111   33.0%   46        
   (0.7, 1]   (bad)       49   14.6%   8         
   (1, Inf)   (very bad)   2    0.6%   12        
See help('pareto-k-diagnostic') for details.

However, increasing these rate parameters further in general seems a poor solution to the problem, as 1) increasing them to very high levels doesn’t remove high Pareto k values, and also results in high numbers of divergent transitions, and 2) highly constrained variances in general would in my mind not fall in line with domain-specific knowledge, where species often differ strongly in their responses to the environment - the model has to allow the flexibility to estimate these accurately.

Is my understanding correct that this is a problem of model flexibility? For a too-flexible model, can these problems with PSIS-LOO be side-stepped by using other cross-validation techniques (e.g. k-fold CV), or are these also invalid under these conditions? It would be good in general (outside of my solution above using a permanently held-out sample) to be able to compare the predictive ability of these models with different covariates and increasing complexity (an option I’d like to explore are phylogenetic varying effects), where the data perhaps don’t so easily permit a hold-out set.

1 Like

You absolutely can use other cross-validation techniques. k-fold CV could be a good option; LOO with moment matching could also be a good option.

The PSIS-LOO procedure is an approximation to brute-force leave-one-out CV. It uses the fitted posterior and some black magic to approximate what the ELPD for a point would be if that point had been left out during model fitting. If that point has too much influence on the fitted posterior, then the black magic isn’t powerful enough to work. The Pareto-k diagnostic tells us if this is likely to be happening.

It turns out that one of the ways for a point to be influential is if the model is misspecified. Consider a linear regression where one point has extremely high leverage. If the true relationship is perfectly linear and homoskedastic, then the posterior shouldn’t change much regardless of whether this point is included. But if the model is misspecified, then the estimates for the slope and intercept might change a lot depending on whether a single point with high leverage is included.

Another way for a point to be influential is if the model is flexible. Consider the case of an overdispersed Poisson regression, where overdispersion is modeled as a Gaussian (on the link scale) random effect of observation ID. Now the posterior for the random effect parameter for a given observation depends strongly on the observed value, but the posterior for that parameter if the observation is left out is just drawn from the random-effect hyperparameters. So the PSIS-LOO black magic can’t deal with this and you see a high Pareto-k. There’s nothing about this situation that suggests that the model is misspecified or that cross-validation is invalid. It’s just that the approximations made in PSIS-LOO are very poor.

Maybe @avehtari has more to add.

3 Likes

Your answer is perfect!

Many thanks @jsocolar - this clears things up a lot!

I played around a bit with LOO with moment matching - in this case, it didn’t appear to help (no change in the distribution of pareto k values), so I’ll give k-fold CV a go - I may be back with more questions soon!

Not sure whether this will be relevant to your question, but if you haven’t already looked at it, I found this paper on marginalizing very useful when thinking about how to compare multivariate hierarchical models with different numbers of explanatory variables:

In my case, the aim was model comparison rather than prediction, for clusters (equivalent to your sites, I think) with known covariate values, integrating over cluster-specific parameters, and not assuming we have any information about the species present in a new cluster.

My recollection is that I didn’t end up using their adaptive quadrature method, but instead did classical Monte Carlo, using lots of draws from the distribution over which I wanted to integrate, for each Stan iteration. It was quite expensive (I think I did the integration in C++ rather than R), and better quadrature implementation than the one I tried might well be preferable.

Hello @prototaxites, I am glad I found your post here because I met a very similar problem with large Pareto K using the confirmatory factor analysis model. So I wonder if your k-fold CV works or not. Or maybe you have some successful experiences to share with us?

@Matthew_Spencer1 this marginal likelihood seems the right thing to do. But I was wondering if it is possible to use PSIS with marginal likelihoods.

My ultimate goal is to use the Bayesian Stacking method to handle the multimodality issue. But I was stuck on the large Pareto K using the PSIS_LOO method. If anyone has any ideas I would be much grateful!

Yes, it’s possible. See, e.g. Section 3.6.1. in Bayesian Leave-One-Out Cross-Validation Approximations for Gaussian Latent Variable Models

Roaches vignette has had the example of Poisson with random effects (Section 4). I now also added illustration of integrating out that latent individual specific parameter for PSIS-LOO (Section 5). In this case, the integration is easy as it’s just 1D (per observation). 2D should be feasible, too, but with more parameters per group, it will get more challenging. I hope the added example helps at least conceptually.

1 Like

Hi @Chen_Chen, I’m afraid I ended up dropping this due to lack of time (as well as not having any data yet!) - so nothing to report back yet. However, I will need to figure this out in time!

Hi Aki,

Many thanks for this - this is really illuminating! I had wondered if it would be possible to use the Stan integration functions to do this, so I’m pleased to see it’s possible. Can I ask how one would extend it to the 2D case? I had a look through the Stan documentation and didn’t see a 2-dimensional integration function, so I presume there’s another (probably more complicated!) way to do this. My first intuition would be to pass through the 1D integration function twice, but I don’t think this would capture the cross-relationship between the two integrated variables?

1 Like

Nested use of 1D integration. This is what common 2D adaptive quadrature functions do internally, and the nested use capture the cross-relationship. It gets bit ugly in Stan code, but it’s doable. As it’s done only in generated quantities, the speed is not that big issue (unless number of observations is very big).

If the priors are normal (as in my example), it’s likely that the adaptive quadrature can be replaced with non-adaptive nested Gauss-Hermite quadrature (and nesting not needed if the correlation is low), and in 2D a very good accuracy would be obtained probably with 256 quadrature points (and might work even with less).

2 Likes

That makes sense - I can see how having many nested integrations could get needlessly complex quite quickly, too.

Question about your adaptive quadrature code in the roaches example - the integrand function calculates the log probability, but exponentiates this to return the probability on the original scale. Then, you call the integrate_1d function and take the log of the output to calculate log_lik. Is there a reason you have to work on the standard probability scale for the integration? From my understanding of probabilities versus log-probabilities, it seems like you could run into numerical problems reasonably easy with this approach.

1 Like

remember that integration is like a sum. This is effectively doing log_sum_exp except its “log_integrate_exp”.

2 Likes

I tested. 32 Gauss-Hermite quadrature points are far from being enough in the roaches example. In general, it’s safer to use adaptive quadrature.

1 Like

I’ve been playing around a bit trying to get this set up, and am running into what looks like some overflow issues with the two latent variables. Code is below (written to help clarify which elements are which in my mind, not for speed!).

functions {
    real integrate_2(real l2, real notused, array[] real theta,
               data array[] real Xi, data array[] int yi) {
    real intercept = theta[1];
    real beta = theta[2];
    real kappa = theta[3];
    vector[2] lambda = to_vector(theta[4:5]);
    real l1 = theta[6];

    return exp(normal_lpdf(l2 | 0, 1) + 
      normal_lpdf(l1 | 0, 1) + 
      neg_binomial_2_log_lpmf(yi | intercept + beta + lambda[1] * l1 + lambda[2] * l2, kappa));
  }
  
  real integrate(real l1, real notused, array[] real theta,
               data array[] real Xi, data array[] int yi) {
    
    real log_lik = log(
      integrate_1d(integrate_2,
        negative_infinity(),
        positive_infinity(),
        append_array(theta,{l1}),
        Xi,
        yi)
      );
          
    return exp(log_lik);
  }
}
data {
  int<lower=1> N; //Number of samples
  int<lower=1> S; //Number of species
  int<lower=1> D; //Number of latent dimensions

  array[N, S] int Y; //Species matrix
}
transformed data{
  // Number of non-zero lower triangular factor loadings
  // Ensures identifiability of the model - no rotation of factors
  int<lower=1> M; 
  M = D * (S - D) + D * (D - 1) / 2 + D;
  
  array[0] real x_r;
}
parameters {
  // Site intercepts
  real a_bar;
  real<lower=0> sigma_a;
  vector[N] a;
  
  // Species intercepts
  real<lower=0> sigma_b0;
  vector[S] b0;
  
  // Factor parameters
  vector[M] L; // lower triangle of species loadings
  real<lower=0> sigma_L; // variance of species loadings
  
  // Latent variables
  matrix[D, N] LV_uncor; // Per-site latent variable

  // NegBin parameters
  real<lower=0> kappa;
}
transformed parameters {
  // Construct factor loading matrix
  matrix[S, D] Lambda_uncor;
  // Constraints to allow identifiability of loadings
  for (i in 1:(D-1)) { 
    for (j in (i+1):(D)){ 
      Lambda_uncor[i,j] = 0; 
    } 
  }
  {
    int index;
    index = 0;
    for (j in 1:D) {
      for (i in j:S) {
        index = index + 1;
        Lambda_uncor[i, j] = L[index];
      } 
    }
  }
}
model {
  // Factor priors
  to_vector(LV_uncor) ~ std_normal();
  L ~ std_normal();

  // Random effect priors
  a ~ std_normal();
  b0 ~ std_normal();
  a_bar ~ std_normal();
  sigma_a ~ exponential(1);
  sigma_b0 ~ exponential(1);
  sigma_L ~ exponential(1);
  
  // Negative Binomial scale parameter
  kappa ~ exponential(1);
  
  array[N] vector[S] mu;
  for (i in 1:N) {
      mu[i,] = a_bar + a[i] * sigma_a + b0 * sigma_b0 + (Lambda_uncor * sigma_L) * LV_uncor[,i];
      Y[i,] ~ neg_binomial_2_log(mu[i, ], kappa);
  }
}
generated quantities {
  // Calculate linear predictor, y_rep, log likelihoods for LOO
  matrix[N, S] log_lik;
  for (i in 1:N) {
    for(j in 1:S){
      real intercept = a_bar + a[i] * sigma_a;
      real beta = b0[j] * sigma_b0;
      vector[2] lambda = (Lambda_uncor[j,] * sigma_L)';

      log_lik[i, j] = log(
        integrate_1d(integrate,
          negative_infinity(),
          positive_infinity(),
          append_array(
            {intercept},
            append_array(
              {beta},
              append_array({kappa},
                to_array_1d(lambda)
                )
              )
            ),
            x_r,
          {Y[i,j]})
        );
    }
  }
}

Playing around with some print statements, values for the latent variables l1 and l2 appear to be identical and extremely high:

intercept: Chain 4 0.347085 
beta: Chain 4 -0.219371 
lambda[1]: Chain 4 2.22141 
lambda[2]: Chain 4 0 
l1: Chain 4 1.79769e+308 
l2: Chain 4 1.79769e+308 

Leading to the following error:

Chain 4 Exception: Exception: Exception: neg_binomial_2_log_lpmf: Log location parameter is inf, but must be finite! (in '/var/folders/78/j8s8yd9n3gl0z0gn5y1bzzhm0000gn/T/Rtmpb1iKZS/model-11bc58175412.stan', line 18, column 4 to line 20, column 95) (in '/var/folders/78/j8s8yd9n3gl0z0gn5y1bzzhm0000gn/T/Rtmpb1iKZS/model-11bc58175412.stan', line 26, column 4 to line 33, column 8) (in '/var/folders/78/j8s8yd9n3gl0z0gn5y1bzzhm0000gn/T/Rtmpb1iKZS/model-11bc58175412.stan', line 124, column 6 to line 139, column 10) 

Am I missing something obvious in my attempt to nest these integration calls?

I’m not surprised. There is challenge as we need to integrate densities and not log-densities. I did get some warnings about the accuracy already with 1D. One possibility is to evaluate first the log density with the two parameters set to their mean values, and then subtract that log-density from the log-densities inside integrand functions. Then the max density would be 1, and there is a smaller change of over/underflows. After getting the integration result and taking lof of that, the offset would be summed back.

You are integrating from -Inf to Inf, so I guess 1.79769e+308 is approximating Inf. You could try with a finite range, if you have a good guess for that (you can get the guess from normal distributions, maybe something like ± 9*sd)

2 Likes