Oh wow
I hugely appreciate the time you have taken to offer these suggestions. All very smart and clear.
And yes, any improvements on how to speed things up would be appreciated.
A little background, that may make the code easier to follow:
We are modelling the visual foraging paradigm from cognitive psychology. (See Clarke et al 2022, Comp Bio for further details). In short, a number of participants complete a number of trials. Trials can be in one of K conditions (usually 2).
in each trial, participants are presented with a number of items (around 40-80). Some of these are targets, some are distracters. The targets can be one of several different types (usually there are only two types of targets, so this is currently hardcoded). The participant’s task is to simply click on all the targets to “collect then”. I am interested in trying to predict the order in which these items are selected.
I model this as a sampling without replacement procedure. I currently have four main parameters:
- bA measures whether you prefer target items of class A over class B.
- b_stick measures whether you prefer to select an item that matches the same class as the previously selected target class.
- rho_delta is used to put more weight on items are are close to the previously selected item
- rho_psi is used to weight items that are ahead or behind our direction of travel.
bA and b_stick are converted to probabilities using inv_logit, and then the weights for the remaining items are calculated as w = pA .* p_stick. The two spatial components go into a negative exponential and are also multiplied by w.
Here’s all the code from my current implementation. Sorry if it’s hard to work through. Until recently, most of the code in the transformed parameters block was in the model{} block. I wonder if I should move it back in.
I also wonder if I should keep pre-processing variables like delta (all the inter-target distances) in R, or if it would make more sense to code it up in the transformed data{} block.
// spatial foraging project
functions{
vector standarise_weights(vector w, int n_targets, vector remaining_items) {
/* set weights of found items to 0 and divide by the sum of
remaining weights so that they sum to 1 */
vector[n_targets] w_s = w .* remaining_items;
w_s = w_s / sum(w_s);
return(w_s);
}
vector init_spat_bias(int n_targets, vector x, vector y,
vector init_bias_params, real lambda) {
vector[n_targets] w, w1, w2;
/* for init selection we want to weight each item by
how likely it is to come from a two component beta
mixture model */
for (ii in 1:n_targets) {
w1[ii] = beta_lpdf(x[ii] | init_bias_params[1], init_bias_params[2])
+ beta_lpdf(y[ii] | init_bias_params[3], init_bias_params[4]);
w2[ii] = beta_lpdf(x[ii] | init_bias_params[5], init_bias_params[6])
+ beta_lpdf(y[ii] | init_bias_params[8], init_bias_params[7]);
}
w = lambda * exp(w1) + (1-lambda) * exp(w2);
return(w);
}
vector compute_spatial_weights(int n, int n_targets, int ii,
real rho_delta, real rho_psi,
real u_delta, real u_psi, vector delta, vector psi, vector phi,
vector x, vector y, vector init_bias_params, real lambda) {
vector[n_targets] w;
w = rep_vector(1, n_targets);
// now start computing the weights
if (n == 1) {
// calculate inital selection weights based on spatial location
if (lambda < 0)
{
// if lambda <0, do not apply initial bias
w = rep_vector(1, n_targets);
} else {
// if lambda >= 0 we apply initial bias
w = init_spat_bias(n_targets, x, y, init_bias_params, lambda);
}
w = standarise_weights(w, n_targets, rep_vector(1, n_targets));
} else {
if (n == 2) {
// for the second selected target, weight by distance from the first
w = exp(-(rho_delta + u_delta) * delta);
} else {
// for all later targets, also weight by direciton
w = exp(-(rho_delta + u_delta) * delta - (rho_psi + u_psi) * psi);
}
}
w = standarise_weights(w, n_targets, rep_vector(1, n_targets));
return(w);
}
}
data {
int <lower = 1> N; // total number of selected targets over the whole experiment
int <lower = 1> L; // number of participant levels
int <lower = 1> K; // number of experimental conditions
int <lower = 1> n_trials; // total number of trials (overall)
int <lower = 1> n_classes; // number of target classes - we assume this is constant over n_trials
int <lower = 1> n_targets; // total number of targets per trial
array[N] int <lower = 0, upper = n_targets> found_order; // = 1 is starting a new trial, 0 otherwise
array[N] int <lower = 1> Y; // target IDs - which target was selected here? This is what we predict
// (x, y) coordinates of each target
array[n_trials] vector<lower=0,upper=1>[n_targets] item_x;
array[n_trials] vector<lower=0,upper=1>[n_targets] item_y;
array[N] vector<lower = 0>[n_targets] delta; // distance measures
array[N] vector[n_targets] psi; // direction measures (relative)
array[N] vector[n_targets] phi; // direction measures (absolute)
array[n_trials] int <lower = 1, upper = K> X; // trial features (ie, which condition are we in)
matrix<lower = -1, upper = 1>[n_trials, n_targets] item_class; // target class, one row per trial
array[N] vector<lower = -1, upper = 1>[n_targets] S; // stick/switch (does this targ match prev targ)
array[N] int <lower = 1, upper = L> Z; // random effect levels
array[N] int<lower = 1, upper = n_trials> trial; // what trial are we on?
int<lower = 0, upper = 1> fit_init_bias; // should we fit the initial bias?
real prior_sd_bA; // param for class weight prior
real prior_sd_b_stick; // prior for sd for bS
real prior_mu_rho_delta;
real prior_sd_rho_delta;
real prior_mu_rho_psi;
real prior_sd_rho_psi;
}
parameters {
// These are all the parameters we want to fit to the data
////////////////////////////////////
// fixed effects
////////////////////////////////////
/* in order to allow for correlations between the
variables, these are all stored in a list
these include bA, bS (stick weight), and the two spatial
sigmas, along with the floor (chance of selectin an
item at random)
*/
array[K] real bA; // weights for class A compared to B
array[K] real b_stick; // stick-switch rates
array[K] real rho_delta; // distance tuning
array[K] real rho_psi; // direction tuning
///////////////////////////////
// random effects
///////////////////////////////
array[K] vector[L] uA; // weights for class A compared to B
array[K] vector[L] u_stick; // stick-switch rates
array[K] vector[L] u_delta; // distance tuning
array[K] vector[L] u_psi; // direction tuning
// initial bais parameters
/* These are constant over participants so
should not be included in the random effect structure
order of params:
a, b for comp 1, x dimension
a, b for comp 1, y dimension
a, b for comp 2 x dimension
a, b for comp 2 y dimension
*/
vector<lower = -5, upper = 5>[8] init_bias_params;
/* lambda varies to person to person,
so may want to have it correlated (potentially)
with the b params.*/
array[L] real<lower=0, upper=1> lambda;
}
transformed parameters {
// some counters and index variables, etc.
vector[n_targets] remaining_items; // binary vector that tracks which targets have been found
vector[n_targets] m; // does this target match the previous target?
real lambdall;
vector[8] init_bias_params2; // exp transform
init_bias_params2 = exp(init_bias_params);
array[N] simplex[n_targets] weights;
//////////////////////////////////////////////////
// // step through data row by row and define LLH
//////////////////////////////////////////////////
for (ii in 1:N) {
// check if we are at the start of a new trial
// if we are, initialise a load of things
if (found_order[ii] == 1) {
// as we're at the start of a new trial, reset the remaining_items tracker
remaining_items = rep_vector(1, n_targets);
}
// update the class weights to take random effects into account
// set the weight of each target to be its class weight
weights[ii] = (bA[X[trial[ii]]] + uA[X[trial[ii]], Z[ii]]) * to_vector(item_class[trial[ii]]) ;
// apply spatial weighting
// first of all, check if we should fit inital bias
if (fit_init_bias == 0) {
lambdall = -1;
} else {
lambdall = lambda[Z[ii]];
}
if (found_order[ii] == 1) {
weights[ii] = inv_logit(weights[ii]);
} else {
// check which targets match the previously selected target
// this is precomputed in S[ii]
weights[ii] = inv_logit(weights[ii]) .* inv_logit((b_stick[X[trial[ii]]] + u_stick[X[trial[ii]], Z[ii]]) * S[ii]);
}
weights[ii] = weights[ii] .* compute_spatial_weights(found_order[ii], n_targets, ii,
rho_delta[X[trial[ii]]], rho_psi[X[trial[ii]]], u_delta[X[trial[ii]], Z[ii]], u_psi[X[trial[ii]], Z[ii]],
delta[ii], psi[ii], phi[ii],
item_x[trial[ii]], item_y[trial[ii]], init_bias_params2, lambdall);
// remove already-selected items, and standarise to sum = 1
weights[ii] = standarise_weights(weights[ii], n_targets, remaining_items);
// do I need this if statement?
if (Y[ii] == n_targets+1) {
// trial completed
} else {
// remove found target from list of remaining remaining_items
remaining_items[Y[ii]] = 0;
}
}
}
model {
/////////////////////////////////////////////////////
// Define Priors
////////////////////////////////////////////////////
//-----priors intial item selection distributions---
for (ii in 1:K) {
// priors for fixed effects
target += normal_lpdf(bA[ii] | 0, prior_sd_bA);
target += normal_lpdf(b_stick[ii] | 0, prior_sd_b_stick);
target += normal_lpdf(rho_delta[ii] | prior_mu_rho_delta, prior_sd_rho_delta);
target += normal_lpdf(rho_psi[ii] | prior_mu_rho_psi, prior_sd_rho_psi);
// priors for random effects
target += normal_lpdf(uA[ii] | 0, 0.5);
target += normal_lpdf(u_stick[ii] | 0, 0.5);
target += normal_lpdf(u_delta[ii] | 0, 1);
target += normal_lpdf(u_psi[ii] | 0, 0.5);
}
// priors for intial bias
init_bias_params[1] ~ normal(1.5, 0.1);
init_bias_params[2] ~ normal(1.5, 0.1);
init_bias_params[3] ~ normal(1.5, 0.1);
init_bias_params[4] ~ normal(1.5, 0.1);
init_bias_params[5] ~ normal(2.0, 0.25);
init_bias_params[6] ~ normal(7.0, 0.25);
init_bias_params[7] ~ normal(2.0, 0.25);
init_bias_params[8] ~ normal(7.0, 0.25);
lambda ~ normal(0.5, 0.25);
//////////////////////////////////////////////////
// // step through data row by row and define LLH
//////////////////////////////////////////////////
for (ii in 1:N) {
//print(sum(weights[ii]));
// likelihood!
target += categorical_lpmf(Y[ii] | weights[ii]);
}
}
generated quantities {
// here we can output our prior distributions
real prior_bA = normal_rng(0, prior_sd_bA);
real prior_b_stick = normal_rng(0, prior_sd_b_stick);
real prior_rho_delta = normal_rng(prior_mu_rho_delta, prior_sd_rho_delta);
real prior_rho_psi = normal_rng(prior_mu_rho_psi, prior_sd_rho_psi);
real prior_direction_bias = normal_rng(-2, 3);
}