Vari class, operands_and_partials, or adj_jac_apply

I’m implementing a new lpdf function for HMMs, with discrete latent states, which uses the adjoint method @betanalpha and I have been working on. The signature is

template <T_omega, T_Gamma, T_rho>
hmm_marginal_lpdf(const Eigen::Matrix<T_omega, -1, -1>& log_omegas,
                  const Eigen::Matrix<T_Gamma, -1, -1>& Gamma,
                  const Eigen::Matrix<T_rho, -1, 1>& rho,
                  int n_states) {...}

We derived an adjoint method to get sensitivities for log_omega, Gamma, and rho. Currently, I wrote a custom vari class.

benefit: I can calculate the derivatives when constructing the vari object, and then only store these derivatives (as opposed to the matrices required to construct these derivatives) before calling chain(). So the forward pass only stores the log density and the derivatives.

drawback: I don’t need sensitivities for all three, and I’m not quite sure how adequately template the vari class / I don’t want to write 8 vari classes.

Browsing the forum there seems to be two other options: operands_and_partials and Adj_jac_apply. The first one gives me the requisite templating, but forces me to do all the calculations during the chain call (though it seems like there should be a way around this). The second method seems fine but only for rev mode.

Do I have a reasonable assessment of the situation?

1 Like

After further inspection, all three methods can achieve what I’m after.

Note operands_and_partials doesn’t give you control on when the derivatives, but computes them during the forward pass. This is fine, since this is what I intend to do.

1 Like

The n_states argument’s redundant given that you have Gamma and rho.

When there’s a reduction to a single scalar output, the Jacobian is the same size as the number of parameters in the input. So operands_and_partials should be efficient enough.

1 Like