Before I realized that the brms built-in negative binomial was parameterized the way I needed, I had started to create my own custom family. I found that in some circumstances the fit converged better with the custom family than in the built-in. Below is an example that shows the custom family sampling better than the built-in. I compared the stancode and it differs in a couple of places…that is shown below the R code.
The question is why the methods of fitting perform differently and if I am using the brms built-in negative binomial incorrectly. I hope not to debate whether my simulation, model, or priors make sense (unless that directly pertains to why the methods perform differently), sense they were chosen to illustrate the issue.
Set up…
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(brms)
library(tidyverse)
library(diffobj)
Create custom family
neg_binomial_2 <- custom_family("neg_binomial_2",
dpars = c("mu", "phi"),
links =c("log", "log"),
type = "int")
Create a simulated data set
set.seed(23)
n <- 1000
df <- tibble(
count=rnbinom(n=n,
size=c(1,10), #different shape per treatment
mu=c(10,4)), #different mean per treatment
treatment=rep(c("A","B"), n/2),
group=as.character(rep(1:10,each=n/10))) #some grouping factor
df <- df %>%
group_by(group) %>%
mutate(count = count + rnbinom(n=1, size=20, mu=1)) # add a grouping effect
# I am not actually going to use the grouping effect in the initial models
# but having the unaccounted for noise in the data helps to illustrate the differences between
# the built-in in and the custom family. The issue can persist when the grouping effects
# are included in the model.
head(df,12)
df %>% group_by(treatment) %>%
summarize(mean=mean(count), variance=var(count),
phi=mean(count)^2 / (var(count) - mean(count)))
Fit using built-in
system.time(fit1 <- brm(bf(count ~ treatment, shape ~ treatment),
family = negbinomial(),
prior=set_prior("student_t(3,0,10)", class="b"),
data=df))
summary(fit1)
The above takes about 1400 seconds on my computer (4core Mac). The effective samples are all less than 10, and Rhat is >> 1.
Fit using custom family
system.time(fit2 <- brm(bf(count ~ treatment, phi ~ treatment),
family = neg_binomial_2,
prior=set_prior("student_t(3,0,10)", class="b"),
data = df))
summary(fit2)
The custom family fit takes about 200 seconds. The effective samples are >1000 for 3 of the 4 parameters and Rhat is 1.00 for those three and 1.04 for the last. So something is clearly different.
Stancode differences
diffObj(stancode(fit1), stancode(fit2))
Apart from the trivial (phi vs shape), the following are different.
The built-in does not transform mu
for (n in 1:N) {
shape[n] = exp(shape[n]);
}
for (n in 1:N) {
phi[n] = exp(phi[n]);
mu[n] = exp(mu[n]);
}
the likelihood function, and the way it is called are different.
// likelihood including all constants
if (!prior_only) {
target += neg_binomial_2_log_lpmf(Y | mu, shape);
}
// likelihood including all constants
if (!prior_only) {
for (n in 1:N) {
target += neg_binomial_2_lpmf(Y[n] | mu[n], phi[n]);
}
}
I assume that mu is not transformed in the built-in because log_lpms if called and that does the transformation internally (?)
So…why does the custom come closer to convergence? Am I calling the built-in incorrectly?
Thanks,
Julin
Full stan code
built-in
> stancode(fit1)
// generated with brms 2.4.3
functions {
}
data {
int<lower=1> N; // total number of observations
int Y[N]; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
int<lower=1> K_shape; // number of population-level effects
matrix[N, K_shape] X_shape; // population-level design matrix
int prior_only; // should the likelihood be ignored?
}
transformed data {
int Kc = K - 1;
matrix[N, K - 1] Xc; // centered version of X
vector[K - 1] means_X; // column means of X before centering
int Kc_shape = K_shape - 1;
matrix[N, K_shape - 1] Xc_shape; // centered version of X_shape
vector[K_shape - 1] means_X_shape; // column means of X_shape before centering
for (i in 2:K) {
means_X[i - 1] = mean(X[, i]);
Xc[, i - 1] = X[, i] - means_X[i - 1];
}
for (i in 2:K_shape) {
means_X_shape[i - 1] = mean(X_shape[, i]);
Xc_shape[, i - 1] = X_shape[, i] - means_X_shape[i - 1];
}
}
parameters {
vector[Kc] b; // population-level effects
real temp_Intercept; // temporary intercept
vector[Kc_shape] b_shape; // population-level effects
real temp_shape_Intercept; // temporary intercept
}
transformed parameters {
}
model {
vector[N] mu = temp_Intercept + Xc * b;
vector[N] shape = temp_shape_Intercept + Xc_shape * b_shape;
for (n in 1:N) {
shape[n] = exp(shape[n]);
}
// priors including all constants
target += student_t_lpdf(b | 3,0,10);
target += student_t_lpdf(temp_Intercept | 3, 2, 10);
target += student_t_lpdf(temp_shape_Intercept | 3, 0, 10);
// likelihood including all constants
if (!prior_only) {
target += neg_binomial_2_log_lpmf(Y | mu, shape);
}
}
generated quantities {
// actual population-level intercept
real b_Intercept = temp_Intercept - dot_product(means_X, b);
// actual population-level intercept
real b_shape_Intercept = temp_shape_Intercept - dot_product(means_X_shape, b_shape);
}
custom
> stancode(fit2)
// generated with brms 2.4.3
functions {
}
data {
int<lower=1> N; // total number of observations
int Y[N]; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
int<lower=1> K_phi; // number of population-level effects
matrix[N, K_phi] X_phi; // population-level design matrix
int prior_only; // should the likelihood be ignored?
}
transformed data {
int Kc = K - 1;
matrix[N, K - 1] Xc; // centered version of X
vector[K - 1] means_X; // column means of X before centering
int Kc_phi = K_phi - 1;
matrix[N, K_phi - 1] Xc_phi; // centered version of X_phi
vector[K_phi - 1] means_X_phi; // column means of X_phi before centering
for (i in 2:K) {
means_X[i - 1] = mean(X[, i]);
Xc[, i - 1] = X[, i] - means_X[i - 1];
}
for (i in 2:K_phi) {
means_X_phi[i - 1] = mean(X_phi[, i]);
Xc_phi[, i - 1] = X_phi[, i] - means_X_phi[i - 1];
}
}
parameters {
vector[Kc] b; // population-level effects
real temp_Intercept; // temporary intercept
vector[Kc_phi] b_phi; // population-level effects
real temp_phi_Intercept; // temporary intercept
}
transformed parameters {
}
model {
vector[N] mu = temp_Intercept + Xc * b;
vector[N] phi = temp_phi_Intercept + Xc_phi * b_phi;
for (n in 1:N) {
phi[n] = exp(phi[n]);
mu[n] = exp(mu[n]);
}
// priors including all constants
target += student_t_lpdf(b | 3,0,10);
target += student_t_lpdf(temp_Intercept | 3, 2, 10);
target += student_t_lpdf(temp_phi_Intercept | 3, 0, 10);
// likelihood including all constants
if (!prior_only) {
for (n in 1:N) {
target += neg_binomial_2_lpmf(Y[n] | mu[n], phi[n]);
}
}
}
generated quantities {
// actual population-level intercept
real b_Intercept = temp_Intercept - dot_product(means_X, b);
// actual population-level intercept
real b_phi_Intercept = temp_phi_Intercept - dot_product(means_X_phi, b_phi);
}
Session info
> sessionInfo()
R version 3.5.1 (2018-07-02)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS 10.14.2
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] bindrcpp_0.2.2 diffobj_0.2.2 forcats_0.3.0 stringr_1.3.1
[5] dplyr_0.7.6 purrr_0.2.5 readr_1.1.1 tidyr_0.8.1
[9] tibble_1.4.2 tidyverse_1.2.1 brms_2.4.3 Rcpp_0.12.18
[13] rstan_2.17.3 StanHeaders_2.17.2 ggplot2_3.0.0
loaded via a namespace (and not attached):
[1] httr_1.3.1 Brobdingnag_1.2-6 jsonlite_1.5 modelr_0.1.2
[5] gtools_3.8.1 threejs_0.3.1 shiny_1.1.0 assertthat_0.2.0
[9] stats4_3.5.1 cellranger_1.1.0 yaml_2.2.0 pillar_1.3.0
[13] backports_1.1.2 lattice_0.20-35 glue_1.3.0 digest_0.6.16
[17] promises_1.0.1 rvest_0.3.2 colorspace_1.3-2 htmltools_0.3.6
[21] httpuv_1.4.5 Matrix_1.2-14 plyr_1.8.4 dygraphs_1.1.1.6
[25] pkgconfig_2.0.2 broom_0.5.0 haven_1.1.2 xtable_1.8-2
[29] mvtnorm_1.0-8 scales_1.0.0 later_0.7.3 bayesplot_1.6.0
[33] DT_0.4 withr_2.1.2 shinyjs_1.0 lazyeval_0.2.1
[37] cli_1.0.0 readxl_1.1.0 magrittr_1.5 crayon_1.3.4
[41] mime_0.5 nlme_3.1-137 xml2_1.2.0 xts_0.11-0
[45] colourpicker_1.0 rsconnect_0.8.8 tools_3.5.1 loo_2.0.0
[49] hms_0.4.2 matrixStats_0.54.0 munsell_0.5.0 compiler_3.5.1
[53] rlang_0.2.2.9000 grid_3.5.1 ggridges_0.5.0 rstudioapi_0.7
[57] htmlwidgets_1.2 crosstalk_1.0.0 igraph_1.2.2 miniUI_0.1.1.1
[61] base64enc_0.1-3 codetools_0.2-15 gtable_0.2.0 inline_0.3.15
[65] abind_1.4-5 markdown_0.8 reshape2_1.4.3 R6_2.2.2
[69] lubridate_1.7.4 gridExtra_2.3 rstantools_1.5.1 zoo_1.8-3
[73] knitr_1.20 bridgesampling_0.5-2 bindr_0.1.1 shinystan_2.5.0
[77] shinythemes_1.1.1 stringi_1.2.4 parallel_3.5.1 tidyselect_0.2.4
[81] coda_0.19-1
>