Fitting HMMs with time-varying transition matrices using brms: A prototype

I have a preprint out that models disease progression in Covid-19 patients : https://www.medrxiv.org/content/10.1101/2020.12.03.20239863v1 (with @paul.buerkner, @vianeylb and many others) The results themselves are not that interesting (we can’t learn much from the dataset we have), but I am quite excited that one of the things we tried was to use Hidden Markov Models to model disease progression. Now we needed the HMMs to have transition matrices that vary both by patient (to adjust for baseline characteristics) and by time (change after a treatment was administered). So this called for some kind of linear predictors on the transition probabilities… To let me iterate faster on the models (and because I hate reimplementing linear predictors, varying interceptes etc.) I hacked brms to write the “linear” part for me and then inject HMM code on top of this.

This post is about this HMM implementation, how it works and what are the next steps to make it more broadly usable. Here’s how it works:

Linear predictors for HMMs?

I know of two reasonable ways to use linear predictors for transition matrices in HMMs. The first one is treating transition probabilities from each state as a separate categorical regression. This could make sense in a lot of cases, but I expected the transitions in my model to be highly structured - looking something like this:

obrazek

The arrows represent “direct” transitions, e.g. a patient that is deteriorating from “AA” (breathing Ambient Air) to “Ventilated” will pass through the “Oxygen” state (needs supplemental O2) and patient in the “Oxygen” state will not be discharged before being able to withstand with only “AA”. However, our data are on daily resolution and it is possible that two (or more) of those transitions will happen within a single day, so the transition probability from “AA” to “Ventilated” is not zero, but we want to enforce that P(X_{t} = \mathtt{Ventilated} | X_{t - 1} = \mathtt{Oxygen}) > P(X_{t} = \mathtt{Ventilated} | X_{t - 1} = \mathtt{AA}).

This leads to a different formulation: We setup a rate matrix R so that for any two states i \neq j R_{i,j} is the rate of transition from i to j. R is intended to be sparse, i.e. R_{i,j} \neq 0 only for transitions explicitly intended to be modeled directly (as seen above). Additionally the diagonal elements are set as R_{i,i} = -\sum_{j \neq i} R_{i,j}.

Assuming fixed rates, the evolution of the vector p(t) of the state probabilities in continuous time t is then given by the differential equation

\frac{dp(t)}{dt} = Rp

Given the initial state probabilities p(0), the solution to this equation is:

p(t) = \exp(tR)p(0)

where \exp is the matrix exponential. We can thus compute a discrete-time transition matrix S = \exp(R) so that p(t + 1) = Sp(t). The appeal of this approach is that we can enforce a lot of structure on the transition matrix S while allowing positive probability for almost all transitions. In the case of the complex model (case b in the figure above), the full transition matrix between the 8 states would have 42 free parameters (no transitions from the “Death” and "Discharged states) while the rate matrix for the same model has only 13 free parameters.

We also find the rate formulation to be theoretically appealing - the disease progression takes place in continuous time, the discretization into individual days is an artifact of the way we collected data, not the reality.

Finally we let the off-diagonal elements of R to be described via a linear predictor, i.e. for discrete time t and patient k we have R_{i,j;t,k} = \exp ( \mu_{i,j;t,k} )

This approach is taken from Williams et al.: https://doi.org/10.1080/01621459.2019.1594831 (thanks @vianeylb for pointing me to this and for other help along the way).

Describing the model as brms formula

In the current implementation, the available transitions (non-zero off-diagonal elements of R) are described in a data.frame like this (for the simple “a” model from above):

Where .from, .to and .rate_id are required and the rate_group is additional data we can use in the predictor.

The code then takes a data frame with predictors for each serie (patient in this case) and time and makes a cross product with the rate data frame, forming a combined data frame. You then write a formula describing the linear predictor based on this combined data frame, e.g.:

~0 + .rate_id + (0 + age + sex + took_favipiravir || rate_group)

You see that I combine the properties of individual rates (.rate_id, rate_group), time-invariant predictors (age, sex) and time-varying predictors (took_favipiravir). Also note that I try to make the model less flexible (and thus easier to fit) by having covariates act exactly the same on multiple rates in the same group.

My code then:

  • Builds a mapping between each rate, time and serie (patient) and the corresponding row in the combined data frame

  • Collapses all rows of the combined data frame that have the same values of all predictors that are actually used (for efficiency), and adds a “dummy” response column to trick brms (this is ignored later).

  • Packages the mapping, observations, definitions of observation matrix, and other data outside the combined data frame via stanvars to be passed to the brms model.

  • Uses stanvars to define additional parameters.

  • Creates a new family object of class rate_hmm that overrides the stan_log_lik method from within brms - this needs to be done explicitly via registerS3method("stan_log_lik", class = "rate_hmm", method = stan_log_lik.rate_hmm, envir = asNamespace("brms")) (this is currently the biggest/only hack). The reason for this is that the regular custom_family in brms still keeps a loop over observations. This way I let brms only write the code to compute the mu vector and replace everything below.

  • Use regular calls to brm etc. to run the model, wraps the result in a new brmshmmfit object that also keeps the mapping and other data necessary.

Note that the collapsing of repeated rows is completely necessary, because it lets me also avoid recomputing the same transition matrix repeatedly, because the matrix exponential is almost certainly the most expensive part of the model (when running predictions from the fitted model, computing the matrix exponentials for ~3000 posterior samples takes easily several minutes after this optimization. And this is without autodiff).

The actual HMM code is derived from the new HMM methods in Stan.

Observartions need not be available for all time points, but predictors have to be.

Observation matrix

The current implementation uses only a very simple observation model where either a hidden state corresponds immediately to a single observation (but multiple hidden states can produce the same observation) or there is a probability for it producing an adjacent observation (as in "The patient would actually be ok on Ambient air, but the doctors keep them on Oxygen just to be sure) so this is definitely a place for expansion. Also continous observation models could be good to handle.

Miscellaneous and next steps

I was able to validate the method (at least for small-scale models) via SBC, but brms lets you express some wild things which you will almost certainly be unable to fit. And although I ended with quite simple models, I was able to iterate quickly through models with monotonic effects and other features I almost certainly wouldn’t try if I had to implement them myself :-) I also could focus on debugging the HMM and not my linear formula codes…

The code for the HMM part is with the rest of the paper at github.com/cas-bioinf/covid19retrospective (the usage is seen in manuscript/hmm.Rmd

There is also support for posterior predictions for new data, but a bunch of rough edges remain. I would totally love to develop this further - allow also the categorical regression formulation, add flexibility for observational models, …, but honestly, I am terrible at completing projects where I am the main stakeholder. So if you (or someone you know) have a dataset you would love to analyze with a time/serie-varying HMMs let me know. Building on real use cases would definitely let me build a better interface/feature set and also having someone who cares about the results will help in making me actually ship the damn thing at some point :-)

I would also note that this approach (transform dataset, build mapping between the mu array and the parameters of interest, replace the likelihood), is potentially very general while not crazy hard to pull off and should make it possible to use the richness of linear predictors and supporting functions provided by brms for any type of model (e.g. ODEs)

Thanks for reading!

16 Likes

Hi Martin,

this is a really nice idea! And I am happy you could make use of brms as well :-)

Is there something from the brms side that I could do to make the required code easier? You mentioned avoiding a loop over observations in custom families? If it is just that, I can easily add a feature to turn off the loop. but I am also happy to add more if this makes working with such complex custom families easier.

2 Likes

Certainly yes, but not sure yet what would be the best way. I am trying to gather my thoughts on all of this. I will try to file some issues with brms in the near future to discuss details.

Hi

I would like to add, that I am also working with HMMs right now, and would be great if we could work this from brms.
I was thinking to work it between a the nonlinear equations similar to IRT, and then try to set an AR1 between the states. Just started working on this

1 Like