I’ve set up a nonlinear model in brms that I would like to run with the reduce_sum
function. I don’t have much experience coding directly in Stan, so please let me know if I should provide any additional information that would be helpful.
The setup for this model in brms is:
bf_1 <- bf(Accuracy ~ ppe(repetition, model_time_z, stability_z, b, m, bl, tau),
b + m + bl + tau ~ 1 + (1|User),
nl = T)
prior_1 <-
prior(normal(0, 1), nlpar = 'b') +
prior(normal(0, 1), nlpar = 'm') +
prior(normal(0, 1), nlpar = 'bl') +
prior(normal(0, 1), nlpar = 'tau')
fit_1 <- brm(bf_1,
prior = prior_1,
stanvar = stanvars,
data = dat_prepped)
With the nonlinear function (called “ppe”) supplied to the stanvar
argument.
The Stan code produced by brms is:
// generated with brms 2.12.0
functions {
real ppe(real N,
real T,
real St,
real b,
real m,
real bl,
real tau) {
real s = .1;
real c = .1;
real d;
real M;
real P;
real forget;
real learn;
d = b + m * St;
forget = T^-d;
learn = (bl + N) ^ c;
M = learn * forget;
P = (tau - M)/s;
return P;
}
}
data {
int<lower=1> N; // number of observations
int Y[N]; // response variable
int<lower=1> K_b; // number of population-level effects
matrix[N, K_b] X_b; // population-level design matrix
int<lower=1> K_m; // number of population-level effects
matrix[N, K_m] X_m; // population-level design matrix
int<lower=1> K_bl; // number of population-level effects
matrix[N, K_bl] X_bl; // population-level design matrix
int<lower=1> K_tau; // number of population-level effects
matrix[N, K_tau] X_tau; // population-level design matrix
// covariate vectors for non-linear functions
int C_1[N];
vector[N] C_2;
vector[N] C_3;
// data for group-level effects of ID 1
int<lower=1> N_1; // number of grouping levels
int<lower=1> M_1; // number of coefficients per level
int<lower=1> J_1[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_1_b_1;
// data for group-level effects of ID 2
int<lower=1> N_2; // number of grouping levels
int<lower=1> M_2; // number of coefficients per level
int<lower=1> J_2[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_2_m_1;
// data for group-level effects of ID 3
int<lower=1> N_3; // number of grouping levels
int<lower=1> M_3; // number of coefficients per level
int<lower=1> J_3[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_3_bl_1;
// data for group-level effects of ID 4
int<lower=1> N_4; // number of grouping levels
int<lower=1> M_4; // number of coefficients per level
int<lower=1> J_4[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_4_tau_1;
int prior_only; // should the likelihood be ignored?
}
transformed data {
}
parameters {
vector[K_b] b_b; // population-level effects
vector[K_m] b_m; // population-level effects
vector[K_bl] b_bl; // population-level effects
vector[K_tau] b_tau; // population-level effects
vector<lower=0>[M_1] sd_1; // group-level standard deviations
vector[N_1] z_1[M_1]; // standardized group-level effects
vector<lower=0>[M_2] sd_2; // group-level standard deviations
vector[N_2] z_2[M_2]; // standardized group-level effects
vector<lower=0>[M_3] sd_3; // group-level standard deviations
vector[N_3] z_3[M_3]; // standardized group-level effects
vector<lower=0>[M_4] sd_4; // group-level standard deviations
vector[N_4] z_4[M_4]; // standardized group-level effects
}
transformed parameters {
vector[N_1] r_1_b_1; // actual group-level effects
vector[N_2] r_2_m_1; // actual group-level effects
vector[N_3] r_3_bl_1; // actual group-level effects
vector[N_4] r_4_tau_1; // actual group-level effects
r_1_b_1 = (sd_1[1] * (z_1[1]));
r_2_m_1 = (sd_2[1] * (z_2[1]));
r_3_bl_1 = (sd_3[1] * (z_3[1]));
r_4_tau_1 = (sd_4[1] * (z_4[1]));
}
model {
// initialize linear predictor term
vector[N] nlp_b = X_b * b_b;
// initialize linear predictor term
vector[N] nlp_m = X_m * b_m;
// initialize linear predictor term
vector[N] nlp_bl = X_bl * b_bl;
// initialize linear predictor term
vector[N] nlp_tau = X_tau * b_tau;
// initialize non-linear predictor term
vector[N] mu;
for (n in 1:N) {
// add more terms to the linear predictor
nlp_b[n] += r_1_b_1[J_1[n]] * Z_1_b_1[n];
}
for (n in 1:N) {
// add more terms to the linear predictor
nlp_m[n] += r_2_m_1[J_2[n]] * Z_2_m_1[n];
}
for (n in 1:N) {
// add more terms to the linear predictor
nlp_bl[n] += r_3_bl_1[J_3[n]] * Z_3_bl_1[n];
}
for (n in 1:N) {
// add more terms to the linear predictor
nlp_tau[n] += r_4_tau_1[J_4[n]] * Z_4_tau_1[n];
}
for (n in 1:N) {
// compute non-linear predictor values
mu[n] = ppe(C_1[n] , C_2[n] , C_3[n] , nlp_b[n] , nlp_m[n] , nlp_bl[n] , nlp_tau[n]);
}
// priors including all constants
target += normal_lpdf(b_b | 0, 1);
target += normal_lpdf(b_m | 0, 1);
target += normal_lpdf(b_bl | 0, 1);
target += normal_lpdf(b_tau | 0, 1);
target += student_t_lpdf(sd_1 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_1[1] | 0, 1);
target += student_t_lpdf(sd_2 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_2[1] | 0, 1);
target += student_t_lpdf(sd_3 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_3[1] | 0, 1);
target += student_t_lpdf(sd_4 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_4[1] | 0, 1);
// likelihood including all constants
if (!prior_only) {
target += bernoulli_logit_lpmf(Y | mu);
}
}
generated quantities {
}
And I’ve tried to incorporate reduce_sum
by moving most of the Model code (everything but the priors) to a partial_sum
function:
// generated with brms 2.12.0
functions {
real ppe(real N,
real T,
real St,
real b,
real m,
real bl,
real tau) {
real s = .1;
real c = .1;
real d;
real M;
real P;
real forget;
real learn;
d = b + m * St;
forget = T^-d;
learn = (bl + N) ^ c;
M = learn * forget;
P = (tau - M)/s;
return P;
}
real partial_sum(int[] slice_n_Y, int start, int end,
int[] C_1, vector C_2, vector C_3,
int[] J_1, int[] J_2, int[] J_3, int[] J_4,
matrix X_b, matrix X_m, matrix X_bl, matrix X_tau,
vector b_b, vector b_m, vector b_bl, vector b_tau,
vector r_1_b_1, vector r_2_m_1, vector r_3_bl_1, vector r_4_tau_1,
vector Z_1_b_1, vector Z_2_m_1, vector Z_3_bl_1, vector Z_4_tau_1) {
// initialize linear predictor term
vector[size(slice_n_Y)] nlp_b = X_b * b_b;
// initialize linear predictor term
vector[size(slice_n_Y)] nlp_m = X_m * b_m;
// initialize linear predictor term
vector[size(slice_n_Y)] nlp_bl = X_bl * b_bl;
// initialize linear predictor term
vector[size(slice_n_Y)] nlp_tau = X_tau * b_tau;
vector[size(slice_n_Y)] mu;
for (n in 1:((end-start) + 1)){
// add more terms to the linear predictor
nlp_b[n] += r_1_b_1[J_1[n]] * Z_1_b_1[n];
}
for (n in 1:((end-start) + 1)){
// add more terms to the linear predictor
nlp_m[n] += r_2_m_1[J_2[n]] * Z_2_m_1[n];
}
for (n in 1:((end-start) + 1)){
// add more terms to the linear predictor
nlp_bl[n] += r_3_bl_1[J_3[n]] * Z_3_bl_1[n];
}
for (n in 1:((end-start) + 1)){
// add more terms to the linear predictor
nlp_tau[n] += r_4_tau_1[J_4[n]] * Z_4_tau_1[n];
}
for (n in 1:((end-start) + 1)){
mu[n] = ppe(C_1[n], C_2[n] , C_3[n] , nlp_b[n] , nlp_m[n] , nlp_bl[n] , nlp_tau[n]);
}
return bernoulli_logit_lpmf(slice_n_Y |
mu[1:((end-start) + 1)]);
}
}
data {
int<lower=1> N; // number of observations
int Y[N]; // response variable
int<lower=1> K_b; // number of population-level effects
matrix[N, K_b] X_b; // population-level design matrix
int<lower=1> K_m; // number of population-level effects
matrix[N, K_m] X_m; // population-level design matrix
int<lower=1> K_bl; // number of population-level effects
matrix[N, K_bl] X_bl; // population-level design matrix
int<lower=1> K_tau; // number of population-level effects
matrix[N, K_tau] X_tau; // population-level design matrix
// covariate vectors for non-linear functions
int C_1[N];
vector[N] C_2;
vector[N] C_3;
// data for group-level effects of ID 1
int<lower=1> N_1; // number of grouping levels
int<lower=1> M_1; // number of coefficients per level
int<lower=1> J_1[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_1_b_1;
// data for group-level effects of ID 2
int<lower=1> N_2; // number of grouping levels
int<lower=1> M_2; // number of coefficients per level
int<lower=1> J_2[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_2_m_1;
// data for group-level effects of ID 3
int<lower=1> N_3; // number of grouping levels
int<lower=1> M_3; // number of coefficients per level
int<lower=1> J_3[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_3_bl_1;
// data for group-level effects of ID 4
int<lower=1> N_4; // number of grouping levels
int<lower=1> M_4; // number of coefficients per level
int<lower=1> J_4[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_4_tau_1;
int prior_only; // should the likelihood be ignored?
}
transformed data {
}
parameters {
vector[K_b] b_b; // population-level effects
vector[K_m] b_m; // population-level effects
vector[K_bl] b_bl; // population-level effects
vector[K_tau] b_tau; // population-level effects
vector<lower=0>[M_1] sd_1; // group-level standard deviations
vector[N_1] z_1[M_1]; // standardized group-level effects
vector<lower=0>[M_2] sd_2; // group-level standard deviations
vector[N_2] z_2[M_2]; // standardized group-level effects
vector<lower=0>[M_3] sd_3; // group-level standard deviations
vector[N_3] z_3[M_3]; // standardized group-level effects
vector<lower=0>[M_4] sd_4; // group-level standard deviations
vector[N_4] z_4[M_4]; // standardized group-level effects
}
transformed parameters {
vector[N_1] r_1_b_1; // actual group-level effects
vector[N_2] r_2_m_1; // actual group-level effects
vector[N_3] r_3_bl_1; // actual group-level effects
vector[N_4] r_4_tau_1; // actual group-level effects
r_1_b_1 = (sd_1[1] * (z_1[1]));
r_2_m_1 = (sd_2[1] * (z_2[1]));
r_3_bl_1 = (sd_3[1] * (z_3[1]));
r_4_tau_1 = (sd_4[1] * (z_4[1]));
}
model {
// priors including all constants
target += normal_lpdf(b_b | 0, 1);
target += normal_lpdf(b_m | 0, 1);
target += normal_lpdf(b_bl | 0, 1);
target += normal_lpdf(b_tau | 0, 1);
target += student_t_lpdf(sd_1 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_1[1] | 0, 1);
target += student_t_lpdf(sd_2 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_2[1] | 0, 1);
target += student_t_lpdf(sd_3 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_3[1] | 0, 1);
target += student_t_lpdf(sd_4 | 3, 0, 10)
- 1 * student_t_lccdf(0 | 3, 0, 10);
target += normal_lpdf(z_4[1] | 0, 1);
// likelihood including all constants
if (!prior_only) {
target += reduce_sum(partial_sum, Y, 1, C_1, C_2, C_3,
J_1, J_2, J_3, J_4,
X_b, X_m, X_bl, X_tau,
b_b, b_m, b_bl, b_tau,
r_1_b_1, r_2_m_1, r_3_bl_1, r_4_tau_1,
Z_1_b_1, Z_2_m_1, Z_3_bl_1, Z_4_tau_1);
}
}
generated quantities {
}
This script will run in cmdstanr, but the time to fit is about twice as long even though I’m using all 8 of my logical cores. I’m almost certain that I haven’t made sound use of reduce_sum
and was hoping for a bit of guidance on how I can improve the model. Thanks!