I’m working on a specific model and implementing it in both Stan and brms for comparison. In brms it requires a custom family. I found that the brms model appears to sample less well than than the Stan version, even though I think they should be nearly equivalent. Is brms doing something that I’m not aware of? Any insight is really appreciated!
Here is the model in math. The code below simulates a dataset and fits the model in both brms and stan
library(tidyverse)
library(rstan)
library(brms)
ngenotypes <- 300
rep_per_geno <- 3
r_bar <- 1.3
r_sd <- .3
r_geno <- rnorm(n = ngenotypes, mean = r_bar, sd = r_sd)
curve(1*exp(r_bar*x), xlim = c(0,2.5), lwd = 2)
walk(r_geno, ~curve(exp(.*x), add = TRUE, lwd = .5))
## project the growth of each original female
aphid_clone_data <- expand_grid(
rep_id = 1:rep_per_geno,
clone_id = 1:ngenotypes,
first_aphid = 1:2,
) |>
mutate(
clone_r = r_geno[clone_id],
expect_aphids = 1*exp(clone_r*2),
obs_aphids = rpois(n = length(expect_aphids), lambda = expect_aphids)
)
# add in mortality and combine aphids
aphid_clone_mort_dat <- aphid_clone_data |>
mutate(surv = rbinom(length(obs_aphids), size = 1, prob = .8),
obs_aphids_alive = obs_aphids * surv)
aphid_clone_mort_sum <- aphid_clone_mort_dat |>
group_by(clone_id, rep_id) |>
summarize(tot_aphids = sum(obs_aphids_alive))
#> `summarise()` has grouped output by 'clone_id'. You can override using the
#> `.groups` argument.
# visualize:
aphid_clone_mort_sum |>
ggplot(aes(x = tot_aphids)) +
geom_histogram()
#> `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
# try it with brms
poisson_mix_mortality <- custom_family(
"poisson_mix_mortality",
dpars = c("mu", "m"),
links = c("identity", "identity"),
lb = c(NA, 0), ub = c(NA,1),
type = "int",
#vars = "vint1[n]"
loop = TRUE
)
poisson_mix_mortality_fns <- "
real poisson_mix_mortality_lpmf(int abd_i, real mu, real m) {
real ll;
if (abd_i == 0) {
ll = log_sum_exp(
[
2 * log(m),
log(2) + log(m) + log1m(m) + poisson_log_lpmf(abd_i | mu),
2 * log1m(m) + poisson_log_lpmf(abd_i | log(2) + mu)
]
);
} else {
ll = log_sum_exp(
log(2) + log(m) + log1m(m) + poisson_log_lpmf(abd_i | mu),
2 * log1m(m) + poisson_log_lpmf(abd_i | log(2) + mu)
);
}
return ll;
}
int poisson_mix_mortality_rng(real mu, real m) {
real p1 = square(m); // Pr[0] component: both die
real p2 = 2 * m * (1 - m); // One dies, one lives
real p3 = square(1 - m); // Both live
// Normalize to ensure valid probabilities
real total = p1 + p2 + p3;
p1 /= total;
p2 /= total;
p3 /= total;
// Sample which mortality path to take
real u = uniform_rng(0, 1);
int n;
if (u < p1) {
n = 0; // both dead
} else if (u < p1 + p2) {
n = poisson_log_rng(mu); // one survives
} else {
n = poisson_log_rng(log(2) + mu); // both survive
}
return n;
}
"
stanvars <- stanvar(scode = poisson_mix_mortality_fns,
block = "functions")
# fit the model! NOTE that this is the "wrong" model because there actually ARE
# differences between clones, and this model ignores them. To test it properly,
# run the simulation above but use sd = 0 or a very small number
poisson_mix_brm <- brm(
tot_aphids ~ 0 + time,
data = aphid_clone_mort_sum |>
mutate(time = 2),
family = poisson_mix_mortality,
stanvars = stanvars,
chains = 1
)
#> Compiling Stan program...
#> Trying to compile a simple C file
poisson_mix_bf <- bf(
tot_aphids ~ 0 + time + (0 + time | clone_id),
family = poisson_mix_mortality)
get_prior(poisson_mix_bf,
data = aphid_clone_mort_sum |>
mutate(time = 2))
#> prior class coef group resp dpar nlpar lb ub source
#> (flat) b default
#> (flat) b time (vectorized)
#> (flat) m 0 1 default
#> student_t(3, 0, 14.8) sd 0 default
#> student_t(3, 0, 14.8) sd clone_id 0 (vectorized)
#> student_t(3, 0, 14.8) sd time clone_id 0 (vectorized)
pois_mix_prior <- c(
prior(normal(1.5, .2), class = "b", coef = "time"),
prior(beta(7*.2, 7*(1-.2)), class = "m", lb = 0, ub = 1),
prior(exponential(3), class="sd")
)
hier_poisson_mix_brm <- brm(
tot_aphids ~ 0 + time + (0 + time | clone_id),
data = aphid_clone_mort_sum |>
mutate(time = 2),
family = poisson_mix_mortality,
stanvars = stanvars,
prior = pois_mix_prior,
chains = 4, cores = 4, refresh=0
)
#> Compiling Stan program...
#> Trying to compile a simple C file
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#bulk-ess
#> Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#tail-ess
hier_poisson_mix_brm
#> Family: poisson_mix_mortality
#> Links: mu = identity; m = identity
#> Formula: tot_aphids ~ 0 + time + (0 + time | clone_id)
#> Data: mutate(aphid_clone_mort_sum, time = 2) (Number of observations: 900)
#> Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#> total post-warmup draws = 4000
#>
#> Multilevel Hyperparameters:
#> ~clone_id (Number of levels: 300)
#> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> sd(time) 0.31 0.01 0.28 0.33 1.01 693 1136
#>
#> Regression Coefficients:
#> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> time 1.27 0.02 1.24 1.31 1.00 603 1167
#>
#> Further Distributional Parameters:
#> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> m 0.19 0.01 0.17 0.21 1.00 2050 2942
#>
#> Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).
however the Stan model works and gives better samples (though still not perfect)
mixture_stan <- rstan::stan(
model_code = "
data{
int n;
int nclone;
real time;
// real N0;
array[n] int abd;
array[n] int clone_id;
}
parameters {
real r_bar;
real<lower=0,upper=1> m;
real<lower=0> r_sd;
vector[nclone] r_i;
}
model {
// priors
m ~ beta(7*.2, 7*(1-.2));
r_bar ~ normal(1.5, .2);
r_sd ~ exponential(3);
r_i ~ normal(r_bar, r_sd);
for (i in 1:n) {
if (abd[i] == 0) {
target += log_sum_exp(
[
2*log(m),
log(2)+log(m)+log1m(m)
+ poisson_log_lpmf(abd[i] | r_i[clone_id[i]]*time),
2*log1m(m)
+ poisson_log_lpmf(abd[i] | log(2) + r_i[clone_id[i]]*time)
]
);
} else {
target += log_sum_exp(
log(2)+log(m)+log1m(m)
+ poisson_log_lpmf(abd[i] | r_i[clone_id[i]]*time),
2*log1m(m)
+ poisson_log_lpmf(abd[i] | log(2) + r_i[clone_id[i]]*time)
);
}
}
}
",
data = list(n = nrow(aphid_clone_mort_sum),
nclone = max(aphid_clone_mort_sum$clone_id),
time = 2,
abd = aphid_clone_mort_sum$tot_aphids,
clone_id = aphid_clone_mort_sum$clone_id
),
refresh=0)
#> Trying to compile a simple C file
#> Running /usr/lib/R/bin/R CMD SHLIB foo.c
#> using C compiler: ‘gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0’
mixture_stan
#> Inference for Stan model: anon_model.
#> 4 chains, each with iter=2000; warmup=1000; thin=1;
#> post-warmup draws per chain=1000, total post-warmup draws=4000.
#>
#> mean se_mean sd 2.5% 25% 50% 75% 97.5%
#> r_bar 1.27 0.00 0.02 1.24 1.26 1.27 1.29 1.31
#> m 0.19 0.00 0.01 0.17 0.18 0.19 0.20 0.22
#> r_sd 0.31 0.00 0.01 0.28 0.30 0.31 0.32 0.34
#> n_eff Rhat
#> r_bar 2758 1.00
#> m 544 1.02
#> r_sd 727 1.02
#>
#> Samples were drawn using NUTS(diag_e) at Sun May 25 10:44:47 2025.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at
#> convergence, Rhat=1).
Created on 2025-05-25 with reprex v2.1.1