Reduce memory usage for PPC in multilevel models

Hi,

I’m fitting linear mixed models using raw Stan code and pystan. My data are reaction times transformed on the log scale and distributed across 16 participants. Each participant has roughly 2400 trials distributed across 3 experimental factors (2 x 2 x 6, hence 24 cells -> 100 trial / cell) so the total number of observations is roughly 38000 data points. The model fits well and behave as expected and is fitted on 6 chains with 2500 iterations (including 1000 warmup ).


data {
  int<lower=0> N;               //N observations total
  int<lower=1> P;               //N population-level effects  (intercept + beta)
  int<lower=0> J;               //N subj
  int<lower=1> n_u;             //N subj ranefs (ex. x * x | part = 4) if random effect on all effects, same argument as P can be passed
  int<lower=1,upper=J> subj[N]; //subj ID vector
  matrix[N,P] X;                //fixef design matrix
  matrix[N, n_u] Z_u;           //subj ranef design matrix, X if raneff = fixeff
  vector[N] y;                  // DV
  vector[2] p_intercept;        // prior on intercept
  vector[P-1] p_fmu;            // priors on fixef mu
  vector[P-1] p_fsigma;         // priors on fixef sigma
  vector[2] p_r;                // priors on ranef
  int<lower=0, upper=1> logT;                     // are the data log transformed ?
}

transformed data {
  matrix[N,P-1] X_beta; // removing intercept
  X_beta = block(X,1,2,N,P-1); // returns X from row 1 and col 2 to N rows and P-1 columns
}

parameters {
  real alpha;                   // intercept
  real<lower=0> sigma;          //residual std
  vector[P-1] beta;             //population-level effects coefs (w\o intercept)
  cholesky_factor_corr[n_u] L_u;//cholesky factor of subj ranef corr matrix
  vector<lower=0>[n_u] sigma_u; //subj ranef std INCLUDING INTERCEPT, hence beta[1] ~ sigma[2]   
  vector[n_u] z_u[J];           //spherical subj ranef
}

transformed parameters {
  vector[n_u] u[J];             //subj ranefs
  {
    matrix[n_u,n_u] Sigma_u;    //subj ranef cov matrix
    Sigma_u = diag_pre_multiply(sigma_u,L_u);
    for(j in 1:J)
      u[j] = Sigma_u * z_u[j];
  }
}

model {
  vector[N] mu;
  //priors
  L_u ~ lkj_corr_cholesky(2.0);
  alpha ~ normal(p_intercept[1], p_intercept[2]);
  beta ~ normal(p_fmu, p_fsigma); 
  for (j in 1:J)
    z_u[j] ~ normal(p_r[1],p_r[2]);
    
  //likelihood
  for (n in 1:N)  
    mu[n] = alpha + X_beta[n] * beta + Z_u[n] * u[subj[n]];
  y ~ normal(mu, sigma);
}

The problem is when I add the generated quantities block that tries to get the log likelihood and the predicted value of each data point .

generated quantities {
  matrix[n_u,n_u] Cor_u;        
  vector[N] log_lik;            
  vector[N] y_hat;              // predicted y
  real raw_intercept;           //raw intercept if log values
  vector[P-1] raw_beta;         //raw effect size if log values
  Cor_u = tcrossprod(L_u);      //Correlations between random effects by subj
  if (logT == 1) { 
    raw_intercept = exp(alpha);
    raw_beta = exp(alpha + beta) - raw_intercept;
  } 
  for (n in 1:N){           
    log_lik[n] = normal_lpdf(y[n] | alpha + X_beta[n] * beta + Z_u[n] * u[subj[n]], sigma);
    y_hat[n] = normal_rng(alpha + X_beta[n] * beta + Z_u[n] * u[subj[n]], sigma);
  }	
}

When I run the model with the generated quantities block, obviously the size of the fitted object gets unreasonably large (7.4 Gb) and, as I am fitting 2 different model versions to 4 measures, I cannot load all the 8 models for my post-fit analysis (models are fitted on an HPC but supposed to be read on a normal RAM sized PC and publicly shared afterwards).

The reason I need these outputs is :

  1. To use posterior predictive checks
  2. To select the best fitting model (from the 2 version used) for each measure using WAIC

Hence my question : is there a way to drastically reduce memory usage while keeping the log likelihood and the predicted values for all the datapoints ? Or should I be more reasonable and somehow use summary statistics ?

Don’t use WAIC. Use LOO. Don’t calculate log_lik in the generated quantities block. Pass a function to loo that returns the log-likelihood contribution for the i-th observation for all posterior draws as documented

function : A function f that takes arguments data_i and draws and returns a vector containing the log-likelihood for a single observation i evaluated at each posterior draw. The function should be written such that, for each observation i in 1:N , evaluating f(data_i = data[i,, drop=FALSE], draws = draws) results in a vector of length S (size of posterior sample). The log-likelihood function can also have additional arguments but data_i and draws are required.
If using the function method then the arguments data and draws must also be specified in the call to loo :

Thanks for your answer, unfortunately using the loo package would bring me out of the python environment I am using. I would prefer a solution that could be implemented in python.

But supposing that this is the only solution, I can use the loo package and do PPC on summary statistics instead of the data points, am I right?

If the loo library for Python does not accept functions, then an issue should be filed against it. Doing posterior predictive checks on aggregates is possible, but I think it would tend to overlook a lot of misfitting at the observation level.

Are you using Arviz in Python? @ahartikainen might know if it already supports function argument for loo.

Currently we don’t support it.

We would either need the standalone generated quantities for PyStan (see https://github.com/stan-dev/pystan/issues/416), possibility to call Stan functions from Python (see https://github.com/stan-dev/pystan/issues/409) or user would need to create a custom function with python tools to create pp samples (this is prone to user errors; different parametrizations).

Can you transform your PyStan fit to InferenceData object and save to netcdf4? It has compression on by default so filesize might go down a bit.

1 Like

Wow you were right @ahartikainen, I converted the fit to netCDF and it does compress a lot, from 7.4Gb to 1.6Gb ! That size is way easier to handle. Thanks a lot for all your replies.

2 Likes