That was more difficult than initially thought. The epred
functionality needed a bit of tinkering. However, whereas you will see it is all possible, it simply replicates the functionality of the sratio()
family and does so less flexibly (i.e., requiring a specific family for each number of categories). So whereas it works, the sratio()
family is probably the way to go (however, the custom family samples quite a bit more efficiently/faster).
To sum the problem up, we need to pass specials = "categorical"
to custom_family
. This unfortunately renames the parameter names, which is what you can see below. However, it only requires one formula, which is nice.
library("tidyverse")
library("brms")
womenlf <- carData::Womenlf
womenlf <- womenlf %>%
mutate(work2 = factor(partic,
levels = c("not.work", "parttime", "fulltime"))) %>%
mutate(work3 = as.numeric(work2))
str(womenlf)
# 'data.frame': 263 obs. of 6 variables:
# $ partic : Factor w/ 3 levels "fulltime","not.work",..: 2 2 2 2 2 2 2 1 2 2 ...
# $ hincome : int 15 13 45 23 19 7 15 7 15 23 ...
# $ children: Factor w/ 2 levels "absent","present": 2 2 2 2 2 2 2 2 2 2 ...
# $ region : Factor w/ 5 levels "Atlantic","BC",..: 3 3 3 3 3 3 3 3 3 3 ...
# $ work2 : Factor w/ 3 levels "not.work","parttime",..: 1 1 1 1 1 1 1 3 1 1 ...
# $ work3 : num 1 1 1 1 1 1 1 3 1 1 ...
mpt_cond <- custom_family("mpt_cond",
links = "identity", dpars = c("mu", "mub"),
type = "int",
vars = c("n_cat"), specials = "categorical")
stanvars <- stanvar(x = 3, name = "n_cat", scode = " int n_cat;")
stan_lpmf <- stanvar(block = "functions", scode = "
vector mpt_cond_pred(int y, real mu, real mu_b, int n_cat) {
real p_mu = inv_logit(mu);
real p_mub = inv_logit(mu_b);
vector[n_cat] prob;
prob[1] = p_mu;
prob[2] = (1 - p_mu) * p_mub;
prob[3] = (1 - p_mu) * (1 - p_mub);
return(prob);
}
real mpt_cond_lpmf(int y, real mu, real mu_b, int n_cat) {
real p_mu = inv_logit(mu);
real p_mub = inv_logit(mu_b);
vector[n_cat] prob;
prob[1] = p_mu;
prob[2] = (1 - p_mu) * p_mub;
prob[3] = (1 - p_mu) * (1 - p_mub);
return(categorical_lpmf(y | prob));
}")
fit1 <- brm(work3 ~ hincome + children + region,
data = womenlf, family = mpt_cond,
stanvars = stanvars + stan_lpmf)
fit1
This gives the same results as before:
> fit1
Family: mpt_cond
Links: mu2 = identity; mu3 = identity
Formula: work3 ~ hincome + children + region
Data: womenlf (Number of observations: 263)
Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
total post-warmup draws = 4000
Population-Level Effects:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu2_Intercept -1.30 0.55 -2.38 -0.19 1.00 2717 2404
mu3_Intercept -4.14 1.12 -6.41 -1.98 1.00 2627 2718
mu2_hincome 0.05 0.02 0.01 0.09 1.00 4302 3010
mu2_childrenpresent 1.65 0.31 1.05 2.26 1.00 3853 3037
mu2_regionBC -0.37 0.60 -1.55 0.82 1.00 2924 2948
mu2_regionOntario -0.22 0.49 -1.20 0.72 1.00 2513 2610
mu2_regionPrairie -0.51 0.58 -1.67 0.56 1.00 2960 3018
mu2_regionQuebec 0.16 0.51 -0.85 1.17 1.00 2513 2441
mu3_hincome 0.11 0.04 0.04 0.20 1.00 4019 3098
mu3_childrenpresent 3.01 0.61 1.90 4.29 1.00 3698 2458
mu3_regionBC 1.32 1.10 -0.75 3.55 1.00 2709 2951
mu3_regionOntario 0.20 0.90 -1.54 1.96 1.00 2290 2767
mu3_regionPrairie 0.43 1.02 -1.55 2.49 1.00 2682 2986
mu3_regionQuebec -0.15 0.99 -2.09 1.84 1.00 2583 2892
Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
To get post processing we need the following:
expose_functions(fit1, vectorize = TRUE)
log_lik_mpt_cond <- function(i, prep) {
mu <- brms::get_dpar(prep, "mu2", i = i)
mu_b <- brms::get_dpar(prep, "mu3", i = i)
n_cat <- prep$data$n_cat
y <- prep$data$Y[i]
mpt_cond_lpmf(y, mu, mu_b, n_cat)
}
loo(fit1)
# Computed from 4000 by 263 log-likelihood matrix
#
# Estimate SE
# elpd_loo -225.7 12.2
# p_loo 16.4 1.7
# looic 451.4 24.3
# ------
# Monte Carlo SE of elpd_loo is 0.1.
#
# All Pareto k estimates are good (k < 0.5).
# See help('pareto-k-diagnostic') for details.
posterior_predict_mpt_cond <- function(i, prep, ...) {
mu <- brms::get_dpar(prep, "mu2", i = i)
mu_b <- brms::get_dpar(prep, "mu3", i = i)
n_cat <- prep$data$n_cat
y <- prep$data$Y[i]
prob <- mpt_cond_pred(y, mu, mu_b, n_cat)
extraDistr::rcat(length(mu), t(prob))
}
pp_check(fit1, type = "bars")
We can also get the predictions:
posterior_epred_mpt_cond <- function(prep) {
mu <- brms::get_dpar(prep, "mu2")
mu_b <- brms::get_dpar(prep, "mu3")
n_cat <- prep$data$n_cat
y <- prep$data$Y
prob <- mpt_cond_pred(y = y, mu = mu, mu_b = mu_b, n_cat = n_cat)
dim(prob) <- c(dim(prob)[1], dim(mu))
prob <- aperm(prob, c(2,3,1))
dimnames(prob) <- list(
as.character(seq_len(dim(prob)[1])),
NULL,
as.character(seq_len(dim(prob)[3]))
)
prob
}
epred1 <- posterior_epred(fit1)
str(epred1)
# num [1:4000, 1:263, 1:3] 0.612 0.697 0.731 0.733 0.694 ...
# - attr(*, "dimnames")=List of 3
# ..$ : chr [1:4000] "1" "2" "3" "4" ...
# ..$ : NULL
# ..$ : chr [1:3] "1" "2" "3"
conditional_effects(fit1, "region", categorical = TRUE)
So as I said, this is pretty much equivalent to the following (which is however slower in sampling):
fit2 <- brm(work3 ~ cs(hincome + children + region),
family = sratio(),
data = womenlf)
fit2
pp_check(fit2, type = "bars")
conditional_effects(fit2, "region", categorical = TRUE)