Negative binomial shape and phi parameters

Before I realized that the brms built-in negative binomial was parameterized the way I needed, I had started to create my own custom family. I found that in some circumstances the fit converged better with the custom family than in the built-in. Below is an example that shows the custom family sampling better than the built-in. I compared the stancode and it differs in a couple of places…that is shown below the R code.

The question is why the methods of fitting perform differently and if I am using the brms built-in negative binomial incorrectly. I hope not to debate whether my simulation, model, or priors make sense (unless that directly pertains to why the methods perform differently), sense they were chosen to illustrate the issue.

Set up…

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(brms)
library(tidyverse)
library(diffobj)

Create custom family

neg_binomial_2 <- custom_family("neg_binomial_2",
                                dpars = c("mu", "phi"),
                                links =c("log", "log"),
                                type = "int")

Create a simulated data set

set.seed(23)
n <- 1000
df <- tibble(
  count=rnbinom(n=n,
                size=c(1,10), #different shape per treatment
                mu=c(10,4)), #different mean per treatment
  treatment=rep(c("A","B"), n/2),
group=as.character(rep(1:10,each=n/10))) #some grouping factor

df <- df %>%
  group_by(group) %>%
  mutate(count =  count + rnbinom(n=1, size=20, mu=1)) # add a grouping effect
# I am not actually going to use the grouping effect in the initial models 
# but having the unaccounted for noise in the data helps to illustrate the differences between 
# the built-in in and the custom family.  The issue can persist when the grouping effects
# are included in the model.

head(df,12)

df %>% group_by(treatment) %>%
  summarize(mean=mean(count), variance=var(count),
            phi=mean(count)^2 / (var(count) - mean(count)))

Fit using built-in

system.time(fit1 <- brm(bf(count ~ treatment, shape ~ treatment),
            family = negbinomial(),
            prior=set_prior("student_t(3,0,10)", class="b"), 
            data=df))

summary(fit1)

The above takes about 1400 seconds on my computer (4core Mac). The effective samples are all less than 10, and Rhat is >> 1.

Fit using custom family

system.time(fit2 <- brm(bf(count ~ treatment, phi ~ treatment),
            family = neg_binomial_2,
            prior=set_prior("student_t(3,0,10)", class="b"),
            data = df))

summary(fit2)

The custom family fit takes about 200 seconds. The effective samples are >1000 for 3 of the 4 parameters and Rhat is 1.00 for those three and 1.04 for the last. So something is clearly different.

Stancode differences

diffObj(stancode(fit1), stancode(fit2))

Apart from the trivial (phi vs shape), the following are different.

The built-in does not transform mu

  • built-in
for (n in 1:N) { 
   shape[n] = exp(shape[n]); 
 } 
  • custom
for (n in 1:N) { 
   phi[n] = exp(phi[n]); 
   mu[n] = exp(mu[n]); 
} 

the likelihood function, and the way it is called are different.

  • built-in
  // likelihood including all constants 
  if (!prior_only) { 
    target += neg_binomial_2_log_lpmf(Y | mu, shape);
  } 
  • custom
  // likelihood including all constants 
  if (!prior_only) { 
    for (n in 1:N) {
      target += neg_binomial_2_lpmf(Y[n] | mu[n], phi[n]);
    }
  } 

I assume that mu is not transformed in the built-in because log_lpms if called and that does the transformation internally (?)

So…why does the custom come closer to convergence? Am I calling the built-in incorrectly?

Thanks,

Julin

Full stan code

built-in

> stancode(fit1)
// generated with brms 2.4.3
functions { 
} 
data { 
  int<lower=1> N;  // total number of observations 
  int Y[N];  // response variable 
  int<lower=1> K;  // number of population-level effects 
  matrix[N, K] X;  // population-level design matrix 
  int<lower=1> K_shape;  // number of population-level effects 
  matrix[N, K_shape] X_shape;  // population-level design matrix 
  int prior_only;  // should the likelihood be ignored? 
} 
transformed data { 
  int Kc = K - 1; 
  matrix[N, K - 1] Xc;  // centered version of X 
  vector[K - 1] means_X;  // column means of X before centering 
  int Kc_shape = K_shape - 1; 
  matrix[N, K_shape - 1] Xc_shape;  // centered version of X_shape 
  vector[K_shape - 1] means_X_shape;  // column means of X_shape before centering 
  for (i in 2:K) { 
    means_X[i - 1] = mean(X[, i]); 
    Xc[, i - 1] = X[, i] - means_X[i - 1]; 
  } 
  for (i in 2:K_shape) { 
    means_X_shape[i - 1] = mean(X_shape[, i]); 
    Xc_shape[, i - 1] = X_shape[, i] - means_X_shape[i - 1]; 
  } 
} 
parameters { 
  vector[Kc] b;  // population-level effects 
  real temp_Intercept;  // temporary intercept 
  vector[Kc_shape] b_shape;  // population-level effects 
  real temp_shape_Intercept;  // temporary intercept 
} 
transformed parameters { 
} 
model { 
  vector[N] mu = temp_Intercept + Xc * b;
  vector[N] shape = temp_shape_Intercept + Xc_shape * b_shape;
  for (n in 1:N) { 
    shape[n] = exp(shape[n]); 
  } 
  // priors including all constants 
  target += student_t_lpdf(b | 3,0,10); 
  target += student_t_lpdf(temp_Intercept | 3, 2, 10); 
  target += student_t_lpdf(temp_shape_Intercept | 3, 0, 10); 
  // likelihood including all constants 
  if (!prior_only) { 
    target += neg_binomial_2_log_lpmf(Y | mu, shape);
  } 
} 
generated quantities { 
  // actual population-level intercept 
  real b_Intercept = temp_Intercept - dot_product(means_X, b); 
  // actual population-level intercept 
  real b_shape_Intercept = temp_shape_Intercept - dot_product(means_X_shape, b_shape); 
} 

custom

> stancode(fit2)
// generated with brms 2.4.3
functions { 
} 
data { 
  int<lower=1> N;  // total number of observations 
  int Y[N];  // response variable 
  int<lower=1> K;  // number of population-level effects 
  matrix[N, K] X;  // population-level design matrix 
  int<lower=1> K_phi;  // number of population-level effects 
  matrix[N, K_phi] X_phi;  // population-level design matrix 
  int prior_only;  // should the likelihood be ignored? 
} 
transformed data { 
  int Kc = K - 1; 
  matrix[N, K - 1] Xc;  // centered version of X 
  vector[K - 1] means_X;  // column means of X before centering 
  int Kc_phi = K_phi - 1; 
  matrix[N, K_phi - 1] Xc_phi;  // centered version of X_phi 
  vector[K_phi - 1] means_X_phi;  // column means of X_phi before centering 
  for (i in 2:K) { 
    means_X[i - 1] = mean(X[, i]); 
    Xc[, i - 1] = X[, i] - means_X[i - 1]; 
  } 
  for (i in 2:K_phi) { 
    means_X_phi[i - 1] = mean(X_phi[, i]); 
    Xc_phi[, i - 1] = X_phi[, i] - means_X_phi[i - 1]; 
  } 
} 
parameters { 
  vector[Kc] b;  // population-level effects 
  real temp_Intercept;  // temporary intercept 
  vector[Kc_phi] b_phi;  // population-level effects 
  real temp_phi_Intercept;  // temporary intercept 
} 
transformed parameters { 
} 
model { 
  vector[N] mu = temp_Intercept + Xc * b;
  vector[N] phi = temp_phi_Intercept + Xc_phi * b_phi;
  for (n in 1:N) { 
    phi[n] = exp(phi[n]); 
    mu[n] = exp(mu[n]); 
  } 
  // priors including all constants 
  target += student_t_lpdf(b | 3,0,10); 
  target += student_t_lpdf(temp_Intercept | 3, 2, 10); 
  target += student_t_lpdf(temp_phi_Intercept | 3, 0, 10); 
  // likelihood including all constants 
  if (!prior_only) { 
    for (n in 1:N) {
      target += neg_binomial_2_lpmf(Y[n] | mu[n], phi[n]);
    }
  } 
} 
generated quantities { 
  // actual population-level intercept 
  real b_Intercept = temp_Intercept - dot_product(means_X, b); 
  // actual population-level intercept 
  real b_phi_Intercept = temp_phi_Intercept - dot_product(means_X_phi, b_phi); 
} 

Session info

> sessionInfo()
R version 3.5.1 (2018-07-02)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS  10.14.2

Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] bindrcpp_0.2.2     diffobj_0.2.2      forcats_0.3.0      stringr_1.3.1     
 [5] dplyr_0.7.6        purrr_0.2.5        readr_1.1.1        tidyr_0.8.1       
 [9] tibble_1.4.2       tidyverse_1.2.1    brms_2.4.3         Rcpp_0.12.18      
[13] rstan_2.17.3       StanHeaders_2.17.2 ggplot2_3.0.0     

loaded via a namespace (and not attached):
 [1] httr_1.3.1           Brobdingnag_1.2-6    jsonlite_1.5         modelr_0.1.2        
 [5] gtools_3.8.1         threejs_0.3.1        shiny_1.1.0          assertthat_0.2.0    
 [9] stats4_3.5.1         cellranger_1.1.0     yaml_2.2.0           pillar_1.3.0        
[13] backports_1.1.2      lattice_0.20-35      glue_1.3.0           digest_0.6.16       
[17] promises_1.0.1       rvest_0.3.2          colorspace_1.3-2     htmltools_0.3.6     
[21] httpuv_1.4.5         Matrix_1.2-14        plyr_1.8.4           dygraphs_1.1.1.6    
[25] pkgconfig_2.0.2      broom_0.5.0          haven_1.1.2          xtable_1.8-2        
[29] mvtnorm_1.0-8        scales_1.0.0         later_0.7.3          bayesplot_1.6.0     
[33] DT_0.4               withr_2.1.2          shinyjs_1.0          lazyeval_0.2.1      
[37] cli_1.0.0            readxl_1.1.0         magrittr_1.5         crayon_1.3.4        
[41] mime_0.5             nlme_3.1-137         xml2_1.2.0           xts_0.11-0          
[45] colourpicker_1.0     rsconnect_0.8.8      tools_3.5.1          loo_2.0.0           
[49] hms_0.4.2            matrixStats_0.54.0   munsell_0.5.0        compiler_3.5.1      
[53] rlang_0.2.2.9000     grid_3.5.1           ggridges_0.5.0       rstudioapi_0.7      
[57] htmlwidgets_1.2      crosstalk_1.0.0      igraph_1.2.2         miniUI_0.1.1.1      
[61] base64enc_0.1-3      codetools_0.2-15     gtable_0.2.0         inline_0.3.15       
[65] abind_1.4-5          markdown_0.8         reshape2_1.4.3       R6_2.2.2            
[69] lubridate_1.7.4      gridExtra_2.3        rstantools_1.5.1     zoo_1.8-3           
[73] knitr_1.20           bridgesampling_0.5-2 bindr_0.1.1          shinystan_2.5.0     
[77] shinythemes_1.1.1    stringi_1.2.4        parallel_3.5.1       tidyselect_0.2.4    
[81] coda_0.19-1         
>
1 Like