Using stan via brms to model compositional response

I have data that looks like this:

d %>% select(mode, Density, bus_stops_per_1000, has_tram)
# A tibble: 101 x 4
   mode[,"walking"] [,"cycling"] [,"pt"] [,"car"] [,"other"] Density bus_stops_per_1000 has_tram
              <dbl>        <dbl>   <dbl>    <dbl>      <dbl>   <dbl>              <dbl> <lgl>   
 1             0.24     0.00195     0.17    0.588   0.000312   5694.               1.52 FALSE   
 2             0.19     0.08        0.12    0.55    0.06       2539.               4.89 FALSE   
 3             0.03     0.24        0.08    0.649   0.00108    2823.               2.02 FALSE   
 4             0.07     0.03        0.23    0.670   0.000385    673.              15.8  TRUE    
 5             0.09     0.02        0.2     0.687   0.00277    2943.               3.36 TRUE    
 6             0.12     0.00285     0.11    0.764   0.00353    7433.               1.84 FALSE   
 7             0.18     0.000414    0.13    0.686   0.00310    2589.               3.30 TRUE    
 8             0.03     0.34        0.24    0.387   0.00346    2844.               2.71 TRUE    
 9             0.08     0.04        0.21    0.669   0.00101    2189.               5.18 FALSE   
10             0.13     0.01        0.14    0.719   0.000635   2677.               1.33 FALSE   
# … with 91 more rows

The response is compositional with 5 columns which sum to 1 per row:

summary(rowSums(d$mode))
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
      1       1       1       1       1       1 

I would like to estimate how the compositions change in response to some predictors. I have demonstrated a proof of concept using brms shown here (graphical output shown below).

In recent versions of brms the following model does not finish, even after 2+ hours, despite the small size of the input dataset that represents 101 cities around the world, with the following command:

m = brm(mode ~ Density + bus_stops_per_1000 + has_tram, data = d, family = dirichlet(), inits = 0, iter = 10e4)

I’m quite new to this brave new (to me at least) world of Bayesian inference so it’s likely that I’m making some conceptual errors. However, based on reading the brms documentation and my evolving understanding of modelling, I think the above command should at least return a value.

Any ideas what I’m doing wrong and how I can make this work very much appreciated. (A related question is that if we predict mode shares for a particular mode using this method, is it possible to calculate the confidence intervals of change for a particular mode, e.g. if a city gained a tram system the central estimate of mode share by walking increases by 0.008 but what is the distribution? See here for an explanation)

You should be able to reproduce the issue and generate results for the 101 cities using the reproducible example below:

library(brms)
library(dplyr)
library(sf)

d = readRDS(url("https://github.com/ATFutures/who3/blob/master/global-data/cities-101-osm-bus.Rds?raw=true"))
names(d)
# no zeros allowed: how to allow zero values?
d %>% sf::st_drop_geometry() %>% select(walking:cycling) %>% summary()
d = d %>% 
  st_drop_geometry() %>%
  # filter(cycling != 0) %>% 
  # filter(other != 0) %>% 
  mutate_at(vars(other, cycling), ~case_when(. == 0 ~ runif(n = 1, min = 0, max = 0.5), TRUE ~ .)) %>%
  ungroup() %>%
  select(-bb_poly) %>% 
  mutate(Density = Population / Area)
totals = d %>% select(walking:other, -City) %>% rowSums() - 100
d$car = d$car - totals
modes_matrix = d %>% select(walking:other) %>% as.matrix()
summary(rowSums(modes_matrix))
d = d %>% select(-(walking:other))
modes_matrix[modes_matrix == 0] = 0.01
d$mode = modes_matrix / 100
names(d)
head(d)
m = brm(mode ~ Density + bus_stops_per_1000 + has_tram, data = d, family = dirichlet(), inits = 0, iter = 10e4)
  • Operating System: Ubuntu 18.04
  • brms Version:
packageVersion("brms")
#> [1] '2.13.3'

Created on 2020-07-01 by the reprex package (v0.3.0)

Update on this, I’m trying the model with fewer iterations, I think 100k iterations may explain why it was taking forever as a colleague pointed out. Running the model now as follows, building on the example above:

m = brm(
  mode ~ Density + bus_stops_per_1000 + has_tram,
  data = d,
  family = dirichlet(),
  inits = 0,
  iter = 1000,
  cores = 4,
  chains = 4
)

Hi there. Without looking at your data or your model, I will also recommend that you set regularizing priors wherever possible. This is one of the single most important steps to effective computation (fast run times, good mixing, no geometry problems etc.)

You can get a sense of the default priors brms will set in your model by running something like:

model_formula <- bf(mode ~ Density + bus_stops_per_1000 + has_tram)
get_prior(model_formula,
          data = d,
          family = dirichlet())

In many cases, the brms defaults work great, but it’s always better to ensure they make sense for your model and your data.

Many thanks for the reply @franzsf, I successfully generated priors as you suggested. Unfortunately the code is still not working. I suspect there is a bug and would be grateful if others could try to reproduce the issue. On my computer the final command simply never completes. Interestingly an html page opens up (I believe to report progress on the stan model run) but it’s a blank page. To simplify the code and reduce dependencies I’ve created a smaller reproducible example incorporating your suggestion:

library(brms)
d = readRDS(url("https://github.com/ATFutures/who3/releases/download/0.0.1/cities-compositional-data.Rds"))
model_formula <- bf(mode ~ Density + bus_stops_per_1000 + has_tram)
prior = get_prior(model_formula,
                  data = d,
                  family = dirichlet())
prior
brm(model_formula, data = d, family = dirichlet(),
     iter = 1000, cores = 4, chains = 4, prior = prior)

Results below:

# with brms again
library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.13.3). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
d = readRDS(url("https://github.com/ATFutures/who3/releases/download/0.0.1/cities-compositional-data.Rds"))
model_formula <- bf(mode ~ Density + bus_stops_per_1000 + has_tram)
prior = get_prior(model_formula,
                  data = d,
                  family = dirichlet())
prior
#>                     prior     class               coef group resp      dpar
#> 1                                 b                                        
#> 2                         Intercept                                        
#> 3       gamma(0.01, 0.01)       phi                                        
#> 4                                 b                                   mucar
#> 5                                 b bus_stops_per_1000                mucar
#> 6                                 b            Density                mucar
#> 7                                 b       has_tramTRUE                mucar
#> 8  student_t(3, 0.1, 2.5) Intercept                                   mucar
#> 9                                 b                               mucycling
#> 10                                b bus_stops_per_1000            mucycling
#> 11                                b            Density            mucycling
#> 12                                b       has_tramTRUE            mucycling
#> 13 student_t(3, 0.1, 2.5) Intercept                               mucycling
#> 14                                b                                 muother
#> 15                                b bus_stops_per_1000              muother
#> 16                                b            Density              muother
#> 17                                b       has_tramTRUE              muother
#> 18 student_t(3, 0.1, 2.5) Intercept                                 muother
#> 19                                b                                    mupt
#> 20                                b bus_stops_per_1000                 mupt
#> 21                                b            Density                 mupt
#> 22                                b       has_tramTRUE                 mupt
#> 23 student_t(3, 0.1, 2.5) Intercept                                    mupt
#>    nlpar bound
#> 1             
#> 2             
#> 3             
#> 4             
#> 5             
#> 6             
#> 7             
#> 8             
#> 9             
#> 10            
#> 11            
#> 12            
#> 13            
#> 14            
#> 15            
#> 16            
#> 17            
#> 18            
#> 19            
#> 20            
#> 21            
#> 22            
#> 23
# never completes even with only 101 observations:
# brm(model_formula, data = d, family = dirichlet(), iter = 1000, cores = 4, chains = 4, prior = prior, )

Created on 2020-07-02 by the reprex package (v0.3.0)

Unless I’m missing something, the “get_prior” command just outputs what the defaults priors are, given your data, model and family. You have to set priors in the brms specifically, not by referencing that output as an object.

model <- brm(model_formula, 
             data = d, 
             family = dirichlet(), 
             iter = 1000, 
             cores = 4, 
             chains = 4, 
             prior = c(prior(normal(0,1), class = "b", dpar = "mucar"))  # One example

See documentation for set_prior (https://cran.r-project.org/web/packages/brms/brms.pdf) etc. I couldn’t tell you what priors are most appropriate for your model.

Once you have priors set, however, you can use the sample_prior = “only” to get prior predictive checks and see if they make sense.

Thanks for the swift reply. I see from the documentation that you can set priors in the brm() function call with prior = prior(normal(0, 1), class = b). I will give that a try and report back.

Update @franzsf I’ve double checked and I think you can use the prior object. See ?get_prior, which contains the following:

make_stancode(count ~ zAge + zBase * Trt + (1|patient) + (1|obs),
              data = epilepsy, family = poisson(), 
              prior = prior)

Have I misunderstood something?

I guess I did miss that; I’ve always specified priors explicitly. Note, however, the documentation also specifies how you can tweak the defaults brms will automatically set. Regardless of how you set them, the computational issues you were experiencing may be addressed by setting priors that regularize more than the defaults.

That makes more sense, but which changes to the priors will regularize more than the defaults?

I don’t think it’s really possible to make general recommendations outside of the context of the model and data in question. I’m certainly not qualified to do so. One idea might be try replacing the student_t defaults with appropriately-sized normals to tighten up the tails. Then sample from the prior only and seeing what makes sense. The computational issues could also be resulting from how the model is parameterized; I’d just try tightening up the priors as an initial step.

** Edited ** Actually, you probably should start with a simpler version of the model, get that running with appropriate priors, and then expand out.

Thanks @franzsf, it sounds like starting from 1st principles based on a very simple representation of the data is worthwhile, can build from something really simple e.g. car:PT:walk in response to presence/absence of trams.

However… I do suspect there is a bug here: regardless of the priors the model should actually run within a reasonable (less than 1 hour) amount of time with only 100 observations.

Could it also be that there isn’t enough data for the number of priors. I see

nrow(prior)
[1] 23

Does that mean there are 23 priors?!

Digging into your example, it looks like there are problems with zeros and totals not summing to 100%. (Edited: you already attempted to deal with this, so unclear what’s going on there). Try the code below as a start. This ran in < 5 seconds on my machine.

library(brms)
library(data.table)

data <- as.data.table(readRDS(url("https://github.com/ATFutures/who3/blob/master/global-data/cities-101-osm-bus.Rds?raw=true")))

# Ensure no zeroes
summary(data$walking)
summary(data$cycling)
summary(data$pt)
summary(data$car)
summary(data$other)

# Fudge small numbers since zero-inflatd dirichlet not native to brms
data[cycling == 0, cycling := 0.01]
data[other == 0, other := 0.01]

data[,total := walking + cycling + pt + car + other]

data[,walking := walking/total
     ][,cycling := cycling/total
       ][,pt := pt/total
         ][,car := car/total
           ][,other := other/total]

data[,check := walking+cycling+pt+car+other]
summary(data$check)


bind <- function(...) cbind(...)

transp_formula <- bf(bind(walking,cycling,pt,car,other) ~ 1)

get_prior(transp_formula,
          family = dirichlet(),
          data = data)

transp_model <- brm(transp_formula,
                    data = data,
                    family = dirichlet(),
                    prior = c(prior(normal(0,4), class = "Intercept", dpar = "mucar"),
                              prior(normal(0,4), class = "Intercept", dpar = "mucycling"),
                              prior(normal(0,4), class = "Intercept", dpar = "muother"),
                              prior(normal(0,4), class = "Intercept", dpar = "mupt")),
                    chains = 4,
                    cores = 4,
                    iter = 1000,
                    warmup = 500)

Many thanks for the reply, that looks like a really good start. I think there’s an issue with by brms installation as it’s still taking forever. I will try it in a different computer / Docker container when I get a chance.