Dirichlet prior possible with stanvars?

Hi have a model that I’d like to fit:

y_{i} \sim \text{Ordered-Logit}(\mu_{i}, \tau) \\ \mu_{i} = \sum_{j = 1}^{J} \beta \phi_{j} x_{t-j} \\ \beta \sim \text{Normal}(0, 0.5) \\ \tau \sim \text{Normal}(0, 1.5) \\ \phi \sim \text{Dirichlet}(1)

Here, all x variables measure the same variable at different points in time. As such, the model treats y as a weighed average of x where \beta serves to convert between units of y and x and the simplex \phi represents the weight placed on each measurement period. For example, if we have 4 periods, so \phi might be {0.7, 0.2, 0.1, 0}.

I have managed to fit the model in Stan, but would like to be able to use the various companion functions that come with brms. I know that dirichlet priors aren’t yet supported in brms. My question: is it possible to fit them anyway using the stanvars() function?

Yes, I think a stanvar for the model block should work. If there is a specific problem you have, feel free to report back. Best of luck!

@jackbailey Can you provide your raw Stan code?

Of course!

data {
  int<lower=0> N;               // Number of cases
  vector[N] v_05;               // Voted Labour 2005
  vector[N] v_10;               // Voted Labour 2010
  vector[N] v_15;               // Voted Labour 2015
  vector[N] v_17;               // Voted Labour 2017
  int<lower=0,upper=1> v_19[N]; // Voted Labour 2019
}
parameters {
  real a;
  real b;
  simplex[4] p;
}
model {
  
  // Specify priors
  target += normal_lpdf( a | 0, 1.5 );
  target += normal_lpdf( b | 0, 1.5 );
  target += dirichlet_lpdf( p | rep_vector( 0.5, 4 ) );
  
  // Specify model
  v_19 ~ bernoulli_logit(a + p[1]*b*v_17 + p[2]*b*v_15 + p[3]*b*v_10 + p[4]*b*v_05);

}

EDIT: @maxbiostat edited this post for syntax highlighting.

1 Like

EDIT2: The main problem is that brms does not support changing the declaration of parameters. Dirichlet prior by itself is no problem, but with a non-simplex parameter set it is useless.

Here is a workaround together with a similar dataset to make it work:

pacman::p_load(pacman, tidyverse, rio, magrittr, janitor, hablar, skimr, brms, pxweb, cmdstanr)

Data loading and preparation:

# PXWEB query 
pxweb_query_list <- 
  list("Region"=c("0114","0115","0117","0120","0123","0125","0126","0127",
                  "0128","0136","0138","0139","0140","0160","0162","0163",
                  "0180","0181","0182","0183","0184","0186","0187","0188",
                  "0191","0192","0305","0319","0330","0331","0360","0380",
                  "0381","0382","0428","0461","0480","0481","0482","0483",
                  "0484","0486","0488","0509","0512","0513","0560","0561",
                  "0562","0563","0580","0581","0582","0583","0584","0586",
                  "0604","0617","0642","0643","0662","0665","0680","0682",
                  "0683","0684","0685","0686","0687","0760","0761","0763",
                  "0764","0765","0767","0780","0781","0821","0834","0840",
                  "0860","0861","0862","0880","0881","0882","0883","0884",
                  "0885","0980","1060","1080","1081","1082","1083","1214",
                  "1229","1230","1231","1233","1256","1257","1260","1261",
                  "1262","1263","1264","1265","1266","1267","1270","1272",
                  "1273","1275","1276","1277","1278","1280","1281","1282",
                  "1283","1284","1285","1286","1287","1290","1291","1292",
                  "1293","1315","1380","1381","1382","1383","1384","1401",
                  "1402","1407","1415","1419","1421","1427","1430","1435",
                  "1438","1439","1440","1441","1442","1443","1444","1445",
                  "1446","1447","1452","1460","1461","1462","1463","1465",
                  "1466","1470","1471","1472","1473","1480","1481","1482",
                  "1484","1485","1486","1487","1488","1489","1490","1491",
                  "1492","1493","1494","1495","1496","1497","1498","1499",
                  "1715","1730","1737","1760","1761","1762","1763","1764",
                  "1765","1766","1780","1781","1782","1783","1784","1785",
                  "1814","1860","1861","1862","1863","1864","1880","1881",
                  "1882","1883","1884","1885","1904","1907","1960","1961",
                  "1962","1980","1981","1982","1983","1984","2021","2023",
                  "2026","2029","2031","2034","2039","2061","2062","2080",
                  "2081","2082","2083","2084","2085","2101","2104","2121",
                  "2132","2161","2180","2181","2182","2183","2184","2260",
                  "2262","2280","2281","2282","2283","2284","2303","2305",
                  "2309","2313","2321","2326","2361","2380","2401","2403",
                  "2404","2409","2417","2418","2421","2422","2425","2460",
                  "2462","2463","2480","2481","2482","2505","2506","2510",
                  "2513","2514","2518","2521","2523","2560","2580","2581",
                  "2582","2583","2584"),
       "Partimm"=c("S"),
       "ContentsCode"=c("ME0104B6","ME0104B7"),
       "Tid"=c("2002","2006","2010","2014","2018"))

# Download data 
px_data <- 
  pxweb_get(url = "http://api.scb.se/OV0104/v1/doris/sv/ssd/ME/ME0104/ME0104C/ME0104T3",
            query = pxweb_query_list)

# Convert to data.frame 
px_data_frame <- as.data.frame(px_data, column.name.type = "text", variable.value.type = "text")

px_data_frame %>% 
  as_tibble() %>% 
  remove_constant() %>% 
  clean_names() %>% 
  rename(municipality = region, year = valar, proportion = andel_roster_av_giltiga_roster, count = antal_roster) %>% 
  pivot_longer(cols=c(count, proportion)) %>% 
  pivot_wider(names_from = c(name, year), values_from = value) %>% 
  select(-count_2002, -count_2006, -count_2010, -count_2014) %>% 
  mutate(across(starts_with("proportion_"), ~.x/100),
         population_2018 = round(count_2018/proportion_2018)) %>% 
  select(-proportion_2018) -> s_data

The most important part:

stanvar(block = "parameters", 
        scode = "simplex[K_c] b_c;") -> 
  s_code

bf(count_2018|trials(population_2018) ~ a + b*c)+
  lf(a ~ 1)+
  lf(b ~ 1)+
  lf(c ~ 0 + proportion_2002 + proportion_2006 + proportion_2010 + proportion_2014)+
  set_nl(TRUE)+
  binomial() -> 
  bf_formula

set_prior("normal(0, 1.5)", nlpar = "a")+
  set_prior("normal(0, 1.5)", nlpar = "b")+
  set_prior("dirichlet(rep_vector( 0.5, K_c ))", nlpar = "c", class="b") -> 
  priors

make_standata(formula = bf_formula, 
              prior = priors,
              data = s_data, 
              stanvars = s_code) -> 
  stan_data

NULL -> class(stan_data)

make_stancode(formula = bf_formula, 
              prior = priors,
              data = s_data, 
              stanvars = s_code) %>% 
  str_remove("vector\\[K_c\\] b_c\\;") %>% # remove the old declaration
  write_stan_file() %>% 
  cmdstan_model() -> 
  cmdstanr_model

cmdstanr_model$sample(data = stan_data, 
                      parallel_chains = 4) -> 
  cmdstanr_fit

cmdstanr_fit$output_files() %>% 
  rstan::read_stan_csv() -> 
  rstan_fit

cmdstanr_model -> 
  attributes(rstan_fit)$CmdStanModel

brm(formula = bf_formula, 
    prior = priors,
    backend = "cmdstanr", 
    data = s_data, 
    empty = TRUE) -> 
  weighted_mean_model

rstan_fit -> 
  weighted_mean_model$fit

rename_pars(weighted_mean_model) -> 
  weighted_mean_model

summary(weighted_mean_model)

EDIT: Restructured the code a bit.

4 Likes

Staffan! Thank you so much, this is great!

1 Like