I have a preprint out that models disease progression in Covid-19 patients : https://www.medrxiv.org/content/10.1101/2020.12.03.20239863v1 (with @paul.buerkner, @vianeylb and many others) The results themselves are not that interesting (we can’t learn much from the dataset we have), but I am quite excited that one of the things we tried was to use Hidden Markov Models to model disease progression. Now we needed the HMMs to have transition matrices that vary both by patient (to adjust for baseline characteristics) and by time (change after a treatment was administered). So this called for some kind of linear predictors on the transition probabilities… To let me iterate faster on the models (and because I hate reimplementing linear predictors, varying interceptes etc.) I hacked brms
to write the “linear” part for me and then inject HMM code on top of this.
This post is about this HMM implementation, how it works and what are the next steps to make it more broadly usable. Here’s how it works:
Linear predictors for HMMs?
I know of two reasonable ways to use linear predictors for transition matrices in HMMs. The first one is treating transition probabilities from each state as a separate categorical regression. This could make sense in a lot of cases, but I expected the transitions in my model to be highly structured - looking something like this:
The arrows represent “direct” transitions, e.g. a patient that is deteriorating from “AA” (breathing Ambient Air) to “Ventilated” will pass through the “Oxygen” state (needs supplemental O2) and patient in the “Oxygen” state will not be discharged before being able to withstand with only “AA”. However, our data are on daily resolution and it is possible that two (or more) of those transitions will happen within a single day, so the transition probability from “AA” to “Ventilated” is not zero, but we want to enforce that P(X_{t} = \mathtt{Ventilated} | X_{t - 1} = \mathtt{Oxygen}) > P(X_{t} = \mathtt{Ventilated} | X_{t - 1} = \mathtt{AA}).
This leads to a different formulation: We setup a rate matrix R so that for any two states i \neq j R_{i,j} is the rate of transition from i to j. R is intended to be sparse, i.e. R_{i,j} \neq 0 only for transitions explicitly intended to be modeled directly (as seen above). Additionally the diagonal elements are set as R_{i,i} = -\sum_{j \neq i} R_{i,j}.
Assuming fixed rates, the evolution of the vector p(t) of the state probabilities in continuous time t is then given by the differential equation
Given the initial state probabilities p(0), the solution to this equation is:
where \exp is the matrix exponential. We can thus compute a discrete-time transition matrix S = \exp(R) so that p(t + 1) = Sp(t). The appeal of this approach is that we can enforce a lot of structure on the transition matrix S while allowing positive probability for almost all transitions. In the case of the complex model (case b in the figure above), the full transition matrix between the 8 states would have 42 free parameters (no transitions from the “Death” and "Discharged states) while the rate matrix for the same model has only 13 free parameters.
We also find the rate formulation to be theoretically appealing - the disease progression takes place in continuous time, the discretization into individual days is an artifact of the way we collected data, not the reality.
Finally we let the off-diagonal elements of R to be described via a linear predictor, i.e. for discrete time t and patient k we have R_{i,j;t,k} = \exp ( \mu_{i,j;t,k} )
This approach is taken from Williams et al.: https://doi.org/10.1080/01621459.2019.1594831 (thanks @vianeylb for pointing me to this and for other help along the way).
Describing the model as brms formula
In the current implementation, the available transitions (non-zero off-diagonal elements of R) are described in a data.frame
like this (for the simple “a” model from above):
Where .from
, .to
and .rate_id
are required and the rate_group
is additional data we can use in the predictor.
The code then takes a data frame with predictors for each serie (patient in this case) and time and makes a cross product with the rate data frame, forming a combined data frame. You then write a formula describing the linear predictor based on this combined data frame, e.g.:
~0 + .rate_id + (0 + age + sex + took_favipiravir || rate_group)
You see that I combine the properties of individual rates (.rate_id
, rate_group
), time-invariant predictors (age
, sex
) and time-varying predictors (took_favipiravir
). Also note that I try to make the model less flexible (and thus easier to fit) by having covariates act exactly the same on multiple rates in the same group.
My code then:
-
Builds a mapping between each rate, time and serie (patient) and the corresponding row in the combined data frame
-
Collapses all rows of the combined data frame that have the same values of all predictors that are actually used (for efficiency), and adds a “dummy” response column to trick
brms
(this is ignored later). -
Packages the mapping, observations, definitions of observation matrix, and other data outside the combined data frame via
stanvars
to be passed to thebrms
model. -
Uses
stanvars
to define additional parameters. -
Creates a new
family
object of classrate_hmm
that overrides thestan_log_lik
method from withinbrms
- this needs to be done explicitly viaregisterS3method("stan_log_lik", class = "rate_hmm", method = stan_log_lik.rate_hmm, envir = asNamespace("brms"))
(this is currently the biggest/only hack). The reason for this is that the regularcustom_family
inbrms
still keeps a loop over observations. This way I letbrms
only write the code to compute themu
vector and replace everything below. -
Use regular calls to
brm
etc. to run the model, wraps the result in a newbrmshmmfit
object that also keeps the mapping and other data necessary.
Note that the collapsing of repeated rows is completely necessary, because it lets me also avoid recomputing the same transition matrix repeatedly, because the matrix exponential is almost certainly the most expensive part of the model (when running predictions from the fitted model, computing the matrix exponentials for ~3000 posterior samples takes easily several minutes after this optimization. And this is without autodiff).
The actual HMM code is derived from the new HMM methods in Stan.
Observartions need not be available for all time points, but predictors have to be.
Observation matrix
The current implementation uses only a very simple observation model where either a hidden state corresponds immediately to a single observation (but multiple hidden states can produce the same observation) or there is a probability for it producing an adjacent observation (as in "The patient would actually be ok on Ambient air, but the doctors keep them on Oxygen just to be sure) so this is definitely a place for expansion. Also continous observation models could be good to handle.
Miscellaneous and next steps
I was able to validate the method (at least for small-scale models) via SBC, but brms
lets you express some wild things which you will almost certainly be unable to fit. And although I ended with quite simple models, I was able to iterate quickly through models with monotonic effects and other features I almost certainly wouldn’t try if I had to implement them myself :-) I also could focus on debugging the HMM and not my linear formula codes…
The code for the HMM part is with the rest of the paper at github.com/cas-bioinf/covid19retrospective (the usage is seen in manuscript/hmm.Rmd
There is also support for posterior predictions for new data, but a bunch of rough edges remain. I would totally love to develop this further - allow also the categorical regression formulation, add flexibility for observational models, …, but honestly, I am terrible at completing projects where I am the main stakeholder. So if you (or someone you know) have a dataset you would love to analyze with a time/serie-varying HMMs let me know. Building on real use cases would definitely let me build a better interface/feature set and also having someone who cares about the results will help in making me actually ship the damn thing at some point :-)
I would also note that this approach (transform dataset, build mapping between the mu
array and the parameters of interest, replace the likelihood), is potentially very general while not crazy hard to pull off and should make it possible to use the richness of linear predictors and supporting functions provided by brms
for any type of model (e.g. ODEs)
Thanks for reading!