Hi guys,
I came up with the following Stan model and I was wondering whether you have some advise on how to improve the performance (both computationally and statistically). Note that I will apply the model to cases where N_uncensored+N_censored
will be around 500k, NC
(number of covariates) around 20 and m
at most 10. The matrices basis_evals_censored, basis_evals_uncensored
and deriv_basis_evals_uncensored
contain the evaluation (and derivative thereof) of the corresponding m
i-spline basis function for each of the individuals, at the corresponding survival time of that individual.
Any optimisation for the matrix / vector multiplications, that could lead to performance gains?
Ultimately I would like to reduce runtime.
data {
int<lower=0> N_uncensored; // number of uncensored data points
int<lower=0> N_censored; // number of censored data points
int<lower=1> m; // number of basis splines
int<lower=1> NC; // number of covariates
matrix[N_censored,NC] Q_censored; // Q-transf. of design matrix (censored)
matrix[N_uncensored,NC] Q_uncensored; // Q-transf. of design matrix (uncensored)
matrix[NC, NC] R;
vector[N_censored] log_times_censored; // x=log(t) in the paper (censored)
vector[N_uncensored] log_times_uncensored; // x=log(t) in the paper (uncensored)
matrix[m,N_censored] basis_evals_censored; // ispline basis matrix (censored)
matrix[m,N_uncensored] basis_evals_uncensored; // ispline basis matrix (uncensored)
matrix[m,N_uncensored] deriv_basis_evals_uncensored; // derivatives of isplines matrix (uncensored)
}
transformed data {
matrix[NC,NC] R_inv = inverse(R);
}
parameters {
row_vector<lower=0>[m] gammas; // regression coefficients for splines
vector[NC] betas_tr; // regression coefficients for covariates
real gamma_intercept; // \gamma_0 in the paper
}
transformed parameters {
vector[NC] betas = R_inv * betas_tr;
}
model {
vector[N_censored] etas_censored;
vector[N_uncensored] etas_uncensored;
gammas ~ normal(0, 1);
betas ~ normal(0,1);
gamma_intercept ~ normal(0,1);
etas_censored = Q_censored*betas_tr + (gammas*basis_evals_censored)' + gamma_intercept;
etas_uncensored = Q_uncensored*betas_tr + (gammas*basis_evals_uncensored)' + gamma_intercept;
target += -exp(etas_censored);
target += etas_uncensored - exp(etas_uncensored) - log_times_uncensored + log(gammas*deriv_basis_evals_uncensored)';
}