Two more ideas!
How about making a custom log_sum_exp that ignores infs:
real log_sum_exp_ignore_inf(vector v) {
int N = size(v);
vector noninfs[size(v)];
int q = 1;
for (n in 1:N) {
if (is_inf(v[n]) == 0) {
noninfs[q] = v[n];
q = q + 1;
}
}
return log_sum_exp(noninfs[1:q]);
}
Or make two arrays that encode the possible states at each time sparsely (N time points, 4 states):
int number_of_states_at_T[N] = {
1,
2,
4,
4,
...
}
// -1 is a non-state
int state_ids_at_time_T[N, 4] = {
{1, -1, -1, -1}, //Only one possible state
{2, 3, -1, -1}, //Only two possible states
{1, 2, 3, 4},
{1, 2, 3, 4},
...
}
Then your two inner loops that loop over K now would instead loop over number_of_states_at_T[t] (outer) and number_of_states_at_T[t - 1] (inner). You’d have to account for a variable length log_sum_exp too.