Hi,
The following Stan code fails to run on my laptop because of the large RAM requirements, even with as few as 5000 data points (with 9 \le S \le 13 ). I’ve not seen these issues with RAM with Stan before, so I’m a bit lost on how to edit the code to have a smaller RAM requirement. The code involves making a matrix of size ~S^2, but S is a relatively small number (<20) so I don’t see why this should contribute Gb’s worth of RAM. Any help in bringing down the RAM requirements would be appreciated.
Thanks.
functions{
matrix generateRateMatrix(int[,] stateVar, int num_states, real lam, real mu, real gamma, int s){
matrix[num_states, num_states] RateMatrix;
int k_down;
int m_down;
int k;
int m;
RateMatrix = rep_matrix(0., num_states, num_states);
for (down in 1:num_states){
k_down = stateVar[down, 1];
m_down = stateVar[down, 2];
for (across in 1:num_states){
k = stateVar[across, 1];
m = stateVar[across, 2];
if (k == k_down-1 && m == m_down){
RateMatrix[down, across] = (s-m-k) * (k*lam/(s-1) + 2*mu);
}
else if (k == k_down && m == m_down-1){
RateMatrix[down, across] = m * (s-m-k) * lam / (s-1);
}
else if (k == k_down+1 && m == m_down-1){
RateMatrix[down, across] = k * (m*lam/(s-1)+mu);
}
else if (k == k_down+1 && m == m_down){
RateMatrix[down, across] = k * ((s-m-k)*lam/(s-1)+gamma);
}
else if (k == k_down && m == m_down+1){
RateMatrix[down, across] = m * (s-m-k) * lam / (s-1);
}
else if (k == k_down-1 && m == m_down+1){
RateMatrix[down, across] = m * (k*lam/(s-1)+2*gamma);
}
else if (k == k_down && m == m_down){
RateMatrix[down, across] = -(2*((k+m)*(s-m-k)+k*m)*lam/(s-1)
+ (k+2*m)*gamma + (2*s-(k+2*m))*mu);
}
}
}
return RateMatrix;
}
int[,] generatestateVar(int s, int num_states){
int stateVar[num_states, 2];
int i;
i = 1;
for (m in 0:s){
for (k in 0:s){
if (k+m <=s ){
stateVar[i, 1] = k;
stateVar[i, 2] = m;
i += 1;
}
}
}
return stateVar;
}
vector findProbDist(matrix RateMatrix, int[,] stateVar, int num_states, int s, real age){
matrix[num_states, 1] InitialConditions;
matrix[num_states, 1] ProbStates;
vector[2*s+1] ProbDist;
int k;
int m;
InitialConditions = rep_matrix(0., num_states, 1);
InitialConditions[1, 1] = 0.5;
InitialConditions[num_states, 1] = 0.5;
ProbStates = scale_matrix_exp_multiply(age, RateMatrix, InitialConditions);
ProbDist = rep_vector(0., 2*s+1);
for (i in 1:num_states){
k = stateVar[i, 1];
m = stateVar[i, 2];
ProbDist[k+2*m+1] += ProbStates[i, 1];
}
return ProbDist;
}
vector wrapper_function(int num_states, real lam, real mu, real gamma, int s, real age){
int stateVar[num_states, 2];
matrix[num_states, num_states] RateMatrix;
vector[2*s+1] ProbDist;
stateVar = generatestateVar(s, num_states);
RateMatrix = generateRateMatrix(stateVar, num_states, lam, mu, gamma, s);
ProbDist = findProbDist(RateMatrix, stateVar, num_states, s, age);
return ProbDist;
}
vector runModel(real lam, real mu, real gamma, int s, real age){
int num_states;
vector[2*s+1] ProbDist;
num_states = ((s+1) * (s+2)) / 2; // (s+1) * (s+2) is guaranteed to be even
ProbDist = wrapper_function(num_states, lam, mu, gamma, s, age);
return ProbDist;
}
real noisy_stemcell_lpdf(real[] y, real lam, real mu, real gamma, real sigma, real age, int s){
int N = num_elements(y);
int K = 2*s+1;
vector[K] ProbDist;
vector[K] lProbDist;
vector[K] peak;
vector[K] ltruncate;
vector[N] LL;
for (k in 1:K){
peak[k] = (k-1.0) ./ (2.0*s);
ltruncate[k] = log_diff_exp(normal_lcdf(1 | peak[k], sigma), normal_lcdf(0 | peak[k], sigma));
}
ProbDist = runModel(lam, mu, gamma, s, age);
lProbDist = log(ProbDist);
for (n in 1:N){
vector[K] lps = lProbDist - ltruncate;
for (k in 1:K){
lps[k] += normal_lpdf(y[n] | peak[k], sigma);
}
LL[n] = log_sum_exp(lps);
}
return sum(LL);
}
real normal_lub_rng(real mean, real sigma, real lb, real ub) {
real y_hat;
y_hat = normal_rng(mean, sigma);
while (y_hat < lb || y_hat > ub){
y_hat = normal_rng(mean, sigma);
}
return y_hat;
}
real random_draw_rng(real lam, real mu, real gamma, real age, int s, real sigma){
int K = 2*s+1;
real y_hat;
real peak;
vector[K] ProbDist;
int cat;
ProbDist = runModel(lam, mu, gamma, s, age);
cat = categorical_rng(ProbDist);
peak = (cat-1.0) ./ (2.0*s);
y_hat = normal_lub_rng(peak, sigma, 0.0, 1.0);
return y_hat;
}
}
data {
int<lower=0> N; // Number of Sites
int<lower=1> T;
int S[T]; // Stem Cell Number
real<lower=0,upper=1> y[N] ; // Fraction methylated
real age;
}
transformed data {
real log_unif;
log_unif = -log(T);
}
parameters {
real<lower=0> lam; // Replacement rate
real<lower=0> mu; // Methylation rate
real<lower=0> gamma; // Demethylation rate
real<lower=0> sigma;
}
transformed parameters{
vector[T] lp;
lp = rep_vector(log_unif, T);
for (t in 1:T){
lp[t] += noisy_stemcell_lpdf(y | lam, mu, gamma, sigma, age, S[t]);
}
}
model {
lam ~ normal(0, 1); // Prior
mu ~ normal(0, 0.01); // Prior
gamma ~ normal(0, 0.01); // Prior
sigma ~ normal(0, 0.1); // Prior
target += log_sum_exp(lp);
}
generated quantities{
int<lower=1, upper=T> t;
real<lower=0,upper=1> y_hat;
t = categorical_logit_rng(lp);
y_hat = random_draw_rng(lam, mu, gamma, age, S[t], sigma);
}