This works, don’t ask me how I got it to (it was messy).
functions {
real fs_cdf(real t, real a) {
if (a < 1) {
reject("a must be >= 1, found a = ", a);
}
return erfc(inv_sqrt(2 * a * t));
}
vector make_vars(real mu) {
real mu2 = pow(mu, 2);
real t_tilde = 0.12 + 0.5 * exp(-mu2 / 3);
real a = (3 + sqrt(9 + 4 * mu2)) / 6;
real sqrtamu = sqrt((a - 1) * mu2 / a);
real fourmu2pi = (4 * mu2 + pi() ^ 2) / 8;
real Cf1s = sqrt(a) * exp(-sqrtamu);
real Cf1l = pi() / (4 * fourmu2pi);
real CF1st = Cf1s * fs_cdf(t_tilde | a);
real F1lt = -expm1(-t_tilde * fourmu2pi);
real F1inf = CF1st + Cf1l * (1 - F1lt);
return [mu2, //.......1
t_tilde, //.. 2
a, //.........3
sqrtamu, //...4
fourmu2pi, //.5
Cf1s, //......6
Cf1l, //......7
CF1st, //.....8
F1lt, //......9
F1inf]'; //...10
}
int acceptt_rng(real t_star, real ft, real c) {
if (c <= 0.06385320297074884) {
reject("c is ", c);
}
if (is_nan(c)) {
reject("c is nan!");
}
real z = ft * uniform_rng(0, 1);
real b = exp(-c);
int k = 3;
while (1) {
if (z > b) {
return 0;
}
b -= k * exp(-c * k ^ 2);
if (z < b) {
return 1;
}
k += 2;
b += k * exp(-c * k ^ 2);
k += 2;
}
return 0;
}
real sample_small_mu_rng(vector vars) {
real t_star;
real pi_sq = pi() ^ 2;
real mu2 = vars[1];
real a = vars[3];
real sqrtamu = vars[4];
real fourmu2pi = vars[5];
real Cf1s = vars[6];
real Cf1l = vars[7];
real CF1st = vars[8];
real F1lt = vars[9];
real F1inf = vars[10];
int counter_outer = 0;
while (1) {
real p = F1inf * uniform_rng(0, 1);
if (p <= CF1st) {
t_star = 1. / (2 * a * pow(inv_erfc(p / Cf1s), 2));
while (0.5 * t_star <= 0.06385320297074884) {
p = uniform_rng(0.06385320297074884, CF1st);
t_star = 1. / (2 * a * pow(inv_erfc(p / Cf1s), 2));
}
real ft = exp(-1. / (2 * a * t_star) - sqrtamu + mu2 * t_star);
if (acceptt_rng(t_star, ft, 0.5 * t_star) == 1) {
return t_star;
}
} else {
t_star = -log1p(-(p - CF1st) / Cf1l - F1lt) / fourmu2pi;
real pisqt = pi_sq * t_star / 8;
while (pisqt <= 0.06385320297074884) {
p = uniform_rng(CF1st, F1inf);
t_star = -log1p(-(p - CF1st) / Cf1l - F1lt) / fourmu2pi;
pisqt = pi_sq * t_star / 8;
}
if (acceptt_rng(t_star, exp(-pisqt), pisqt) == 1) {
return t_star;
}
}
}
return 0;
}
real inverse_gaussian_rng(real mu, real mu_sq) {
real v = pow(std_normal_rng(), 2);
real z = uniform_rng(0, 1);
real x = mu + 0.5 * mu_sq * v - 0.5 * mu * sqrt(4 * mu * v + mu_sq * v ^ 2);
if (z <= (mu / (mu + x))) {
return x;
} else {
return mu_sq / x;
}
}
real sample_large_mu_rng(vector vars) {
real mu2 = vars[1];
real t_tilde = vars[2];
real a = vars[3];
real sqrtamu = vars[4];
real fourmu2pi = vars[5];
real Cf1s = vars[6];
real Cf1l = vars[7];
real CF1st = vars[8];
real F1lt = vars[9];
real F1inf = vars[10];
real invabsmu = inv_sqrt(mu2);
if (t_tilde >= 0.63662) {
Cf1l = -log(pi() * 0.25) - 0.5 * log(2 * pi());
Cf1s = 0;
} else {
Cf1l = -pi() ^ 2 * t_tilde / 8 + (3. / 2.) * log(t_tilde) + 0.5 * inv(t_tilde);
Cf1s = Cf1l + 0.5 * log(2 * pi()) + log(pi() * 0.25);
}
while (1) {
real t_star = inverse_gaussian_rng(invabsmu, inv(mu2));
if (is_nan(t_star)) {
reject("t_star is nan! ", mu2);
}
real one2t = 0.5 * inv(t_star);
if (t_star <= 2.5) {
real expone2t = exp(Cf1s - one2t);
if (expone2t == 0) {
expone2t = 1e-8;
}
if (acceptt_rng(t_star, expone2t, one2t) == 0 || invabsmu < 0.000666) {
return t_star;
}
} else {
real expone2t = exp(-log(pi() / 4) - 0.5 * log(2 * pi()) - one2t - (3. / 2.) * log(t_star));
if (acceptt_rng(t_star, expone2t, pi() ^ 2 * t_star / 8) == 0) {
return t_star;
}
}
}
return 0;
}
real fast_pt_rng(real alpha, real tau, real beta, real delta) {
real absmu = abs(delta) ;
vector[10] vars = make_vars(absmu);
real pt;
if (absmu < 1) {
pt = sample_small_mu_rng(vars);
} else {
pt = sample_large_mu_rng(vars);
}
return pt;
}
vector wiener_rng(real alpha, real tau, real beta, real delta) {
real t = 0 ;
real sign_delta = delta > 0 ? 1 : -1;
real x = beta * alpha ;
real mu = abs(delta);
real hit_bound;
vector[2] out;
int counter = 0;
if (beta == 0 || beta == 1) {
return [tau, beta]';
}
while (1) {
real mutheta;
real xlo = x ;
real xhi = alpha - x ;
// lower bound is 0
// upper bound is alpha in stan parmeterization
// symmetric case, [x - xup, x + xup]
if (abs(xlo - xhi) < 1e-6) {
mutheta = xhi * mu;
real pt = fast_pt_rng(alpha, tau, beta, xhi * abs(delta));
hit_bound = sign_delta == 1 ? inv_logit( 2 * mutheta ) : 1 - inv_logit( 2 * mutheta );
real bound = uniform_rng(0, 1) < hit_bound ? 1 : 0;
return [ tau + t + ( square(xhi) * pt), bound]';
// x is closer to upper bound, [x - xup, x + xup]
} else if (xlo > xhi) {
mutheta = xhi * mu;
t += ( square(xhi ) * fast_pt_rng(alpha, tau, beta, xhi* abs(delta))) ;
hit_bound = sign_delta == 1 ? inv_logit( 2 * mutheta ) : 1 - inv_logit( 2 * mutheta );
if (uniform_rng(0, 1) < hit_bound ) {
return [tau + t, 1]';
}
x -= xhi ;
} else {
// x is closer to lower bound, [x - xlo, x + xlo]
mutheta = xlo * mu ;
t += ( square(xlo ) * fast_pt_rng(alpha, tau, beta, xlo* abs(delta) )) ;
hit_bound = sign_delta == 1 ? inv_logit( 2 * mutheta ) : 1 - inv_logit( 2 * mutheta );
if (uniform_rng(0, 1) > hit_bound) {
out[1] = tau + t;
out[2] = 0 ;
break;
}
x += xlo ;
}
}
return out;
}
}
data {
int N;
real<lower=0> alpha_in;
real<lower=0> tau_in;
real<lower=0, upper=1> beta_in;
real delta_in;
}
transformed data {
int N_lower = 0;
int N_upper = 0;
array[N] real rt;
array[N] real idx;
int counter = 0;
for (n in 1 : N) {
vector[2] tmp = wiener_rng(alpha_in, tau_in, beta_in, delta_in);
rt[n] = tmp[1];
idx[n] = tmp[2];
if (tmp[2] == 0) {
N_lower += 1;
} else {
N_upper += 1;
}
}
array[N_lower] int id_lower;
array[N_upper] int id_upper;
array[2] int cnt;
cnt[1] = 0;
cnt[2] = 0;
for (n in 1 : N) {
if (idx[n] == 0) {
cnt[1] += 1;
id_lower[cnt[1]] = n;
} else {
cnt[2] += 1;
id_upper[cnt[2]] = n;
}
}
}
parameters {
real tau_raw; // logit of non-decision time
real delta; // drift-rate
real<lower=0> alpha; // boundary separation
real<lower=0, upper=1> beta; // starting point
}
transformed parameters {
real tau = inv_logit(tau_raw) * min(rt); // non-decision time at RT scale
}
model {
delta ~ normal(0, 3);
alpha ~ normal(1, 2);
beta ~ normal(0, 2);
tau_raw ~ normal(-.2, .48);
target += wiener_lpdf(rt[id_upper] | alpha, tau, beta, delta);
target += wiener_lpdf(rt[id_lower] | alpha, tau, 1 - beta, -delta);
}
In R you can test by
library(cmdstanr)
mod <- cmdstan_model("wiener_rng.stan")
out <- mod$sample(
data = list(
N = 400,
alpha_in = 1.25,
tau_in = 0.2,
beta_in = 0.7,
delta_in = -0.3
),
parallel_chains = 4
)
out
with the following results
> out
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -138.60 -138.29 1.42 1.23 -141.28 -136.94 1.00 1991 2944
tau_raw 2.45 2.45 0.19 0.18 2.15 2.77 1.00 2494 2652
delta -0.36 -0.36 0.11 0.11 -0.54 -0.18 1.00 2315 2585
alpha 1.25 1.25 0.03 0.03 1.20 1.31 1.00 2853 2533
beta 0.72 0.72 0.01 0.02 0.69 0.74 1.00 2279 2817
tau 0.21 0.21 0.00 0.00 0.20 0.21 1.00 2494 2652
edit: found an issue with large delta. Fixed by simplification :)