I’m implementing a new lpdf function for HMMs, with discrete latent states, which uses the adjoint method @betanalpha and I have been working on. The signature is
template <T_omega, T_Gamma, T_rho>
hmm_marginal_lpdf(const Eigen::Matrix<T_omega, -1, -1>& log_omegas,
const Eigen::Matrix<T_Gamma, -1, -1>& Gamma,
const Eigen::Matrix<T_rho, -1, 1>& rho,
int n_states) {...}
We derived an adjoint method to get sensitivities for log_omega
, Gamma
, and rho
. Currently, I wrote a custom vari class.
benefit: I can calculate the derivatives when constructing the vari object, and then only store these derivatives (as opposed to the matrices required to construct these derivatives) before calling chain()
. So the forward pass only stores the log density and the derivatives.
drawback: I don’t need sensitivities for all three, and I’m not quite sure how adequately template the vari class / I don’t want to write 8 vari classes.
Browsing the forum there seems to be two other options: operands_and_partials
and Adj_jac_apply
. The first one gives me the requisite templating, but forces me to do all the calculations during the chain call (though it seems like there should be a way around this). The second method seems fine but only for rev mode.
Do I have a reasonable assessment of the situation?