Dealing with memory issues in Markov chain style model

Hi Stan people,

I’ve recently been trying to get a Markov chain type model to work in Stan. I have however found that my model quickly starts to use up all of my system memory and often ends up crashing. This is even when using a tiny amount of data. I usually have N_s_predictors = 13, N_d_predictors = 6, max_time = 92, and N_individuals = 5000 (although full dataset is 60,000 ish, I can usually get a good fit with this number on other versions of the model). However, the problems occur even when running on just N_individuals = 200.

While my data block may be reasonably sized, I have spent a lot of time reducing the number of parameters and model block components, so that the do not scale with number of individuals. On advice from previous posts I have found on here, I have introduced local scoping, and vectorised many of my model block variables. So I’m now at a bit of a loss as to why all my memory is being used up for the small data sets.

I have also switched to fitting with mle, removing priors and using cmdstan’s $optimize and $mle functions, in order to sort this, with the plan to just fit the model using frequentist style bootstrapping (still using stan as optimizer is fast), however, I find myself unable to do even this now.

The model code is below

data{
  int N_individuals;
  int N_s_predictors;
  int N_d_predictors;
  int max_time;
    
  array[N_individuals] int init_times;
  array[N_individuals,N_s_predictors] real s_predictors;
  array[N_individuals,max_time,N_d_predictors] int d_predictors;
  
  array[N_individuals,max_time] int clin_outcomes;
}
parameters {
      //normalizer p
  array[4,5] real<lower=0> p;
  array[4,5] real base_par_1;//


  array[N_s_predictors,4,5] real static_coefs;
  //array[N_s_predictors,4,5] real static_decay;

  array[N_d_predictors,4,5] real dynam_coefs;
  array[N_d_predictors,4,5] real<upper=0> dynam_decay;

  array[4,5] real age_coefs;
  array[4,5] real<upper=0> age_decay;
  
  array[4,N_d_predictors] real base_par_dpred;
  array[4,N_d_predictors] real<lower=0> p_dpred;
  array[4,N_d_predictors] real<upper=0> age_decay_dpred;
  array[4,N_d_predictors] real age_coef_dpred;

}
model{
  //run chain
  for (individual in 1:N_individuals){
    array[5] vector[max_time] state_prob;
    array[5] vector[max_time] state_prob_guess;
    //initialize chain
    for (clin_outcome2 in 1:5){
      if (clin_outcomes[individual,init_times[individual]] == clin_outcome2){
        state_prob[clin_outcome2,init_times[individual]] = 1;
        state_prob_guess[clin_outcome2,init_times[individual]] = 1;
      } else {
        state_prob[clin_outcome2,init_times[individual]] = 0;
        state_prob_guess[clin_outcome2,init_times[individual]] = 0;
      }
      
      for (time in (init_times[individual]+1):max_time){
        state_prob[clin_outcome2,time] = 0;
        state_prob_guess[clin_outcome2,time] = 0;
      }
    }
    //initialize dprobabilities as well
    array[N_d_predictors] vector[max_time] dprob;
    array[N_d_predictors] vector[max_time] dprob_guess;
    for (dpred in 1:N_d_predictors){
      if (d_predictors[individual,init_times[individual],dpred] == 1){//currently, this assumes that if first point is missing
        dprob[dpred,init_times[individual]] = 1;//then event has not occurred, this should be replaced with a probability later
      } else {
        dprob[dpred,init_times[individual]] = 0;
      }
    }
    
    array[N_d_predictors] vector[max_time] switch_vector;
    for (time in (init_times[individual]):max_time){
      
      
      for (dpred in 1:N_d_predictors){
        switch_vector[dpred] = rep_vector(0,max_time);
        if (d_predictors[individual,time,dpred] == 1){
          dprob[dpred,time] = 1;
        } else if (d_predictors[individual,time,dpred] == 0 || time == init_times[individual]) {//currently, this assumes that if first point is missing
          dprob[dpred,time] = 0;//then event has not occurred, this should be replaced with a probability later
        } else {//if no recorded value, we estimate the probability of them having the condition
          dprob[dpred,time] = tanh(sum(to_vector(p_dpred[,dpred]) .* exp(time * to_vector(base_par_dpred[,dpred])) .* exp(to_vector(age_coef_dpred[,dpred]) .* exp((time - init_times[individual]) .* to_vector(age_decay_dpred[,dpred])))));
        }
        dprob_guess[dpred,time] = tanh(sum(to_vector(p_dpred[,dpred]) .* exp(time * to_vector(base_par_dpred[,dpred])) .* exp(to_vector(age_coef_dpred[,dpred]) .* exp((time - init_times[individual]) .* to_vector(age_decay_dpred[,dpred])))));

        
        //vector of probabilities of where we think the last switch was
        switch_vector[dpred,time] = dprob[dpred,time];
        if (init_times[individual] < time){
          for (time2 in (time-1):init_times[individual]){
            switch_vector[dpred,time2] = (1- sum(switch_vector[dpred,(time2+1):time])) * dprob[dpred,time2];
          }
        }
        
      }
      

      //calculate multipliers
      array[5,5] real static_multiplier;
      array[5,5] real dynamic_multiplier;
      array[5,5] real age_multiplier;
      array[5,5] real base_propensity;
      for (clin_outcome in 1:4){
        for (clin_outcome2 in 1:5){
          base_propensity[clin_outcome,clin_outcome2] = p[clin_outcome,clin_outcome2] * exp(time * base_par_1[clin_outcome,clin_outcome2]);
          static_multiplier[clin_outcome,clin_outcome2] = exp(dot_product(to_vector(s_predictors[individual,]),to_vector(static_coefs[,clin_outcome,clin_outcome2]) ));//.* to_vector(exp(time*to_vector(static_decay[,clin_outcome,clin_outcome2])))));
          dynamic_multiplier[clin_outcome,clin_outcome2] = 1;
          for (d_pred in 1:N_d_predictors){
            dynamic_multiplier[clin_outcome,clin_outcome2] = prod(dynamic_multiplier[clin_outcome,clin_outcome2] * exp(dynam_coefs[d_pred,clin_outcome,clin_outcome2]*exp(switch_vector[d_pred]*dynam_decay[d_pred,clin_outcome,clin_outcome2])));
          }
          age_multiplier[clin_outcome,clin_outcome2] = exp(age_coefs[clin_outcome,clin_outcome2] * exp((time-init_times[individual])*age_decay[clin_outcome,clin_outcome2]));
        }
      }
      
      //calculate transitions
      array[5,5] vector[max_time] transition;
      for (clin_outcome in 1:1){
        transition[clin_outcome,2,time] = tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
        transition[clin_outcome,3,time] = (1 - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
        transition[clin_outcome,4,time] = (1 - transition[clin_outcome,2,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,2,time] - transition[clin_outcome,3,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,1,time] = 1 - transition[clin_outcome,2,time]- transition[clin_outcome,3,time]- transition[clin_outcome,4,time]- transition[clin_outcome,5,time];
      }
      
      for (clin_outcome in 2:2){
        transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
        transition[clin_outcome,3,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
        transition[clin_outcome,4,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,3,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,2,time] = 1 - transition[clin_outcome,1,time] - sum(transition[clin_outcome,3:5,time]);
      }
      
      for (clin_outcome in 3:3){
        transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
        transition[clin_outcome,2,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
        transition[clin_outcome,4,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,3,time] = 1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]- transition[clin_outcome,4,time]- transition[clin_outcome,5,time];
      }
      
      for (clin_outcome in 4:4){
        transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
        transition[clin_outcome,2,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
        transition[clin_outcome,3,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
        transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
        transition[clin_outcome,4,time] = 1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,3,time] - transition[clin_outcome,5,time];
      }
      
      transition[5,1,time] = 0;
      transition[5,2,time] = 0;
      transition[5,3,time] = 0;
      transition[5,4,time] = 0;
      transition[5,5,time] = 1;

      //do transition
      if (time > init_times[individual]){
        for (clin_outcome in 1:5){
          for (clin_outcome2 in 1:5){
            state_prob[clin_outcome2,time] = state_prob[clin_outcome2,time] + state_prob[clin_outcome,time-1] * transition[clin_outcome,clin_outcome2,time-1];
          }
        }
        
        //Now if we are at a known point - record guess and correct
          if (clin_outcomes[individual,time] != -1){
            for (clin_outcome in 1:5){
              state_prob_guess[clin_outcome,time] = state_prob[clin_outcome,time];
              //Also condition result if we are on a training individual, or if we are below threshold for test individuals
                if (clin_outcomes[individual,time] == clin_outcome){
                  state_prob[clin_outcome,time] = 1;
                } else {
                  state_prob[clin_outcome,time] = 0;
                }
            }
          }
  
      }
    }
  
  
          //Add likelihood to target
      for (time in (init_times[individual]+1):max_time){
        
        //There can be some small computational errors leading to some probabilities being below zero
        for (clin_outcome in 1:5){
          state_prob_guess[clin_outcome,time] = max([1e-10,state_prob_guess[clin_outcome,time]]);
        }
  
        if (clin_outcomes[individual,time] != -1){
          clin_outcomes[individual,time] ~ categorical(to_vector(state_prob_guess[,time])/sum(state_prob_guess[,time]));
        }
        
        //Similarly for dpreds
        for (dpred in 1:N_d_predictors){
          if (d_predictors[individual,time,dpred] != -1){
            (d_predictors[individual,time,dpred]+1) ~ categorical(to_vector([1-dprob_guess[dpred,time],dprob_guess[dpred,time]]));
          }
        }
      }

    
    
  }
  
  
  
}

}

The model is quite complex at the moment, and I’d be happy to explain any details if requested. I am kind of hoping that you’ll see some really obvious computational issue though.

I’d be very grateful for any suggestions you could make. Thanks.

1 Like

The model’s a bit too much to dive through, but here’s a few things to keep in mind:

  1. Stan is going to store all the parameters and transformed parameters every iteration—this can get big. You can work out how big this is by counting out 8 bytes or so for each value saved. There are a lot of 4 x 5 x N_s_predictors and 4 x 5 x N_d_predictors.

  2. Every intermediate computation result consumes about 40 bytes on the autodiff stack. So this is likely to be where you run out of memory. This gets reused each iteration, but it can also consume a lot of memory.

  3. You need to make sure Stan gets a lot of memory access. Do you have a 64-bit OS and how much memory are you giving it now?

I’m guessing from looking at your code that the main problem is (2). If you use something like CmdStanPy or CmdStan, it streams output to disk. So that might be a way to get it to run without blowing out memory. Then you can fire up a new job to read the data back in.

Bonus. I’d try very hard to not use exp if you can avoid it. It’s very numerically unstable. The longe you can stay on the log scale, the more stable things will be. Instead of adding, use log_sum_exp and instead of multiplying, add, and instead of exponentiation, multiply.

Hi Bob,

Thanks very much for the response and explaination - I don’t have much of a background in computer science so don’t know much about how this works.

I think you may be correct with the issue being point 2 - I was not aware that intermediate computations would increase memory usage (I would’ve thought that the computer would be able to ‘forget’ any values as soon as the computation is completed?), and was focusing on reducing the things mentioned in point 1.

With testing out removing certain operations , I have been able to reduce memory usage somewhat and identified some lines I think are causing the issues so might be able to sort something out now I know what the issue is.

It might be tricky though, so with respect to the other stuff you mentioned: I am using a 64-bit OS with 16GB memory. I also have an addition 50ish GB virtual memory I’ve had to use although I’d rather avoid needing to use as much of this as possible as I am told it slows computational speed a lot. Looks like I may not have any choice though.

I was also already using CmdStan with the $optimize function. I don’t have a super great understanding of how this works to be honest, but does it stream to disk by default or do I need to do something, for example, with the output_dir argument?

Thanks also for the point about using exponentials - just to clarify I’ve understood, are you recommending that I should switch from things that look like exp(a)*exp(b) to things that look like exp(a+b) - leaving the exponentiation until last, or redesigning the model so that there aren’t exponentials anymore?

Thanks.

In normal computations over primitive numbers, that’s true. That’s even true of forward-mode automatic differentiation (but that doesn’t scale well in dimension).

Stan and other reverse-mode automatic differentiation systems (which do scale well in dimension) build up an expression graph with the result and every argumnent of every operation and a pointer to the code to compute the partial derivatives. That code ges called in a backward pass over the expression graph from the result back to the inputs. There’s an explanation and diagrams here:

Yes. The only way that’s fast is if the data’s accessed sequentially. It’s not the throughput (width of the pipe) so much as the latency (how long water takes to arrive once you open the tap).

Yes! If you can push the algebra out to the point where you don’t need them at all, that’s the best. Otherwise, push the exponentiation to as late in the pipeline as possible. The problem is that exp(500) overflows and exp(-500) underflows—these are massive (miniscule) numbers. It’s even worse if you’re dealing with differences of numbers near each other. For example, if I have a mixture model with log probabilities, the mixture log probability is

log(lambda * exp(log p(y | theta[1])) + (1 - lambda) * exp(log p(y | theta[2])))

but that’s incredibly badly behaved computationally in terms of arithmetic precision. It’s much better to code this as

log_sum_exp(log(lambda) + log p(y | theta[1]), 
            log1m(lambda) + log p(y | theta[2])))

where log_sum_exp(u, v) = log(exp(u) + exp(v)), but is done in a clever way so as not to lose precision, and where log1m(u) = log(1 - u), but is more stable. Sorry this is so manual!