I’m implementing a new lpdf function for HMMs, with discrete latent states, which uses the adjoint method @betanalpha and I have been working on. The signature is
template <T_omega, T_Gamma, T_rho>
hmm_marginal_lpdf(const Eigen::Matrix<T_omega, -1, -1>& log_omegas,
const Eigen::Matrix<T_Gamma, -1, -1>& Gamma,
const Eigen::Matrix<T_rho, -1, 1>& rho,
int n_states) {...}
We derived an adjoint method to get sensitivities for log_omega, Gamma, and rho. Currently, I wrote a custom vari class.
benefit: I can calculate the derivatives when constructing the vari object, and then only store these derivatives (as opposed to the matrices required to construct these derivatives) before calling chain(). So the forward pass only stores the log density and the derivatives.
drawback: I don’t need sensitivities for all three, and I’m not quite sure how adequately template the vari class / I don’t want to write 8 vari classes.
Browsing the forum there seems to be two other options: operands_and_partials and Adj_jac_apply. The first one gives me the requisite templating, but forces me to do all the calculations during the chain call (though it seems like there should be a way around this). The second method seems fine but only for rev mode.
Do I have a reasonable assessment of the situation?