Stan model uses a lot of RAM

Hi,

The following Stan code fails to run on my laptop because of the large RAM requirements, even with as few as 5000 data points (with 9 \le S \le 13 ). I’ve not seen these issues with RAM with Stan before, so I’m a bit lost on how to edit the code to have a smaller RAM requirement. The code involves making a matrix of size ~S^2, but S is a relatively small number (<20) so I don’t see why this should contribute Gb’s worth of RAM. Any help in bringing down the RAM requirements would be appreciated.

Thanks.

functions{
    matrix generateRateMatrix(int[,] stateVar, int num_states, real lam, real mu, real gamma, int s){
        matrix[num_states, num_states] RateMatrix;
        int k_down;
        int m_down;
        int k;
        int m;

        RateMatrix = rep_matrix(0., num_states, num_states);

        for (down in 1:num_states){
            k_down = stateVar[down, 1];
            m_down = stateVar[down, 2];

            for (across in 1:num_states){
                k = stateVar[across, 1];
                m = stateVar[across, 2];

                if (k == k_down-1 && m == m_down){
                    RateMatrix[down, across] = (s-m-k) * (k*lam/(s-1) + 2*mu);
                }
                else if (k == k_down && m == m_down-1){
                    RateMatrix[down, across] = m * (s-m-k) * lam / (s-1);
                }
                else if (k == k_down+1 && m == m_down-1){
                    RateMatrix[down, across] = k * (m*lam/(s-1)+mu);
                }
                else if (k == k_down+1 && m == m_down){
                    RateMatrix[down, across] = k * ((s-m-k)*lam/(s-1)+gamma);
                }
                else if (k == k_down && m == m_down+1){
                    RateMatrix[down, across] = m * (s-m-k) * lam / (s-1);
                }
                else if (k == k_down-1 && m == m_down+1){
                    RateMatrix[down, across] = m * (k*lam/(s-1)+2*gamma);
                }
                else if (k == k_down && m == m_down){
                    RateMatrix[down, across] = -(2*((k+m)*(s-m-k)+k*m)*lam/(s-1)
                                 + (k+2*m)*gamma + (2*s-(k+2*m))*mu);
                }
            }
        }
        
        return RateMatrix;
    }

    int[,] generatestateVar(int s, int num_states){
        int stateVar[num_states, 2];
        int i;

        i = 1;
        for (m in 0:s){
            for (k in 0:s){
                if (k+m <=s ){
                    stateVar[i, 1] = k;
                    stateVar[i, 2] = m;

                    i += 1;
                }
            }
        }

        return stateVar;
    }

    vector findProbDist(matrix RateMatrix, int[,] stateVar, int num_states, int s, real age){
        matrix[num_states, 1] InitialConditions;
        matrix[num_states, 1] ProbStates;
        vector[2*s+1] ProbDist;
        int k;
        int m;
        
        InitialConditions = rep_matrix(0., num_states, 1);
        InitialConditions[1, 1] = 0.5;
        InitialConditions[num_states, 1] = 0.5;

        ProbStates = scale_matrix_exp_multiply(age, RateMatrix, InitialConditions);

        ProbDist = rep_vector(0., 2*s+1);
        for (i in 1:num_states){
            k = stateVar[i, 1];
            m = stateVar[i, 2];
            ProbDist[k+2*m+1] += ProbStates[i, 1];
        }

        return ProbDist;
    }

    vector wrapper_function(int num_states, real lam, real mu, real gamma, int s, real age){
        int stateVar[num_states, 2];
        matrix[num_states, num_states] RateMatrix;
        vector[2*s+1] ProbDist;


        stateVar = generatestateVar(s, num_states);
        RateMatrix = generateRateMatrix(stateVar, num_states, lam, mu, gamma, s);

        ProbDist = findProbDist(RateMatrix, stateVar, num_states, s, age);

        return ProbDist;
    }

    vector runModel(real lam, real mu, real gamma, int s, real age){
        int num_states;
        vector[2*s+1] ProbDist;

        num_states = ((s+1) * (s+2)) / 2; // (s+1) * (s+2) is guaranteed to be even

        ProbDist = wrapper_function(num_states, lam, mu, gamma, s, age);

        return ProbDist;
    }

    real noisy_stemcell_lpdf(real[] y, real lam, real mu, real gamma, real sigma, real age, int s){
        int N = num_elements(y);
        int K = 2*s+1;
        vector[K] ProbDist;
        vector[K] lProbDist;
        vector[K] peak;
        vector[K] ltruncate;
        vector[N] LL;

        for (k in 1:K){
            peak[k] = (k-1.0) ./ (2.0*s);
            ltruncate[k] = log_diff_exp(normal_lcdf(1 | peak[k], sigma), normal_lcdf(0 | peak[k], sigma));
        }

        ProbDist = runModel(lam, mu, gamma, s, age);
        lProbDist = log(ProbDist);

        for (n in 1:N){
            vector[K] lps = lProbDist - ltruncate;
            for (k in 1:K){
                lps[k] += normal_lpdf(y[n] | peak[k], sigma);
            }

            LL[n] = log_sum_exp(lps);
        }

        return sum(LL);
    }

    real normal_lub_rng(real mean, real sigma, real lb, real ub) {
        real y_hat;
        y_hat = normal_rng(mean, sigma);
        
        while (y_hat < lb || y_hat > ub){
            y_hat = normal_rng(mean, sigma);
        }
        return y_hat;
    }

    real random_draw_rng(real lam, real mu, real gamma, real age, int s, real sigma){
        int K = 2*s+1;
        real y_hat;
        real peak;
        vector[K] ProbDist;
        int cat;

        ProbDist = runModel(lam, mu, gamma, s, age);
        cat = categorical_rng(ProbDist);

        peak = (cat-1.0) ./ (2.0*s);

        y_hat = normal_lub_rng(peak, sigma, 0.0, 1.0);

        return y_hat;
    }
}

data {
    int<lower=0> N;                 // Number of Sites 
    int<lower=1> T;
    int S[T];         // Stem Cell Number
    real<lower=0,upper=1> y[N] ;     // Fraction methylated
    real age;
}

transformed data {
  real log_unif;
  log_unif = -log(T);
}

parameters {
    real<lower=0> lam;            // Replacement rate
    real<lower=0> mu;                // Methylation rate
    real<lower=0> gamma;             // Demethylation rate   
    real<lower=0> sigma;
}

transformed parameters{
    vector[T] lp;

    lp = rep_vector(log_unif, T);

    for (t in 1:T){
        lp[t] += noisy_stemcell_lpdf(y | lam, mu, gamma, sigma, age, S[t]);
    }
}

model {
    lam ~ normal(0, 1);               // Prior
    mu ~ normal(0, 0.01);                // Prior
    gamma ~ normal(0, 0.01);             // Prior
    sigma ~ normal(0, 0.1);          // Prior

    target += log_sum_exp(lp);
}

generated quantities{
    int<lower=1, upper=T> t;
    real<lower=0,upper=1> y_hat;

    t = categorical_logit_rng(lp);

    y_hat = random_draw_rng(lam, mu, gamma, age, S[t], sigma);
}
1 Like

It doesn’t make sense to me either why this is ram-hungry. The standard suspects are output from generated quantities or transformed parameters, but in your cases both of those blocks are small so that doesn’t really make sense.

How many iterations are you running? Does this fail immediately or after some time?

Do you mind trying the simplified model just to see (generated quantities and transformed parameters removed)?:

                }
                else if (k == k_down && m == m_down-1){
                    RateMatrix[down, across] = m * (s-m-k) * lam / (s-1);
                }
                else if (k == k_down+1 && m == m_down-1){
                    RateMatrix[down, across] = k * (m*lam/(s-1)+mu);
                }
                else if (k == k_down+1 && m == m_down){
                    RateMatrix[down, across] = k * ((s-m-k)*lam/(s-1)+gamma);
                }
                else if (k == k_down && m == m_down+1){
                    RateMatrix[down, across] = m * (s-m-k) * lam / (s-1);
                }
                else if (k == k_down-1 && m == m_down+1){
                    RateMatrix[down, across] = m * (k*lam/(s-1)+2*gamma);
                }
                else if (k == k_down && m == m_down){
                    RateMatrix[down, across] = -(2*((k+m)*(s-m-k)+k*m)*lam/(s-1)
                                 + (k+2*m)*gamma + (2*s-(k+2*m))*mu);
                }
            }
        }
        
        return RateMatrix;
    }

    int[,] generatestateVar(int s, int num_states){
        int stateVar[num_states, 2];
        int i;

        i = 1;
        for (m in 0:s){
            for (k in 0:s){
                if (k+m <=s ){
                    stateVar[i, 1] = k;
                    stateVar[i, 2] = m;

                    i += 1;
                }
            }
        }

        return stateVar;
    }

    vector findProbDist(matrix RateMatrix, int[,] stateVar, int num_states, int s, real age){
        matrix[num_states, 1] InitialConditions;
        matrix[num_states, 1] ProbStates;
        vector[2*s+1] ProbDist;
        int k;
        int m;
        
        InitialConditions = rep_matrix(0., num_states, 1);
        InitialConditions[1, 1] = 0.5;
        InitialConditions[num_states, 1] = 0.5;

        ProbStates = scale_matrix_exp_multiply(age, RateMatrix, InitialConditions);

        ProbDist = rep_vector(0., 2*s+1);
        for (i in 1:num_states){
            k = stateVar[i, 1];
            m = stateVar[i, 2];
            ProbDist[k+2*m+1] += ProbStates[i, 1];
        }

        return ProbDist;
    }

    vector wrapper_function(int num_states, real lam, real mu, real gamma, int s, real age){
        int stateVar[num_states, 2];
        matrix[num_states, num_states] RateMatrix;
        vector[2*s+1] ProbDist;


        stateVar = generatestateVar(s, num_states);
        RateMatrix = generateRateMatrix(stateVar, num_states, lam, mu, gamma, s);

        ProbDist = findProbDist(RateMatrix, stateVar, num_states, s, age);

        return ProbDist;
    }

    vector runModel(real lam, real mu, real gamma, int s, real age){
        int num_states;
        vector[2*s+1] ProbDist;

        num_states = ((s+1) * (s+2)) / 2; // (s+1) * (s+2) is guaranteed to be even

        ProbDist = wrapper_function(num_states, lam, mu, gamma, s, age);

        return ProbDist;
    }

    real noisy_stemcell_lpdf(real[] y, real lam, real mu, real gamma, real sigma, real age, int s){
        int N = num_elements(y);
        int K = 2*s+1;
        vector[K] ProbDist;
        vector[K] lProbDist;
        vector[K] peak;
        vector[K] ltruncate;
        vector[N] LL;

        for (k in 1:K){
            peak[k] = (k-1.0) ./ (2.0*s);
            ltruncate[k] = log_diff_exp(normal_lcdf(1 | peak[k], sigma), normal_lcdf(0 | peak[k], sigma));
        }

        ProbDist = runModel(lam, mu, gamma, s, age);
        lProbDist = log(ProbDist);

        for (n in 1:N){
            vector[K] lps = lProbDist - ltruncate;
            for (k in 1:K){
                lps[k] += normal_lpdf(y[n] | peak[k], sigma);
            }

            LL[n] = log_sum_exp(lps);
        }

        return sum(LL);
    }

    real normal_lub_rng(real mean, real sigma, real lb, real ub) {
        real y_hat;
        y_hat = normal_rng(mean, sigma);
        
        while (y_hat < lb || y_hat > ub){
            y_hat = normal_rng(mean, sigma);
        }
        return y_hat;
    }

    real random_draw_rng(real lam, real mu, real gamma, real age, int s, real sigma){
        int K = 2*s+1;
        real y_hat;
        real peak;
        vector[K] ProbDist;
        int cat;

        ProbDist = runModel(lam, mu, gamma, s, age);
        cat = categorical_rng(ProbDist);

        peak = (cat-1.0) ./ (2.0*s);

        y_hat = normal_lub_rng(peak, sigma, 0.0, 1.0);

        return y_hat;
    }
}

data {
    int<lower=0> N;                 // Number of Sites 
    int<lower=1> T;
    int S[T];         // Stem Cell Number
    real<lower=0,upper=1> y[N] ;     // Fraction methylated
    real age;
}

transformed data {
  real log_unif;
  log_unif = -log(T);
}

parameters {
    real<lower=0> lam;            // Replacement rate
    real<lower=0> mu;                // Methylation rate
    real<lower=0> gamma;             // Demethylation rate   
    real<lower=0> sigma;
}

model {
    vector[T] lp;

    lp = rep_vector(log_unif, T);

    for (t in 1:T){
        lp[t] += noisy_stemcell_lpdf(y | lam, mu, gamma, sigma, age, S[t]);
    }

    lam ~ normal(0, 1);               // Prior
    mu ~ normal(0, 0.01);                // Prior
    gamma ~ normal(0, 0.01);             // Prior
    sigma ~ normal(0, 0.1);          // Prior

    target += log_sum_exp(lp);
}

Remember that each operation results in a new node on the expression graph with its own value and adjoint stored in memory. Given how iterative the user-defined functions are implemented I wouldn’t be surprised if it’s the expression graph blowing out all of the memory here.

2 Likes

Is the RAM usage due to sampling or C++ compilation? C++ compilation can take ~4gb of RAM.

Thanks all for the responses.

I’m reasonably sure that @betanalpha is correct and this is due to the way the sampling is implemented. The C++ code complies fine and there are no issues if I set the true S value to 7 for my simulated data and loop over S=[5,6,7,8,9], but if I set S equal 11 and loop over [9,10,11,12,13] then the computation uses significantly more RAM. The code involves creating a matrix of order S^2, but given the low maximum possible S value I didn’t think this would be too much of a problem. When I run this on the noiseless data (so not iterating over the different possible S values, S=11) then the sampling uses ~4Gb RAM per chain and the sampling is performed rapidly.