Memory issues with custom model

Hello

I am working on a model at the moment to fit simulated data (if I have coded correctly the model should exactly match the data generating process). The problem I am having is a massive memory blow-up when I run it. I am using cmdstanr and I get the same problem whether I use variational() or sample().

By my calculations, there should be <5 MB of data (assuming int and real both take 8 bytes for C types double and long), plus there are 17 parameters (negligible impact on memory) and 40,000 transformed parameters (~0.3 MB).

When I run the model it instantly jumps to take all available memory on my machine (typically 8-10 GB) and then everything freezes up as the system tries to give it more memory.

I have separated out the mathematics of the model into functions, in particular one is called partial_sum and is designed to produce the log-likelihood for a slice of the data (so I can eventually use reduce_sum), and this in turn calls two small functions (markov and natural_history).

I would be very grateful for any help!

I have posted the full model below in case there is anything stupid I’ve done…


functions {
  // A simple function to project the state from time t0 to t
  vector markov(real   t0,
                real   t,
                vector y,
                real   ab_lambda,
                real   ab_gamma,
                real   lambda_bc,
                real   lambda_bd,
                real   lambda_ce,
                real   lambda_cz,
                real   oy_lambda,
                real   oy_gamma) {
    int N = 40;
    vector[7] result = y;
    matrix[7, 7] trans_mat;
    real t_step = (t - t0) / N;
    
    trans_mat[1, 2:3] = [0, 0];
    trans_mat[2,   3] = 0;
    trans_mat[3:6, 1] = [0, 0, 0, 0]';
    trans_mat[4,   3] = 0;
    trans_mat[5:6, 2] = [0, 0]';
    trans_mat[1:7, 4:7] = rep_matrix(0, 7, 4);
    
    for (i in 0:(N - 1)) {
      real r_exit_a;
      real r_exit_b;
      real r_exit_c;
      
      real H_ab_l = ab_lambda * (t0 + i * t_step) ^ ab_gamma;
      real H_ab_r = ab_lambda * (t0 + (i + 1) * t_step) ^ ab_gamma;
      
      real H_oy_l = oy_lambda * (t0 + i * t_step) ^ oy_gamma;
      real H_oy_r = oy_lambda * (t0 + (i + 1) * t_step) ^ oy_gamma;
      
      r_exit_a = (H_ab_r - H_ab_l) + (H_oy_r - H_oy_l);
      r_exit_b = (lambda_bc + lambda_bd) * t_step + (H_oy_r - H_oy_l);
      r_exit_c = (lambda_ce + lambda_cz) * t_step + (H_oy_r - H_oy_l);
      
      trans_mat[1, 1] = exp(-r_exit_a);
      trans_mat[2, 1] = (H_ab_r - H_ab_l) / r_exit_a * (1 - trans_mat[1, 1]);
      trans_mat[7, 1] = (H_oy_r - H_oy_l) / r_exit_a * (1 - trans_mat[1, 1]);
      
      trans_mat[2, 2] = exp(-r_exit_b);
      trans_mat[3, 2] = lambda_bc * t_step / r_exit_b * (1 - trans_mat[2, 2]);
      trans_mat[4, 2] = lambda_bd * t_step / r_exit_b * (1 - trans_mat[2, 2]);
      trans_mat[7, 2] = (H_oy_r - H_oy_l) * t_step / r_exit_b * (1 - trans_mat[2, 2]);
      
      trans_mat[3, 3] = exp(-r_exit_c);
      trans_mat[5, 3] = lambda_ce * t_step / r_exit_c * (1 - trans_mat[3, 3]);
      trans_mat[6, 3] = lambda_cz * t_step / r_exit_c * (1 - trans_mat[3, 3]);
      trans_mat[7, 3] = (H_oy_r - H_oy_l) * t_step / r_exit_c * (1 - trans_mat[3, 3]);
      
      result = trans_mat * result;
    }
    
    return result;
  }
  
  // Specification of the ODE
  vector natural_history(real   t,
                         vector y,
                         real   ab_lambda,
                         real   ab_gamma,
                         real   lambda_bc,
                         real   lambda_bd,
                         real   lambda_ce,
                         real   lambda_cz,
                         real   oy_lambda,
                         real   oy_gamma) {
    vector[7] dydt;
    real h_ab;
    real h_oy;
    
    h_ab = ab_lambda * ab_gamma * t ^ (ab_gamma - 1);
    h_oy = oy_lambda * oy_gamma * t ^ (oy_gamma - 1);
    
    dydt[1] = -(h_ab + h_oy) * y[1];
    dydt[2] = h_ab * y[1] - (lambda_bc + lambda_bd + h_oy) * y[2];
    dydt[3] = lambda_bc * y[2] - (lambda_ce + lambda_cz + h_oy) * y[3];
    dydt[4] = lambda_bd * y[2];
    dydt[5] = lambda_ce * y[3];
    dydt[6] = lambda_cz * y[3];
    dydt[7] = h_oy * (y[1] + y[2] + y[3]);
    
    return dydt;
  }
  
  // The main work of the model
  // For each patient it estimates their baseline state and then iterates over their follow-up records
  real partial_sum(int[] pt_slice,
                   int start, int end,
                   // Data
                   real[] plco,
                   int[]  recordtype,
                   int[]  event,
                   real[] t_l,
                   real[] t_r,
                   int[]  pt_first,
                   int[]  pt_records,
                   // Parameters
                   real sens,
                   real be_plco,
                   real be_cons,
                   real bl_plco,
                   real bl_cons,
                   real ab_gamma,
                   real lambda_bc,
                   real lambda_bd,
                   real lambda_ce,
                   real lambda_cz,
                   real oy_gamma,
                   real[] ab_lambda,
                   real[] oy_lambda) {
    real loglik = 0;
    
    // We have a slice of patients to calculate the likelihood for
    for (pid in start:end) {
      vector[7] state;
      
      // Baseline state model
      real beta_early;
      real beta_late;
      real denom;
      
      beta_early = be_plco * plco[pid] + be_cons;
      beta_late  = bl_plco * plco[pid] + bl_cons;
      denom      = 1 + exp(beta_early) + exp(beta_late);
      state[1]   = 1 / denom;
      state[2]   = exp(beta_early) / denom;
      state[3]   = exp(beta_late) / denom;
      state[4:7] = [0, 0, 0, 0]';
      
      // Subsequent observations
      for (n in (pt_first[pid]):(pt_first[pid] + pt_records[pid] - 1)) {
        vector[7] prior_state = state;
        
        if (recordtype[n] == 1) {
          // It's a screen
          if (event[n] == 1) {
            // Negative screen
            loglik       += log(prior_state[1] +
                                  prior_state[2] * (1 - sens));
            state[1]   = (1 - prior_state[2]) /
                          (1 - sens * prior_state[2]);
            state[2]   = 1 - state[1];
            state[3:7] = [0, 0, 0, 0, 0]';
          } else if (event[n] == 2) {
            // Screen-detected early lung cancer
            loglik += log(prior_state[2]);
          } else if (event[n] == 3) {
            // Screen-detected late lung cancer
            loglik += log(prior_state[3]);
          }
        } else if (recordtype[n] == 2) {
          // It's a follow-up
          
          // Simulate forwards
          state = markov(t_l[n],
                         t_r[n],
                         prior_state,
                         ab_lambda[pid],
                         ab_gamma,
                         lambda_bc,
                         lambda_bd,
                         lambda_ce,
                         lambda_cz,
                         oy_lambda[pid],
                         oy_gamma);
          
          if (event[n] == 4) {
            // Survive to end of period
            real d     = state[1] + state[2] + state[3];
            loglik    += log(d);
            state[1]   = state[1] / d;
            state[2]   = state[2] / d;
            state[3]   = state[3] / d;
            state[4:7] = [0, 0, 0, 0]';
          } else if (event[n] >= 5) {
            vector[7] dydt;
            dydt = natural_history(t_r[n],
                                   state,
                                   ab_lambda[pid],
                                   ab_gamma,
                                   lambda_bc,
                                   lambda_bd,
                                   lambda_ce,
                                   lambda_cz,
                                   oy_lambda[pid],
                                   oy_gamma);
            
            if (event[n] == 5) {
              // Non-screen-detected early lung cancer
              loglik += log(dydt[4]);
            } else if (event[n] == 6) {
              // Non-screen-detected late lung cancer
              loglik += log(dydt[5]);
            } else if (event[n] == 7) {
              // Death without lung cancer diagnosis
              loglik += log(dydt[6] + dydt[7]);
            }
          }
        }
      }
    
    }
    
    return loglik;
  }
}

data {
  int<lower=0>         N;
  int<lower=0>         N_pt;
  
  real                 plco[N_pt];
  int<lower=0,upper=1> cigsmok[N_pt];
  int<lower=1,upper=N> pt_records[N_pt];
  int<lower=1,upper=N> pt_first[N_pt];
  int<lower=0,upper=2> recordtype[N];
  int<lower=0,upper=7> event[N];
  real<lower=0>        t_l[N];
  real<lower=0>        t_r[N];
}

parameters {
  real<lower=0,upper=1> sens;
  
  real be_plco;
  real be_cons;
  real bl_plco;
  real bl_cons;
  
  real ab_plco;
  real ab_plco_cigsmok;
  real ab_cons;
  real<lower=1> ab_gamma;
  
  real<lower=0> lambda_bc;
  real<lower=0> lambda_bd;
  real<lower=0> lambda_ce;
  real<lower=0> lambda_cz;
  
  real oy_plco;
  real oy_plco_cigsmok;
  real oy_cons;
  real<lower=1> oy_gamma;
}

transformed parameters {
  real ab_lambda[N_pt];
  real oy_lambda[N_pt];
  
  for (n in 1:N_pt) {
    ab_lambda[n] = exp(ab_plco * plco[n] + ab_plco_cigsmok * plco[n] * cigsmok[n]);
    oy_lambda[n] = exp(oy_plco * plco[n] + oy_plco_cigsmok * plco[n] * cigsmok[n]);
  }
}

model {
  // Priors
  sens ~ beta(3, 2);

  be_plco ~ normal(1, 1);
  be_cons ~ normal(-2, 4);
  bl_plco ~ normal(1, 1);
  bl_cons ~ normal(-2, 4);
  
  ab_plco ~ normal(1, 2);
  ab_plco_cigsmok ~ normal(0, 2);
  ab_cons ~ normal(-2, 4);
  ab_gamma ~ lognormal(1, 0.5);
  
  lambda_bc ~ lognormal(0, 1);
  lambda_bd ~ lognormal(0, 1);
  lambda_ce ~ lognormal(1, 1);
  lambda_cz ~ lognormal(0, 1);
  
  oy_plco ~ normal(1, 2);
  oy_plco_cigsmok ~ normal(0, 2);
  oy_cons ~ normal(-2, 4);
  oy_gamma ~ lognormal(1, 0.5);
  
  // I have separated the model into a function in order to parallelise
  // But this is the sequential version of the code
  target += partial_sum(pt_first,
                        1, N_pt,
                        // Data
                        plco,
                        recordtype,
                        event,
                        t_l,
                        t_r,
                        pt_first,
                        pt_records,
                        // Parameters
                        sens,
                        be_plco,
                        be_cons,
                        bl_plco,
                        bl_cons,
                        ab_gamma,
                        lambda_bc,
                        lambda_bd,
                        lambda_ce,
                        lambda_cz,
                        oy_gamma,
                        ab_lambda,
                        oy_lambda);
}
****

In case it is relevant, I am running on Windows 7.

The first 20 rows of the data are:

> data
$recordtype
 [1] 0 2 1 2 1 2 1 2 0 1 2 1 2 1 2 0 2 1 2 1
 [ reached getOption("max.print") -- omitted 141784 entries ]

$event
 [1] 0 4 1 4 1 4 1 7 0 1 4 1 4 1 6 0 4 1 4 1
 [ reached getOption("max.print") -- omitted 141784 entries ]

$t_l
 [1] 0.000000 0.000000 0.017418 0.017418 1.035733 1.035733 2.270258 2.270258 0.000000 0.000000 0.000000
[12] 1.003202 1.003202 2.181076 2.181076 0.000000 0.000000 0.017609 0.017609 0.973374
 [ reached getOption("max.print") -- omitted 141784 entries ]

$t_r
 [1] 0.000000 0.017418 0.017418 1.035733 1.035733 2.270258 2.270258 3.969926 0.000000 0.000000 1.003202
[12] 1.003202 2.181076 2.181076 3.486169 0.000000 0.017609 0.017609 0.973374 0.973374
 [ reached getOption("max.print") -- omitted 141784 entries ]

$cigsmok
 [1] 0 1 0 0 0 0 0 1 0 0 0 0 1 0 0 1 1 1 0 1
 [ reached getOption("max.print") -- omitted 19980 entries ]

$plco
 [1] -3.725846 -2.886477 -4.707330 -3.694662 -3.361498 -3.199309 -3.081870 -3.133034 -3.934819 -5.523318
[11] -3.451841 -5.379946 -4.636202 -3.164560 -3.080575 -2.720887 -4.159040 -3.527086 -5.244920 -5.050195
 [ reached getOption("max.print") -- omitted 19980 entries ]

$pt_records
 [1] 8 7 8 8 7 8 8 8 7 8 7 7 7 7 7 8 7 7 8 8
 [ reached getOption("max.print") -- omitted 19980 entries ]

$pt_first
 [1]   1   9  16  24  32  39  47  55  63  70  78  85  92  99 106 113 121 128 135 143
 [ reached getOption("max.print") -- omitted 19980 entries ]

$N
[1] 141804

$N_pt
[1] 20000

Hoping somebody can provide some advice…

I have rerun it with N_pt = 1000, 2000 and 4000 (N ~= 7 * N_pt) and the memory usage appears to be linear. Still cannot work out why the memory usage should be >3000 times the size of the data.

N_pt    Memory
----------------
1000    697 MB
2000    1.39 GB
4000    2.80 GB

The autodiff tree can get large pretty quickly. The thing to remember is that variables in Stan are mostly autodiff variables. So even if they act like doubles, they carry around more information. For a scalar there will be the value (8 bytes), the adjoint (for autodiff, another 8 bytes), and then a pointer for how the C++ stuff works behind the scenes (another 8 bytes).

And then for every operation you do, there will be temporaries and these temporary variables also take up space. Even though they aren’t visible in the code, they need to be saved so reverse mode autodiff will work.

So like in this:

real a = 1.0;
real b = 2.0;
real c = 3.0;
real d = a + b + c;

There will a hidden expression for either a + b or b + c and that will take memory as well. And this is true for all the loops and everything in them in your model, so I assume that’s where the blow up is happening.

Have you tested that partial_sum with reduce_sum yet? I think the way reduce_sum works you should be able to limit how much of the autodiff tree is in memory at any point. Can you try the tests you have with reduce_sum and see if the max memory characteristics change?

2 Likes

Wow @bbbales2 that did the trick!

It took me ages to work out why it wasn’t making a difference what grainsize I picked. Turns out I wasn’t compiling with stan_threads = TRUE so it wasn’t actually sharding the dataset despite the grainsize instruction.

I can now limit the memory usage to a much more sensible amount.

Thank you so much!

2 Likes

We have like a compatibility version of reduce_sum for when threading is turned off. I hadn’t thought about it before you said this, but maybe we should respect the grainsize for this reason (giant data problems).

1 Like