Hi Stan people,
I’ve recently been trying to get a Markov chain type model to work in Stan. I have however found that my model quickly starts to use up all of my system memory and often ends up crashing. This is even when using a tiny amount of data. I usually have N_s_predictors = 13, N_d_predictors = 6, max_time = 92, and N_individuals = 5000 (although full dataset is 60,000 ish, I can usually get a good fit with this number on other versions of the model). However, the problems occur even when running on just N_individuals = 200.
While my data block may be reasonably sized, I have spent a lot of time reducing the number of parameters and model block components, so that the do not scale with number of individuals. On advice from previous posts I have found on here, I have introduced local scoping, and vectorised many of my model block variables. So I’m now at a bit of a loss as to why all my memory is being used up for the small data sets.
I have also switched to fitting with mle, removing priors and using cmdstan’s $optimize and $mle functions, in order to sort this, with the plan to just fit the model using frequentist style bootstrapping (still using stan as optimizer is fast), however, I find myself unable to do even this now.
The model code is below
data{
int N_individuals;
int N_s_predictors;
int N_d_predictors;
int max_time;
array[N_individuals] int init_times;
array[N_individuals,N_s_predictors] real s_predictors;
array[N_individuals,max_time,N_d_predictors] int d_predictors;
array[N_individuals,max_time] int clin_outcomes;
}
parameters {
//normalizer p
array[4,5] real<lower=0> p;
array[4,5] real base_par_1;//
array[N_s_predictors,4,5] real static_coefs;
//array[N_s_predictors,4,5] real static_decay;
array[N_d_predictors,4,5] real dynam_coefs;
array[N_d_predictors,4,5] real<upper=0> dynam_decay;
array[4,5] real age_coefs;
array[4,5] real<upper=0> age_decay;
array[4,N_d_predictors] real base_par_dpred;
array[4,N_d_predictors] real<lower=0> p_dpred;
array[4,N_d_predictors] real<upper=0> age_decay_dpred;
array[4,N_d_predictors] real age_coef_dpred;
}
model{
//run chain
for (individual in 1:N_individuals){
array[5] vector[max_time] state_prob;
array[5] vector[max_time] state_prob_guess;
//initialize chain
for (clin_outcome2 in 1:5){
if (clin_outcomes[individual,init_times[individual]] == clin_outcome2){
state_prob[clin_outcome2,init_times[individual]] = 1;
state_prob_guess[clin_outcome2,init_times[individual]] = 1;
} else {
state_prob[clin_outcome2,init_times[individual]] = 0;
state_prob_guess[clin_outcome2,init_times[individual]] = 0;
}
for (time in (init_times[individual]+1):max_time){
state_prob[clin_outcome2,time] = 0;
state_prob_guess[clin_outcome2,time] = 0;
}
}
//initialize dprobabilities as well
array[N_d_predictors] vector[max_time] dprob;
array[N_d_predictors] vector[max_time] dprob_guess;
for (dpred in 1:N_d_predictors){
if (d_predictors[individual,init_times[individual],dpred] == 1){//currently, this assumes that if first point is missing
dprob[dpred,init_times[individual]] = 1;//then event has not occurred, this should be replaced with a probability later
} else {
dprob[dpred,init_times[individual]] = 0;
}
}
array[N_d_predictors] vector[max_time] switch_vector;
for (time in (init_times[individual]):max_time){
for (dpred in 1:N_d_predictors){
switch_vector[dpred] = rep_vector(0,max_time);
if (d_predictors[individual,time,dpred] == 1){
dprob[dpred,time] = 1;
} else if (d_predictors[individual,time,dpred] == 0 || time == init_times[individual]) {//currently, this assumes that if first point is missing
dprob[dpred,time] = 0;//then event has not occurred, this should be replaced with a probability later
} else {//if no recorded value, we estimate the probability of them having the condition
dprob[dpred,time] = tanh(sum(to_vector(p_dpred[,dpred]) .* exp(time * to_vector(base_par_dpred[,dpred])) .* exp(to_vector(age_coef_dpred[,dpred]) .* exp((time - init_times[individual]) .* to_vector(age_decay_dpred[,dpred])))));
}
dprob_guess[dpred,time] = tanh(sum(to_vector(p_dpred[,dpred]) .* exp(time * to_vector(base_par_dpred[,dpred])) .* exp(to_vector(age_coef_dpred[,dpred]) .* exp((time - init_times[individual]) .* to_vector(age_decay_dpred[,dpred])))));
//vector of probabilities of where we think the last switch was
switch_vector[dpred,time] = dprob[dpred,time];
if (init_times[individual] < time){
for (time2 in (time-1):init_times[individual]){
switch_vector[dpred,time2] = (1- sum(switch_vector[dpred,(time2+1):time])) * dprob[dpred,time2];
}
}
}
//calculate multipliers
array[5,5] real static_multiplier;
array[5,5] real dynamic_multiplier;
array[5,5] real age_multiplier;
array[5,5] real base_propensity;
for (clin_outcome in 1:4){
for (clin_outcome2 in 1:5){
base_propensity[clin_outcome,clin_outcome2] = p[clin_outcome,clin_outcome2] * exp(time * base_par_1[clin_outcome,clin_outcome2]);
static_multiplier[clin_outcome,clin_outcome2] = exp(dot_product(to_vector(s_predictors[individual,]),to_vector(static_coefs[,clin_outcome,clin_outcome2]) ));//.* to_vector(exp(time*to_vector(static_decay[,clin_outcome,clin_outcome2])))));
dynamic_multiplier[clin_outcome,clin_outcome2] = 1;
for (d_pred in 1:N_d_predictors){
dynamic_multiplier[clin_outcome,clin_outcome2] = prod(dynamic_multiplier[clin_outcome,clin_outcome2] * exp(dynam_coefs[d_pred,clin_outcome,clin_outcome2]*exp(switch_vector[d_pred]*dynam_decay[d_pred,clin_outcome,clin_outcome2])));
}
age_multiplier[clin_outcome,clin_outcome2] = exp(age_coefs[clin_outcome,clin_outcome2] * exp((time-init_times[individual])*age_decay[clin_outcome,clin_outcome2]));
}
}
//calculate transitions
array[5,5] vector[max_time] transition;
for (clin_outcome in 1:1){
transition[clin_outcome,2,time] = tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
transition[clin_outcome,3,time] = (1 - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
transition[clin_outcome,4,time] = (1 - transition[clin_outcome,2,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
transition[clin_outcome,5,time] = (1 - transition[clin_outcome,2,time] - transition[clin_outcome,3,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
transition[clin_outcome,1,time] = 1 - transition[clin_outcome,2,time]- transition[clin_outcome,3,time]- transition[clin_outcome,4,time]- transition[clin_outcome,5,time];
}
for (clin_outcome in 2:2){
transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
transition[clin_outcome,3,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
transition[clin_outcome,4,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,3,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
transition[clin_outcome,2,time] = 1 - transition[clin_outcome,1,time] - sum(transition[clin_outcome,3:5,time]);
}
for (clin_outcome in 3:3){
transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
transition[clin_outcome,2,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
transition[clin_outcome,4,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,4] * static_multiplier[clin_outcome,4] * dynamic_multiplier[clin_outcome,4] * age_multiplier[clin_outcome,4]);
transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,4,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
transition[clin_outcome,3,time] = 1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]- transition[clin_outcome,4,time]- transition[clin_outcome,5,time];
}
for (clin_outcome in 4:4){
transition[clin_outcome,1,time] = tanh(base_propensity[clin_outcome,1] * static_multiplier[clin_outcome,1] * dynamic_multiplier[clin_outcome,1] * age_multiplier[clin_outcome,1]);
transition[clin_outcome,2,time] = (1 - transition[clin_outcome,1,time]) * tanh(base_propensity[clin_outcome,2] * static_multiplier[clin_outcome,2] * dynamic_multiplier[clin_outcome,2] * age_multiplier[clin_outcome,2]);
transition[clin_outcome,3,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time]) * tanh(base_propensity[clin_outcome,3] * static_multiplier[clin_outcome,3] * dynamic_multiplier[clin_outcome,3] * age_multiplier[clin_outcome,3]);
transition[clin_outcome,5,time] = (1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,3,time]) * tanh(base_propensity[clin_outcome,5] * static_multiplier[clin_outcome,5] * dynamic_multiplier[clin_outcome,5] * age_multiplier[clin_outcome,5]);
transition[clin_outcome,4,time] = 1 - transition[clin_outcome,1,time] - transition[clin_outcome,2,time] - transition[clin_outcome,3,time] - transition[clin_outcome,5,time];
}
transition[5,1,time] = 0;
transition[5,2,time] = 0;
transition[5,3,time] = 0;
transition[5,4,time] = 0;
transition[5,5,time] = 1;
//do transition
if (time > init_times[individual]){
for (clin_outcome in 1:5){
for (clin_outcome2 in 1:5){
state_prob[clin_outcome2,time] = state_prob[clin_outcome2,time] + state_prob[clin_outcome,time-1] * transition[clin_outcome,clin_outcome2,time-1];
}
}
//Now if we are at a known point - record guess and correct
if (clin_outcomes[individual,time] != -1){
for (clin_outcome in 1:5){
state_prob_guess[clin_outcome,time] = state_prob[clin_outcome,time];
//Also condition result if we are on a training individual, or if we are below threshold for test individuals
if (clin_outcomes[individual,time] == clin_outcome){
state_prob[clin_outcome,time] = 1;
} else {
state_prob[clin_outcome,time] = 0;
}
}
}
}
}
//Add likelihood to target
for (time in (init_times[individual]+1):max_time){
//There can be some small computational errors leading to some probabilities being below zero
for (clin_outcome in 1:5){
state_prob_guess[clin_outcome,time] = max([1e-10,state_prob_guess[clin_outcome,time]]);
}
if (clin_outcomes[individual,time] != -1){
clin_outcomes[individual,time] ~ categorical(to_vector(state_prob_guess[,time])/sum(state_prob_guess[,time]));
}
//Similarly for dpreds
for (dpred in 1:N_d_predictors){
if (d_predictors[individual,time,dpred] != -1){
(d_predictors[individual,time,dpred]+1) ~ categorical(to_vector([1-dprob_guess[dpred,time],dprob_guess[dpred,time]]));
}
}
}
}
}
}
The model is quite complex at the moment, and I’d be happy to explain any details if requested. I am kind of hoping that you’ll see some really obvious computational issue though.
I’d be very grateful for any suggestions you could make. Thanks.