Advice requested on using categorical_lpmf() in a complex model

Hello,

I am looking for any advice/best practices on how to use a categorical likelihood in my Stan model.

In short, my question is how should I compute the weights for categorical_lpmf(Y[ii] | weights[ii]) when the weights themselves are computed from a small set of latent parameters (and these parameters are the things that I am estimating)

  for (ii in 1:N) {    
    target += categorical_lpmf(Y[ii] | weights[ii]);
    }

where N ~ 100k and weights[ii] is a 1x40 vector

In the current implementation of my model, weights[ii] is being computed for each ii in my model{} block. (For details see Clarke et al 2022 PLoS Comp Bio. I’m happy to share the code too, I’m sure it could be improved. )

These weights depend on four parameters that are defined in the parameters block (actually more than four as I’m doing things multi-level, but I think that isn’t important to the question here).

My issue appears to be that if I define weights[ii] as a vector in the model block, it quite often does not sum to 1 (i.e, 0.99999996) due to rounding errors and the sampler is unhappy.

I considered using a simplex to solve this problem, but it is not possible to define a simplex in the model{} block (I understand why).

My next idea was that I could move all the code for generating weights[ii] to the transformed parameters{} block, but this feels quite clunky. First of all, the algorithm for calculating weights[ii] is quite complicated and my code makes use of a number of intermediate variables. My understanding is that any variables declared in the transformed parameters block will be saved in the output? Furthermore, each weights[ii] simplex is 40 elements wrong, and ii goes from 1 to 100,000. Saving this for every posterior sample feels like it is going to be quite unwieldy.

It’s hard to say what you should be doing here without seeing the rest of your Stan program. The problem with our categorical distribution is that it’s not vectorized, so you can’t speed up autodiff that way. There are ways to speed things up using brute force by grouping categorical into multinomials if weights[i] = weights[j] for some i != j.

If the weights were computed by applying inverse_logit to a linear predictor, it’s better to use categorical_logit and take the linear predictor directly as an argument.

You won’t have this problem with catergorical_logit. With some other way of defining it, you can set

for (n in 1:N) {
  weights[n] /= sum(weights[n]);

If you’d prefer, you can avoid our error checking that the weights form a simplex and implement directly as

for (n in 1:N) {
  target += log(weights[n][Y[n]]);
}

The only difference between local variables in the model block and variables in the transformed parameters block is that the latter are saved in the output. It won’t affect speed other than the transformed parameter versions will be a bit slower due to I/O overhead.

Yes. Unless you make them local. So you can do this:

transformed parameters {
  real alpha;
  {
      real x = f(y, z);  // x is local and not saved
      alpha = softmax(x);
  }
  ...

The braces introduce a local scope, turning the declaration of x into a local variable that is not saved. You can also use functions to define things to avoid this problem.

In programming language terminology, you can’t declare a vector to be a simplex, but you can define a vector to be a simplex. That is, you can’t do this:

model {
  simplex[N] theta = softmax(x);  // declares theta to be a simplex
  ...   

but you can do this

model {
  vector[N] theta = softmax(x);  // defines theta to be a simplex
  ...

If you’re able to share more of your Stan program, we can perhaps find other ways to speed it up.

Oh wow

I hugely appreciate the time you have taken to offer these suggestions. All very smart and clear.

And yes, any improvements on how to speed things up would be appreciated.

A little background, that may make the code easier to follow:

We are modelling the visual foraging paradigm from cognitive psychology. (See Clarke et al 2022, Comp Bio for further details). In short, a number of participants complete a number of trials. Trials can be in one of K conditions (usually 2).

in each trial, participants are presented with a number of items (around 40-80). Some of these are targets, some are distracters. The targets can be one of several different types (usually there are only two types of targets, so this is currently hardcoded). The participant’s task is to simply click on all the targets to “collect then”. I am interested in trying to predict the order in which these items are selected.

I model this as a sampling without replacement procedure. I currently have four main parameters:

  • bA measures whether you prefer target items of class A over class B.
  • b_stick measures whether you prefer to select an item that matches the same class as the previously selected target class.
  • rho_delta is used to put more weight on items are are close to the previously selected item
  • rho_psi is used to weight items that are ahead or behind our direction of travel.

bA and b_stick are converted to probabilities using inv_logit, and then the weights for the remaining items are calculated as w = pA .* p_stick. The two spatial components go into a negative exponential and are also multiplied by w.

Here’s all the code from my current implementation. Sorry if it’s hard to work through. Until recently, most of the code in the transformed parameters block was in the model{} block. I wonder if I should move it back in.

I also wonder if I should keep pre-processing variables like delta (all the inter-target distances) in R, or if it would make more sense to code it up in the transformed data{} block.

// spatial foraging project

functions{

  vector standarise_weights(vector w, int n_targets, vector remaining_items) {

    /* set weights of found items to 0 and divide by the sum of 
    remaining weights so that they sum to 1 */
    vector[n_targets] w_s = w .* remaining_items;  

    w_s = w_s / sum(w_s);
    return(w_s);
  }

  vector init_spat_bias(int n_targets, vector x, vector y, 
    vector init_bias_params, real lambda) {

    vector[n_targets] w, w1, w2;

    /* for init selection we want to weight each item by
    how likely it is to come from a two component beta 
    mixture model */
    for (ii in 1:n_targets) {
      w1[ii] = beta_lpdf(x[ii] | init_bias_params[1], init_bias_params[2])
      + beta_lpdf(y[ii] | init_bias_params[3], init_bias_params[4]);

      w2[ii] = beta_lpdf(x[ii] | init_bias_params[5], init_bias_params[6])
      + beta_lpdf(y[ii] | init_bias_params[8], init_bias_params[7]);
    }

    w = lambda * exp(w1) + (1-lambda) * exp(w2);

    return(w);
  }

  vector compute_spatial_weights(int n, int n_targets, int ii,
   real rho_delta, real rho_psi, 
   real u_delta, real u_psi, vector delta, vector psi, vector phi,
   vector x, vector y, vector init_bias_params, real lambda) {

    vector[n_targets] w;

    w = rep_vector(1, n_targets); 
    // now start computing the weights
    if (n == 1) {

      // calculate inital selection weights based on spatial location
      if (lambda < 0) 
      {
        // if lambda <0, do not apply initial bias
        w = rep_vector(1, n_targets);
      } else {
          // if lambda >= 0 we apply initial bias
          w = init_spat_bias(n_targets, x, y, init_bias_params, lambda);
      }

      w = standarise_weights(w, n_targets, rep_vector(1, n_targets));

    } else {

      if (n == 2) {
        // for the second selected target, weight by distance from the first
        w =  exp(-(rho_delta + u_delta) * delta);
          
      } else {
              // for all later targets, also weight by direciton
        w =  exp(-(rho_delta + u_delta) * delta - (rho_psi + u_psi) * psi);
      }
    }

    w = standarise_weights(w, n_targets, rep_vector(1, n_targets));
    
    return(w);
  }
}

data {
  int <lower = 1> N; // total number of selected targets over the whole experiment
  int <lower = 1> L; // number of participant levels 
  int <lower = 1> K; // number of experimental conditions  

  int <lower = 1> n_trials;  // total number of trials (overall)
  int <lower = 1> n_classes; // number of target classes - we assume this is constant over n_trials
  int <lower = 1> n_targets; // total number of targets per trial
  array[N] int <lower = 0, upper = n_targets> found_order; // = 1 is starting a new trial, 0 otherwise

  array[N] int <lower = 1> Y; // target IDs - which target was selected here? This is what we predict

  // (x, y) coordinates of each target
  array[n_trials] vector<lower=0,upper=1>[n_targets] item_x;
  array[n_trials] vector<lower=0,upper=1>[n_targets] item_y;

  array[N] vector<lower = 0>[n_targets] delta; // distance measures
  array[N] vector[n_targets] psi; // direction measures (relative)
  array[N] vector[n_targets] phi; // direction measures (absolute)

  array[n_trials] int <lower = 1, upper = K> X; // trial features (ie, which condition are we in)
  matrix<lower = -1, upper = 1>[n_trials, n_targets] item_class; // target class, one row per trial
  array[N] vector<lower = -1, upper = 1>[n_targets] S; // stick/switch (does this targ match prev targ) 
  array[N] int <lower = 1, upper = L> Z; // random effect levels
  array[N] int<lower = 1, upper = n_trials> trial; // what trial are we on? 

  int<lower = 0, upper = 1> fit_init_bias; // should we fit the initial bias?

  real prior_sd_bA; // param for class weight prior
  real prior_sd_b_stick; // prior for sd for bS
  real prior_mu_rho_delta;
  real prior_sd_rho_delta;
  real prior_mu_rho_psi;
  real prior_sd_rho_psi;
}

parameters {
  // These are all the parameters we want to fit to the data

  ////////////////////////////////////
  // fixed effects
  ////////////////////////////////////

  /* in order to allow for correlations between the
  variables, these are all stored in a list
  these include bA, bS (stick weight), and the two spatial 
  sigmas, along with the floor (chance of selectin an 
  item at random)
  */
  array[K] real bA; // weights for class A compared to B  
  array[K] real b_stick; // stick-switch rates 
  array[K] real rho_delta; // distance tuning
  array[K] real rho_psi; // direction tuning

  ///////////////////////////////
  // random effects
  ///////////////////////////////

  array[K] vector[L] uA; // weights for class A compared to B  
  array[K] vector[L] u_stick; // stick-switch rates 
  array[K] vector[L] u_delta; // distance tuning
  array[K] vector[L] u_psi; // direction tuning

  // initial bais parameters

  /* These are constant over participants so 
  should not be included in the random effect structure

  order of params:
  a, b for comp 1, x dimension
  a, b for comp 1, y dimension
  a, b for comp 2 x dimension
  a, b for comp 2 y dimension
  */
  vector<lower = -5, upper = 5>[8] init_bias_params;

  /* lambda varies to person to person, 
  so may want to have it correlated (potentially)
  with the b params.*/
  array[L] real<lower=0, upper=1> lambda;
}

transformed parameters {

  
  // some counters and index variables, etc.
  vector[n_targets] remaining_items; // binary vector that tracks which targets have been found
  vector[n_targets] m; // does this target match the previous target?
  real lambdall; 

  vector[8] init_bias_params2; // exp transform
  init_bias_params2 = exp(init_bias_params);

  array[N] simplex[n_targets] weights;

  
    //////////////////////////////////////////////////
  // // step through data row by row and define LLH
  //////////////////////////////////////////////////  
  for (ii in 1:N) {
  
    // check if we are at the start of a new trial
    // if we are, initialise a load of things
    if (found_order[ii] == 1) {

              
      // as we're at the start of a new trial, reset the remaining_items tracker
      remaining_items = rep_vector(1, n_targets);

    }

    // update the class weights to take random effects into account
    // set the weight of each target to be its class weight
    weights[ii] = (bA[X[trial[ii]]] + uA[X[trial[ii]], Z[ii]]) * to_vector(item_class[trial[ii]]) ;

    // apply spatial weighting

    // first of all, check if we should fit inital bias
    if (fit_init_bias == 0) {
      lambdall = -1;
      } else {
        lambdall = lambda[Z[ii]];
      }

      if (found_order[ii] == 1) {

        weights[ii] = inv_logit(weights[ii]);

        } else {

          // check which targets match the previously selected target
          // this is precomputed in S[ii]
          weights[ii] = inv_logit(weights[ii]) .* inv_logit((b_stick[X[trial[ii]]] + u_stick[X[trial[ii]], Z[ii]]) * S[ii]); 

        }

        weights[ii] = weights[ii] .* compute_spatial_weights(found_order[ii], n_targets, ii,
       rho_delta[X[trial[ii]]], rho_psi[X[trial[ii]]], u_delta[X[trial[ii]], Z[ii]], u_psi[X[trial[ii]], Z[ii]], 
       delta[ii], psi[ii], phi[ii],
       item_x[trial[ii]], item_y[trial[ii]], init_bias_params2, lambdall);
        
        // remove already-selected items, and standarise to sum = 1
        weights[ii] = standarise_weights(weights[ii], n_targets, remaining_items);   

        // do I need this if statement? 
        if (Y[ii] == n_targets+1) {
          // trial completed
          } else {
            // remove found target from list of remaining remaining_items
            remaining_items[Y[ii]] = 0;
          }
        }

}

model {

  /////////////////////////////////////////////////////
  // Define Priors
  ////////////////////////////////////////////////////

  //-----priors intial item selection distributions---

  
  for (ii in 1:K) {
    // priors for fixed effects
    target += normal_lpdf(bA[ii] | 0, prior_sd_bA);
    target += normal_lpdf(b_stick[ii] | 0, prior_sd_b_stick);
    target += normal_lpdf(rho_delta[ii] | prior_mu_rho_delta, prior_sd_rho_delta);
    target += normal_lpdf(rho_psi[ii] | prior_mu_rho_psi, prior_sd_rho_psi);

    // priors for random effects
    target += normal_lpdf(uA[ii]      | 0, 0.5);
    target += normal_lpdf(u_stick[ii] | 0, 0.5);
    target += normal_lpdf(u_delta[ii] | 0, 1);
    target += normal_lpdf(u_psi[ii]   | 0, 0.5);
  }

  // priors for intial bias
  init_bias_params[1] ~ normal(1.5, 0.1);
  init_bias_params[2] ~ normal(1.5, 0.1);
  init_bias_params[3] ~ normal(1.5, 0.1);
  init_bias_params[4] ~ normal(1.5, 0.1);
  init_bias_params[5] ~ normal(2.0, 0.25);
  init_bias_params[6] ~ normal(7.0, 0.25);
  init_bias_params[7] ~ normal(2.0, 0.25);
  init_bias_params[8] ~ normal(7.0, 0.25);

  lambda ~ normal(0.5, 0.25);

  //////////////////////////////////////////////////
  // // step through data row by row and define LLH
  //////////////////////////////////////////////////  
  for (ii in 1:N) {    
      //print(sum(weights[ii]));
    // likelihood! 
    target += categorical_lpmf(Y[ii] | weights[ii]);

       
    }
}

generated quantities {
        // here we  can output our prior distributions
        real prior_bA = normal_rng(0, prior_sd_bA);
        real prior_b_stick = normal_rng(0, prior_sd_b_stick);
        real prior_rho_delta = normal_rng(prior_mu_rho_delta, prior_sd_rho_delta);
        real prior_rho_psi = normal_rng(prior_mu_rho_psi, prior_sd_rho_psi);
        real prior_direction_bias = normal_rng(-2, 3);
      }

I have spent the past day tidying things up. It now runs quite smoothly on my synthetic data with no warnings. Here’s hoping it will scale up to real data without any problems.

// spatial foraging project

functions{

  vector standarise_weights(vector w, int n_targets, vector remaining_items) {

    /* set weights of found items to 0 and divide by the sum of 
    remaining weights so that they sum to 1 */
    vector[n_targets] w_s = w .* remaining_items;  

    w_s = w_s / sum(w_s);

    return(w_s);
  }

  vector init_spat_bias(int n_targets, vector x, vector y, vector ab, real lambda) {

    vector[n_targets] w, w1, w2;

    /* for init selection we want to weight each item by
    how likely it is to come from a two-component beta 
    mixture model */
    for (ii in 1:n_targets) {

      w1[ii] = beta_lpdf(x[ii] | ab[1], ab[2])
        + beta_lpdf(y[ii] | ab[3], ab[4]);

      w2[ii] = beta_lpdf(x[ii] | ab[5], ab[6])
        + beta_lpdf(y[ii] | ab[8], ab[7]);
    }

    w = lambda * exp(w1) + (1-lambda) * exp(w2);

    return(w);
  }

  vector compute_spatial_weights(
    int n, int n_targets, int ii,
    real u_delta, real u_psi, vector delta, vector psi, vector phi,
    vector x, vector y, vector init_bias_params, real lambda) {

    vector[n_targets] w;

    w = rep_vector(1, n_targets); 

    // now start computing the weights
    if (n == 1) {

      //Calculate initial selection weights based on spatial location
      w = init_spat_bias(n_targets, x, y, init_bias_params, lambda);

    } 
    else if (n == 2) {

      // for the second selected target, weight by distance from the first
      w =  exp(-u_delta*delta);
          
    } else {

      // for all later targets, also weight by direciton
      w =  exp(-u_delta*delta - u_psi*psi);

    }

    return(w);
  }
}

data {
  int <lower = 1> N; // total number of selected targets over the whole experiment
  int <lower = 1> L; // number of participant levels 
  int <lower = 1> K; // number of experimental conditions  

  int <lower = 1> n_trials;  // total number of trials (overall)
  int <lower = 1> n_classes; // number of target classes - we assume this is constant over n_trials
  int <lower = 1> n_targets; // total number of targets per trial
  array[N] int <lower = 0, upper = n_targets> found_order; // = 1 is starting a new trial, 0 otherwise

  array[N] int <lower = 1> Y; // target IDs - which target was selected here? This is what we predict

  // (x, y) coordinates of each target
  array[n_trials] vector<lower=0,upper=1>[n_targets] item_x;
  array[n_trials] vector<lower=0,upper=1>[n_targets] item_y;

  array[N] vector<lower = 0>[n_targets] delta; // distance measures
  array[N] vector[n_targets] psi; // direction measures (relative)
  array[N] vector[n_targets] phi; // direction measures (absolute)

  array[n_trials] int <lower = 1, upper = K> X; // trial features (ie, which condition are we in)
  matrix<lower = -1, upper = 1>[n_trials, n_targets] item_class; // target class, one row per trial
  array[N] vector<lower = -1, upper = 1>[n_targets] S; // stick/switch (does this targ match prev targ) 
  array[N] int <lower = 1, upper = L> Z; // random effect levels
  array[N] int<lower = 1, upper = n_trials> trial; // what trial are we on? 

  real prior_sd_bA; // param for class weight prior
  real prior_sd_b_stick; // prior for sd for bS
  real prior_mu_rho_delta;
  real prior_sd_rho_delta;
  real prior_mu_rho_psi;
  real prior_sd_rho_psi;
}

transformed data{

  array[N] vector[n_targets] remaining_items;

  for (n in 1:N) {
        // check if we are at the start of a new trial
    // if we are, initialise a load of things
    if (found_order[n] == 1) {
             
      // as we're at the start of a new trial, reset the remaining_items tracker
      remaining_items[n] = rep_vector(1, n_targets);

    } else {

      remaining_items[n] = remaining_items[n-1];
      remaining_items[n][Y[n-1]] = 0;
      
    }
  }
}

parameters {
  // These are all the parameters we want to fit to the data

  ////////////////////////////////////
  // fixed effects
  ////////////////////////////////////

  /* in order to allow for correlations between the
  variables, these are all stored in a list
  these include bA, bS (stick weight), and the two spatial 
  sigmas, along with the floor (chance of selecting an 
  item at random)
  */
  array[K] real bA; // weights for class A compared to B  
  array[K] real b_stick; // stick-switch rates 
  array[K] real<lower = 0> rho_delta; // distance tuning
  array[K] real rho_psi; // direction tuning

  ///////////////////////////////
  // random effects
  ///////////////////////////////

  array[K] vector[L] uA; // weights for class A compared to B  
  array[K] vector[L] u_stick; // stick-switch rates 
  array[K] vector[L] u_delta; // distance tuning
  array[K] vector[L] u_psi; // direction tuning

  ///////////////////////////////
  // initial bais parameters
  ///////////////////////////////

  /* These are constant over participants so 
  should not be included in the random effect structure

  order of params:
  a, b for comp 1, x dimension
  a, b for comp 1, y dimension
  a, b for comp 2, x dimension
  a, b for comp 2, y dimension
  */
  vector<lower = -5, upper = 5>[8] init_bias_params;

  /* lambda varies to person to person, 
  so may want to have it correlated (potentially)
  with the b params.*/
  array[L] real<lower=0, upper=1> lambda;
}

transformed parameters {

  vector[8] init_bias_params2; // exp transform
  init_bias_params2 = exp(init_bias_params);

  // combine fixed and random effects
  array[K] vector[L] zA; 
  array[K] vector[L] z_stick; 
  array[K] vector[L] z_delta; 
  array[K] vector[L] z_psi; 

  for (kk in 1:K) {
    zA[kk] = bA[kk] + uA[kk];
    z_stick[kk] = b_stick[kk] + u_stick[kk];
    z_delta[kk] = rho_delta[kk] + u_delta[kk];
    z_psi[kk]   = rho_psi[kk] + u_psi[kk];
  }
}

model {

  /////////////////////////////////////////////////////
  // Define Priors
  ////////////////////////////////////////////////////
  for (ii in 1:K) {
    // priors for fixed effects
    target += normal_lpdf(bA[ii] | 0, prior_sd_bA);
    target += normal_lpdf(b_stick[ii] | 0, prior_sd_b_stick);
    target += normal_lpdf(rho_delta[ii] | prior_mu_rho_delta, prior_sd_rho_delta);
    target += normal_lpdf(rho_psi[ii] | prior_mu_rho_psi, prior_sd_rho_psi);

    // priors for random effects
    target += normal_lpdf(uA[ii]      | 0, 0.5);
    target += normal_lpdf(u_stick[ii] | 0, 0.5);
    target += normal_lpdf(u_delta[ii] | 0, 1);
    target += normal_lpdf(u_psi[ii]   | 0, 0.5);
  }

  //-----priors intial item selection distributions---
  init_bias_params[1] ~ normal(0.5, 0.1);
  init_bias_params[2] ~ normal(0.5, 0.1);
  init_bias_params[3] ~ normal(0.5, 0.1);
  init_bias_params[4] ~ normal(0.5, 0.1);
  init_bias_params[5] ~ normal(0.7, 0.25);
  init_bias_params[6] ~ normal(2.0, 0.25);
  init_bias_params[7] ~ normal(0.7, 0.25);
  init_bias_params[8] ~ normal(2.0, 0.25);

  lambda ~ normal(0.5, 0.20) T[0, 1];

  //////////////////////////////////////////////////
  // // step through data row by row and define LLH
  //////////////////////////////////////////////////  
  vector[n_targets] weights;

  // some counters and index variables, etc.
  int t; // trial counter
  //////////////////////////////////////////////////
  // // step through data row by row and define LLH
  //////////////////////////////////////////////////  
  for (ii in 1:N) {

    t = trial[ii];
 
    // set the weight of each target to be its class weight
    weights = (zA[X[t], Z[ii]]) * to_vector(item_class[t]) ;

    // multiply weights by stick/switch preference
    weights = inv_logit(weights) .* inv_logit(z_stick[X[t], Z[ii]] * S[ii]); 

    weights = weights .* compute_spatial_weights(found_order[ii], n_targets, ii,
       z_delta[X[t], Z[ii]], z_psi[X[t], Z[ii]], 
       delta[ii], psi[ii], phi[ii],
       item_x[t], item_y[t], init_bias_params2, lambda[Z[ii]]);
        
    // remove already-selected items, and standarise to sum = 1 
    weights = standarise_weights(weights, n_targets, remaining_items[ii]);   

    target += log((weights)[Y[ii]]);
       
  }
}

generated quantities {
        // here we  can output our prior distributions
        real prior_bA = normal_rng(0, prior_sd_bA);
        real prior_b_stick = normal_rng(0, prior_sd_b_stick);
        real prior_rho_delta = normal_rng(prior_mu_rho_delta, prior_sd_rho_delta);
        real prior_rho_psi = normal_rng(prior_mu_rho_psi, prior_sd_rho_psi);
        real prior_direction_bias = normal_rng(-2, 3);
      }