Changing grainsize yields "Gradient evaluated at the initial value is not finite"

Hi all,

I have a reasonably complex partial sum function that computes the forward algorithm for a hidden Markov model two layers of hidden states. In the partial sum function, I use the control flow with if statement to either compute the probabilities directly or to work with the log probabilities. I would like to have both methods available in the partial sum function.

/**
   * Compute partial sum over individuals of forward algorithm for hidden Markov model (HMM)
   *
   * @return  Log probability for subset of individuals
   */
  real partial_sum_lpmf(
    data array[] int seq_ind, data int start, data int end,
    data array[,,,] int y, data array[,,] int q, data int n_s, data int n_o, data int n_d, data int n_prim, data int n_sec, data array[] int first, data array[] int last, data array[] int first_sec, data vector tau,
    real eta_a, vector phi_a, real phi_b, vector psi_a, vector p_a, vector r, vector fr, matrix m, array[] matrix n_z, real sigma, data int log_p) {
      
      // number of individuals in partial sum and initialise partial target
      int n_ind = end - start + 1;
      real ptarget = 0;
      
      // parameters
      real eta;
      array[n_s - 1] vector[n_prim - 1] phi, psi;
      array[n_s - 1] matrix[n_sec, n_prim] p;
      array[n_ind] matrix[n_sec, n_prim] n;
      tuple(vector[n_prim], matrix[n_sec, n_prim]) delta;
      
      // transform intercepts
      vector[n_s - 1] log_phi_a = log(phi_a);
      
      // transition rate/probability matrices (trm/tpm)
      tuple(array[n_prim - 1] matrix[n_s, n_s - 1],
      array[n_prim, n_sec] matrix[n_s - 1, n_o],
      array[n_prim, n_sec] matrix[n_o - 1, n_d - 1]) tpm;
      array[n_prim - 1] matrix[n_s, n_s] trm;
      
      // initialise (log) probabilities
      array[n_prim] matrix[n_o - 1, n_sec] acc;
      matrix[n_s, n_prim] gamma;
      
      // for each individual
      for (i in 1:n_ind) {
        
        // index that maps to original n_ind
        int ii = i + start - 1;
        
        // probability of entering as infected
        eta = eta_a;
        
        // mortality rates
        phi[1] = rep_vector(phi_a[1], n_prim - 1);
        phi[2] = exp(log_phi_a[2] + phi_b * m[1:(n_prim - 1), ii]);
        
        // infection state transition rates
        psi[1] = rep_vector(psi_a[1], n_prim - 1);
        psi[2] = rep_vector(psi_a[2], n_prim - 1);
        
        // for successive primary occasions
        for (t in first[ii]:(n_prim - 1)) {
          
          // ecological trm
          trm[t][1, 1] = -(psi[1][t] + phi[1][t]);
          trm[t][1, 2] = psi[2][t];
          trm[t][2, 1] = psi[1][t];
          trm[t][2, 2] = -(psi[2][t] + phi[2][t]);
          trm[t][3, 1] = phi[1][t];
          trm[t][3, 2] = phi[2][t];
          trm[t][:, 3] = zeros_vector(3);
          
          // ecological (log) tpm
          tpm.1[t] = matrix_exp(trm[t] * tau[t])[:, 1:2];
          
        } // t
        
        // log-transform for log probabilities
        if (log_p) {
          tpm.1 = log(tpm.1);
        }
        
        // sample infection intensities
        ptarget += std_normal_lupdf(to_vector(n_z[ii]));
        n[i] = exp(rep_matrix(log(m[:, ii]), n_sec)' + n_z[ii] * sigma);
        
        // individual and sample infection detection probabilities
        delta.1 = 1 - pow(1 - r[1], m[:, ii]);
        delta.2 = 1 - pow(1 - r[2], n[i]);
        
        // detection probabilities (fixed at 1 for secondary of first capture)
        for (s in 1:(n_s - 1)) {
          p[s] = rep_matrix(p_a[s], n_sec, n_prim);
          p[s][first_sec[ii], first[ii]] = 1;
        } // s
        
        // for each primary occasion
        for (t in first[ii]:n_prim) {
          
          // for each secondary occasion
          for (k in 1:n_sec) {
            
            // observation and diagnostic processes
            if (log_p) {
              
              // observation tpm
              tpm.2[t, k][1, 1] = log(p[1][k, t]) + log1m(fr[1]);
              tpm.2[t, k][1, 2] = log(p[1][k, t]) + log(fr[1]);
              tpm.2[t, k][1, 3] = log1m(p[1][k, t]);
              tpm.2[t, k][2, 1] = log(p[2][k, t]) + log1m(delta.1[t]);
              tpm.2[t, k][2, 2] = log(p[2][k, t]) + log(delta.1[t]);
              tpm.2[t, k][2, 3] = log1m(p[2][k, t]);
              
              // diagnostic tpm
              tpm.3[t, k][1, 1] = log1m(fr[2]);
              tpm.3[t, k][1, 2] = log(fr[2]);
              tpm.3[t, k][2, 1] = log1m(delta.2[k, t]);
              tpm.3[t, k][2, 2] = log(delta.2[k, t]);
              
              // log probability of each alive observed state conditioned on data
              if (q[ii, t, k] == 1) {
                for (o in 1:(n_o - 1)) {
                  acc[t][o, k] = sum(tpm.3[t, k][o, y[ii, t, k]]);
                }
              }
            } else {
              
              // observation tpm
              tpm.2[t, k][1, 1] = p[1][k, t] * (1 - fr[1]);
              tpm.2[t, k][1, 2] = p[1][k, t] * fr[1];
              tpm.2[t, k][1, 3] = 1 - p[1][k, t];
              tpm.2[t, k][2, 1] = p[2][k, t] * (1 - delta.1[t]);
              tpm.2[t, k][2, 2] = p[2][k, t] * delta.1[t];
              tpm.2[t, k][2, 3] = 1 - p[2][k, t];
              
              // diagnostic tpm
              tpm.3[t, k][1, 1] = 1 - fr[2];
              tpm.3[t, k][1, 2] = fr[2];
              tpm.3[t, k][2, 1] = 1 - delta.2[k, t];
              tpm.3[t, k][2, 2] = delta.2[k, t];
              
              // probability of each alive observed state conditioned on data
              if (q[ii, t, k] == 1) {
                for (o in 1:(n_o - 1)) {
                  acc[t][o, k] = prod(tpm.3[t, k][o, y[ii, t, k]]);
                }
              }
            }
          } // k
          
          // initialise (log) marginal probability with first secondary (required to avoid gamma = 0)
          if (q[ii, t, 1] == 0) {
            gamma[1:2, t] = tpm.2[t, 1][:, 3];
          } else {
            if (log_p) {
              gamma[1:2, t] = log_mat_prod(tpm.2[t, 1][:, 1:2], acc[t][:, 1]);
            } else {
              gamma[1:2, t] = tpm.2[t, 1][:, 1:2] * acc[t][:, 1];
            }
          }
          
          // for successive primaries
          for (k in 2:n_sec) {
            
            // if individual was not detected
            if (q[ii, t, k] == 0) {
              
              // marginal detection (log) probabilities
              if (log_p) {
                gamma[1:2, t] += tpm.2[t, k][:, 3];
              } else {
                gamma[1:2, t] .*= tpm.2[t, k][:, 3];
              }
              
            } else {
              
              // marginal (log) probability after each secondary for each alive ecological state
              if (log_p) {
                gamma[1:2, t] += log_mat_prod(tpm.2[t, k][:, 1:2], acc[t][:, k]);
              } else {
                gamma[1:2, t] .*= tpm.2[t, k][1:2, 1:2] * acc[t][:, k];
              }
              
            }
          } // k
        } // t
        
        // marginal (log) probability of alive states at first capture
        if (log_p) {
          gamma[1:2, first[ii]] += [ log1m(eta), log(eta) ]';
        } else {
          gamma[1:2, first[ii]] .*= [ 1 - eta, eta ]';
        }
        
        // marginal (log) probability of alive states between first and last primary with captures
        if (first[ii] < last[ii]) {
          for (t in (first[ii] + 1):last[ii]) {
            if (log_p) {
              gamma[1:2, t] += log_mat_prod(tpm.1[t - 1][1:2, :], gamma[1:2, t - 1]);
            } else {
              gamma[1:2, t] .*= tpm.1[t - 1][1:2, 1:2] * gamma[1:2, t - 1];
            }
          } // t
        }
        
        // if last captured in the last primary
        if (last[ii] == n_prim) {
          
          // increment target with marginal log probabilities of alive states
          if (log_p) {
            ptarget += log_sum_exp(gamma[1:2, n_prim]);
          } else {
            ptarget += log(sum(gamma[1:2, n_prim]));
          }
        } else {
          
          // marginal (log) probabilities of all ecological states in primary after last capture
          if (log_p) {
            gamma[1:2, last[ii] + 1] += log_mat_prod(tpm.1[last[ii]][1:2, :], gamma[1:2, last[ii]]);
            gamma[3, last[ii] + 1] = log_mat_prod(tpm.1[last[ii]][3, :], gamma[1:2, last[ii]]);
          } else {
            gamma[1:2, last[ii] + 1] .*= tpm.1[last[ii]][1:2, 1:2] * gamma[1:2, last[ii]];
            gamma[3, last[ii] + 1] = tpm.1[last[ii]][3, 1:2] * gamma[1:2, last[ii]];
          }
          
          // if first occasion after last capture is the last primary
          if ((last[ii] + 1) == n_prim) {
            
            // increment target with marginal log probabilities of all ecological states
            if (log_p) {
              ptarget += log_sum_exp(gamma[:, n_prim]);
            } else {
              ptarget += log(sum(gamma[:, n_prim]));
            }
            
          }
          else {
            
            // marginal (log) probabilities of all ecological states until last primary
            for (t in (last[ii] + 2):n_prim) {
              if (log_p) {
                gamma[1:2, t] += log_mat_prod(tpm.1[t - 1][1:2, :], gamma[1:2, t - 1]);
                gamma[3, t] = log_sum_exp(log_mat_prod(tpm.1[t - 1][3, :], gamma[1:2, t - 1]), gamma[3, t - 1]);
              } else {
                gamma[1:2, t] .*= tpm.1[t - 1][1:2, 1:2] * gamma[1:2, t - 1];
                gamma[3, t] = tpm.1[t - 1][3, 1:2] * gamma[1:2, t - 1] + gamma[3, t - 1];
              }
            } // t
            
            // increment target with marginal probabilities of all ecological states
            if (log_p) {
              ptarget += log_sum_exp(gamma[:, n_prim]);
            } else {
              ptarget += log(sum(gamma[:, n_prim]));
            }
          }
        }
        
      } // i
      
      return(ptarget);
      
    }

In the model block, I call the partial sum function as follows to allow experimentation with different grainsizes:

  if (grainsize == 1) {
    target += reduce_sum(partial_sum_lupmf, seq_ind, grainsize, y, q, n_s, n_o, n_d, n_prim, n_sec, first, last, first_sec, tau, eta_a, phi_a, phi_b, psi_a, p_a, r, fr, m, n_z, sigma[2], log_p);
  } else {
    target += reduce_sum_static(partial_sum_lupmf, seq_ind, grainsize, y, q, n_s, n_o, n_d, n_prim, n_sec, first, last, first_sec, tau, eta_a, phi_a, phi_b, psi_a, p_a, r, fr, m, n_z, sigma[2], log_p);
  }

Now, when I give grainsize = 1 in R when I call cmdstanr::sample(), the model runs fine with both log_p = 0 (doing the likelihood calculations on the probability scale) and with log_p = 1 (doing the likelihood calculations on the log scale) and I get identical results. However, when I change grainsize to something else, say grainsize = floor(n_ind / n_threads / 2), I get the following long list of errors:

Chain 1 Rejecting initial value:
Chain 1   Gradient evaluated at the initial value is not finite.
Chain 1   Stan can't start sampling from this initial value.
Chain 1 Initialization between (-0.1, 0.1) failed after 100 attempts. 
Chain 1  Try specifying initial values, reducing ranges of constrained values, or reparameterizing the model.
Chain 1 Initialization failed.

So, it strikes me that the error is coming from changing the grainsize in the partial sum function. I can share the whole Stan program including data simulation, but in the meantime, is this known/expected behaviour?

thanks,

Matt

Changing grainsize shouldn’t cause a failure, so this looks like a bug.

The User’s Guide doc says

A grainsize of 1 leaves the partitioning entirely up to the scheduler. This should be the default way of using reduce_sum unless time is spent carefully picking grainsize . For picking a grainsize , see details below.

The examples later all show grain sizes that are even multiples of the data size, which seems very restrictive.

Hi Bob,

Thanks for chiming in. I also realised that I didn’t need the control flow to potentially use reduce_sum_static(), so now the likelihood in my model block is simply:

target += reduce_sum(partial_sum_lupmf, seq_ind, grainsize, y, q, n_s, n_o, n_d, n_prim, n_sec, first, last, first_sec, tau, eta_a, phi_a, phi_b, psi_a, p_a, r, fr, m, n_z, sigma[2], lp);

Again, setting grainsize to any integer in my Stan data list to give to cmdstanr::sample() and lp = 0 (using the probabilities directly in my likelihood) gives no problem. When I set lp = 1, I still get the same error for grainsize other 1 or 2. By using floor(n_ind / n_threads) I ensure the grainsize is always an integer. This does feel like a bug.

I was going to open an issue for this, but I could not reproduce the issue with a simpler example and this model is too complicated for me to be confident that the problem isn’t something ill-defined in the model that breaks on different sizes.

If you can create a simpler example that has this same issue, we can file an issue.

Also, thanks for the clarifications about the model on the issue, which I moved to the stan-dev/math repo, because that’s where new function implementations start. I haven’t seen HMMs with continuously variable time before. I’ve seen them with missing observations, but nothing completely non-uniform. I’m sort of surprised that matrix_exp is both stable and performant enough for this, but it may be fine in low dimensions.

Hey Bob, I also tried to replicate it with a simpler example but no luck. It’s quite strange though, because if I have separate Stan files using the log probabilities vs. the actual probabilities, there’s no problem with changing the grainsize. Only upon putting them in the same function and using the if() statements does the issue show up. I’m happy to share my simulation and analysis script and the Stan programs with you?

Section 2.1 of this paper written by much more knowledgeable people than myself talks about the approach of continuous time variations. It’s actually the norm in ecological studies that our (primary) occasions are NOT evenly spaced at all. Easy to correct for survival probabilities, but impossible to do so for state transitions between alive states.

I also see here that the motivation for the matrix exponential function in Stan seems to have been continuous time Markov models.

That’s the motivation from the original issue, but I think the real motivation for the devs who added it was steady state solutions to linear ODEs. I’ve used it to convert skew-symmetric matrixes into orthonormal matrices. It’s super useful, just hard to compute and not the most stable function.

I don’t think I’m going to have time to try to create a simple, reproducible example from your scripts. I’m afraid we have to prioritize things that are clearly demonstrable bugs on our side.

1 Like

Hey Bob, I do assume that the Stan implementation of the function is going to be at least a very good way to do it? Of course, in my example I only have 3*3 matrix.

Re: grainsize issue, no problem. I don’t mind the caveat that grainsize can’t be changed when using log probs. Hopefully it’ll come up for someone else that can clearly reproduce it.

Perhaps a workaround could be completely separate the two functions, so that you just have one if (logp) ... else statement? My honest guess is that there’s a bug somewhere in one of your multiple if branches. If you have separate versions of the function (call them A and B), one on the log scale and one on the probability scale, and they both work with any grainsize, I would find it shocking if a function doing if (logp) {A} else {B} fails. So shocking that you might find somebody willing to try to hunt down that bug :)

1 Like

Hey,

Thanks, you motivated me to keep digging. I found the issue causing the problem. Taking the log of tpm.1[t] inside the t loop does not cause the grainsize issue:

        // primary occasion intervals
        for (t in first[ii]:(n_prim - 1)) {
          
          // ecological TRM
          trm[t][1, 1] = -(psi[1][t] + phi[1][t]);
          trm[t][1, 2] = psi[2][t];
          trm[t][2, 1] = psi[1][t];
          trm[t][2, 2] = -(psi[2][t] + phi[2][t]);
          trm[t][3, 1] = phi[1][t];
          trm[t][3, 2] = phi[2][t];
          trm[t][:, 3] = zeros_vector(3);
          
          // ecological TPM
          tpm.1[t] = matrix_exp(trm[t] * tau[t])[:, 1:2];
          if (lp) {
            tpm.1[t] = log(tpm.1[t]);
          }
          
        } // t

However, taking the log of all of tpm.1 outside of the t loop does cause the issue. This makes sense, because there nan values for all values of t before first[ii], and that is causing the familiar log probability issue.

So, as expected, it wasn’t Stan, but it was me. Thanks for contributing everyone, and sorry for the oversight.

Edit: Actually, no, that doesn’t make sense. The log version has no trouble getting going with grainsize = 1. At least, I’ve identified that the thing that triggers is doing tpm.1 = log(tpm.1) outside of the t loop, but that it has nothing to do with the partial sum function. So, at the very least I’ve isolated what’s triggering it, but I think it does still seem like there’s a bug.

Edit2: FWIW, the following also works with any grainsize:

// log ecological TPM
if (lp) {
  tpm.1[first[ii]:(n_prim - 1)] = log(tpm.1[first[ii]:(n_prim - 1)]);
}

So the problem seems to occur when I try to do log(tpm.1), where tpm.1 is an object of type array[n_prim - 1] matrix[3, 2], and some of the matrices contained in tpm.1 are filled with nan. Again, it still works with grainsize = 1.

1 Like

This may also be a problem with tuples. We’re still ironing out some of the kinks in deeply nested usages. You might want to try just using two variables and see if that’s the problem.

I agree. The doc doesn’t say that there are restrictions on what the grainsize can be.