Dealing with memory issues in Markov chain style model

Hi Stan people,

I’ve recently been trying to get a Markov chain type model to work in Stan. I have however found that my model quickly starts to use up all of my system memory and often ends up crashing. This is even when using a tiny amount of data. I usually have N_s_predictors = 13, N_d_predictors = 6, max_time = 92, and N_individuals = 5000 (although full dataset is 60,000 ish, I can usually get a good fit with this number on other versions of the model). However, the problems occur even when running on just N_individuals = 200.

While my data block may be reasonably sized, I have spent a lot of time reducing the number of parameters and model block components, so that the do not scale with number of individuals. On advice from previous posts I have found on here, I have introduced local scoping, and vectorised many of my model block variables. So I’m now at a bit of a loss as to why all my memory is being used up for the small data sets.

I have also switched to fitting with mle, removing priors and using cmdstan’s $optimize and $mle functions, in order to sort this, with the plan to just fit the model using frequentist style bootstrapping (still using stan as optimizer is fast), however, I find myself unable to do even this now.

The model code is below

data{
  int N_individuals;
  int N_s_predictors;
  int N_d_predictors;
  int max_time;
    
  array[N_individuals] int init_times;
  array[N_individuals,N_s_predictors] real s_predictors;
  array[N_individuals,max_time,N_d_predictors] int d_predictors;
  
  array[N_individuals,max_time] int clin_outcomes;
}
parameters {
      //normalizer p
  array[4,5] real<lower=0> p;
  array[4,5] real base_par_1;//


  array[N_s_predictors,4,5] real static_coefs;
  //array[N_s_predictors,4,5] real static_decay;

  array[N_d_predictors,4,5] real dynam_coefs;
  array[N_d_predictors,4,5] real<upper=0> dynam_decay;

  array[4,5] real age_coefs;
  array[4,5] real<upper=0> age_decay;
  
  array[4,N_d_predictors] real base_par_dpred;
  array[4,N_d_predictors] real<lower=0> p_dpred;
  array[4,N_d_predictors] real<upper=0> age_decay_dpred;
  array[4,N_d_predictors] real age_coef_dpred;

}
model{
  //run chain
  for (individual in 1:N_individuals){
    array[5] vector[max_time] state_prob;
    array[5] vector[max_time] state_prob_guess;
    //initialize chain
    for (clin_outcome2 in 1:5){
      if (clin_outcomes[individual,init_times[individual]] == clin_outcome2){
        state_prob[clin_outcome2,init_times[individual]] = 1;
        state_prob_guess[clin_outcome2,init_times[individual]] = 1;
      } else {
        state_prob[clin_outcome2,init_times[individual]] = 0;
        state_prob_guess[clin_outcome2,init_times[individual]] = 0;
      }
      
      for (time in (init_times[individual]+1):max_time){
        state_prob[clin_outcome2,time] = 0;
        state_prob_guess[clin_outcome2,time] = 0;
      }
    }
    //initialize dprobabilities as well
    array[N_d_predictors] vector[max_time] dprob;
    array[N_d_predictors] vector[max_time] dprob_guess;
    for (dpred in 1:N_d_predictors){
      if (d_predictors[individual,init_times[individual],dpred] == 1){//currently, this assumes that if first point is missing
        dprob[dpred,init_times[individual]] = 1;//then event has not occurred, this should be replaced with a probability later
      } else {
        dprob[dpred,init_times[individual]] = 0;
      }
    }
    
    array[N_d_predictors] vector[max_time] switch_vector;
    for (time in (init_times[individual]):max_time){
      
      
      for (dpred in 1:N_d_predictors){
        switch_vector[dpred] = rep_vector(0,max_time);
        if (d_predictors[individual,time,dpred] == 1){
          dprob[dpred,time] = 1;
        } else if (d_predictors[individual,time,dpred] == 0 || time == init_times[individual]) {//currently, this assumes that if first point is missing
          dprob[dpred,time] = 0;//then event has not occurred, this should be replaced with a probability later
        } else {//if no recorded value, we estimate the probability of them having the condition
          dprob[dpred,time] = tanh(sum(to_vector(p_dpred[,dpred]) .* exp(time * to_vector(base_par_dpred[,dpred])) .* exp(to_vector(age_coef_dpred[,dpred]) .* exp((time - init_times[individual]) .* to_vector(age_decay_dpred[,dpred])))));
        }
        dprob_guess[dpred,time] = tanh(sum(to_vector(p_dpred[,dpred]) .* exp(time * to_vector(base_par_dpred[,dpred])) .* exp(to_vector(age_coef_dpred[,dpred]) .* exp((time - init_times[individual]) .* to_vector(age_decay_dpred[,dpred])))));

        
        //vector of probabilities of where we think the last switch was
        switch_vector[dpred,time] = dprob[dpred,time];
        if (init_times[individual] < time){
          for (time2 in (time-1):init_times[individual]){
            switch_vector[dpred,time2] = (1- sum(switch_vector[dpred,(time2+1):time])) * dprob[dpred,time2];
          }
        }
        
      }
      

      //calculate multipliers
      array[5,5] real static_multiplier;
      array[5,5] real dynamic_multiplier;
      array[5,5] real age_multiplier;
      array[5,5] real base_propensity;
      for (clin_outcome in 1:4){
        for (clin_outcome2 in 1:5){
          base_propensity[clin_outcome,clin_outcome2] = p[clin_outcome,clin_outcome2] * exp(time * base_par_1[clin_outcome,clin_outcome2]);
          static_multiplier[clin_outcome,clin_outcome2] = exp(dot_product(to_vector(s_predictors[individual,]),to_vector(static_coefs[,clin_outcome,clin_outcome2]) ));//.* to_vector(exp(time*to_vector(static_decay[,clin_outcome,clin_outcome2])))));
          dynamic_multiplier[clin_outcome,clin_outcome2] = 1;
          for (d_pred in 1:N_d_predictors){
            dynamic_multiplier[clin_outcome,clin_outcome2] = prod(dynamic_multiplier[clin_outcome,clin_outcome2] * exp(dynam_coefs[d_pred,clin_outcome,clin_outcome2]*exp(switch_vector[d_pred]*dynam_decay[d_pred,clin_outcome,clin_outcome2])));
          }
          age_multiplier[clin_outcome,clin_outcome2] = exp(age_coefs[clin_outcome,clin_outcome2] * exp((time-init_times[individual])*age_decay[clin_outcome,clin_outcome2]));
        }
      }
      
      //calculate transitions
      array[5,5] vector[max_time] transition;
      for (clin_outcome in 1:1){
        transition[clin_outcome,2,time] = tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
        transition[clin_outcome,3,time] = (1 - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
        transition[clin_outcome,4,time] = (1 - transition[clin_outcome,2,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,2,time] - transition[clin_outcome,3,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,1,time] = 1 - transition[clin_outcome,2,time]- transition[clin_outcome,3,time]- transition[clin_outcome,4,time]- transition[clin_outcome,5,time];
      }
      
      for (clin_outcome in 2:2){
        transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
        transition[clin_outcome,3,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
        transition[clin_outcome,4,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,3,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,2,time] = 1 - transition[clin_outcome,1,time] - sum(transition[clin_outcome,3:5,time]);
      }
      
      for (clin_outcome in 3:3){
        transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
        transition[clin_outcome,2,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
        transition[clin_outcome,4,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,3,time] = 1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]- transition[clin_outcome,4,time]- transition[clin_outcome,5,time];
      }
      
      for (clin_outcome in 4:4){
        transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
        transition[clin_outcome,2,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
        transition[clin_outcome,3,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,4,time] = 1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,3,time] - transition[clin_outcome,5,time];
      }
      
      transition[5,1,time] = 0;
      transition[5,2,time] = 0;
      transition[5,3,time] = 0;
      transition[5,4,time] = 0;
      transition[5,5,time] = 1;

      //do transition
      if (time > init_times[individual]){
        for (clin_outcome in 1:5){
          for (clin_outcome2 in 1:5){
            state_prob[clin_outcome2,time] = state_prob[clin_outcome2,time] + state_prob[clin_outcome,time-1] * transition[clin_outcome,clin_outcome2,time-1];
          }
        }
        
        //Now if we are at a known point - record guess and correct
          if (clin_outcomes[individual,time] != -1){
            for (clin_outcome in 1:5){
              state_prob_guess[clin_outcome,time] = state_prob[clin_outcome,time];
              //Also condition result if we are on a training individual, or if we are below threshold for test individuals
                if (clin_outcomes[individual,time] == clin_outcome){
                  state_prob[clin_outcome,time] = 1;
                } else {
                  state_prob[clin_outcome,time] = 0;
                }
            }
          }
  
      }
    }
  
  
          //Add likelihood to target
      for (time in (init_times[individual]+1):max_time){
        
        //There can be some small computational errors leading to some probabilities being below zero
        for (clin_outcome in 1:5){
          state_prob_guess[clin_outcome,time] = max([1e-10,state_prob_guess[clin_outcome,time]]);
        }
  
        if (clin_outcomes[individual,time] != -1){
          clin_outcomes[individual,time] ~ categorical(to_vector(state_prob_guess[,time])/sum(state_prob_guess[,time]));
        }
        
        //Similarly for dpreds
        for (dpred in 1:N_d_predictors){
          if (d_predictors[individual,time,dpred] != -1){
            (d_predictors[individual,time,dpred]+1) ~ categorical(to_vector([1-dprob_guess[dpred,time],dprob_guess[dpred,time]]));
          }
        }
      }

    
    
  }
  
  
  
}

}

The model is quite complex at the moment, and I’d be happy to explain any details if requested. I am kind of hoping that you’ll see some really obvious computational issue though.

I’d be very grateful for any suggestions you could make. Thanks.

1 Like

The model’s a bit too much to dive through, but here’s a few things to keep in mind:

  1. Stan is going to store all the parameters and transformed parameters every iteration—this can get big. You can work out how big this is by counting out 8 bytes or so for each value saved. There are a lot of 4 x 5 x N_s_predictors and 4 x 5 x N_d_predictors.

  2. Every intermediate computation result consumes about 40 bytes on the autodiff stack. So this is likely to be where you run out of memory. This gets reused each iteration, but it can also consume a lot of memory.

  3. You need to make sure Stan gets a lot of memory access. Do you have a 64-bit OS and how much memory are you giving it now?

I’m guessing from looking at your code that the main problem is (2). If you use something like CmdStanPy or CmdStan, it streams output to disk. So that might be a way to get it to run without blowing out memory. Then you can fire up a new job to read the data back in.

Bonus. I’d try very hard to not use exp if you can avoid it. It’s very numerically unstable. The longe you can stay on the log scale, the more stable things will be. Instead of adding, use log_sum_exp and instead of multiplying, add, and instead of exponentiation, multiply.