Hello,
I want to fit a (IRT-type) model involving a custom probability function. For this custom probability function
(name: espl_lpdf), I need to compute the integral of a function of type exp(c*bspline(x)), wherein bspline(x) is a B-spline with correspondingly defined knots.
When using the model below on test_data, I get error messages refering to a too large error estimate of the integral. I have difficulties in coming up with a proper solution to this problem. After all, the integral I compute (via the integrate_1d function - using the quadrature rule) is well-defined and quadrature should not have any problems with this integral (the function is smooth - except for the knot-locations; the integral is one-dimensional).
I would be very grateful for any comments, advices on this topic and for suggestions, on how to resolve this problem. For the sake of clarity, I have appended the whole Stan- and R-code – as well as the corresponding data set.
Please note that the functions used to compute the B-spline is taken from Milad Kharratzadeh’s post on B-splines. I have also tested the function various times, and I think it works properly, so that the problem is not with the definition of the B-spline.
functions {
//from: https://mc-stan.org/users/documentation/case-studies/splines_in_stan.html
vector build_b_spline(real[] t, real[] ext_knots, int ind, int order);
vector build_b_spline(real[] t, real[] ext_knots, int ind, int order) {
// INPUTS:
// t: the points at which the b_spline is calculated
// ext_knots: the set of extended knots
// ind: the index of the b_spline
// order: the order of the b-spline
vector[size(t)] b_spline;
vector[size(t)] w1 = rep_vector(0, size(t));
vector[size(t)] w2 = rep_vector(0, size(t));
if (order==1)
for (i in 1:size(t)) // B-splines of order 1 are piece-wise constant
b_spline[i] = (ext_knots[ind] <= t[i]) && (t[i] < ext_knots[ind+1]);
else {
if (ext_knots[ind] != ext_knots[ind+order-1])
w1 = (to_vector(t) - rep_vector(ext_knots[ind], size(t))) /
(ext_knots[ind+order-1] - ext_knots[ind]);
if (ext_knots[ind+1] != ext_knots[ind+order])
w2 = 1 - (to_vector(t) - rep_vector(ext_knots[ind+1], size(t))) /
(ext_knots[ind+order] - ext_knots[ind+1]);
// Calculating the B-spline recursively as linear interpolation of two lower-order splines
b_spline = w1 .* build_b_spline(t, ext_knots, ind, order-1) +
w2 .* build_b_spline(t, ext_knots, ind+1, order-1);
}
return b_spline;
}
//define exp(coeff*bspline)
//for testing purposes: just a single coeff/ b-spline
real exp_spline(real x, real xc, real[] theta, real[] x_r, int[] x_i)
{ real first[1];
first[1] = x;
return exp(theta[1]*build_b_spline(first, x_r, x_i[1], x_i[2])[1]);
}
//define a custom pdf - using the integral of an exponential b-spline
real espl_lpdf(real y, real theta, real coeff, data real[] ext_knots) {
real first[1];
real loc;
real deriv_spline;
real llik;
int x_i[2];
real theta_h[1];
first[1] = y;
x_i[1] = 1;
x_i[2] = 2;
theta_h[1] = coeff;
deriv_spline = build_b_spline(first, ext_knots, 1, 2)[1];
print(ext_knots[1], y, theta_h, ext_knots, x_i);
loc = integrate_1d(exp_spline, ext_knots[1],y, theta_h, ext_knots, x_i, 1e-4);
llik = std_normal_lpdf(loc-theta)+coeff*deriv_spline;
return llik;
}
}
data {
int<lower=0> K; // number of items
int<lower=1> N; // number of test takers
matrix[N, K] Y; // responses of test takers to the items
int num_knots;
real ext_knots[num_knots];
}
parameters {
vector[N] z;
vector[K] intercepts;
vector[K] coeffs;
real<lower=0> sigma_theta;
}
transformed parameters {
vector[N] theta = sigma_theta*z;
}
model {
sigma_theta ~ cauchy(0,1);
intercepts ~ normal(0,1);
coeffs ~ normal(0,1);
z ~ normal(0, 1);
for(i in 1:N)
{
for(j in 1:K)
{
Y[i,j] ~ espl(theta[i]-intercepts[j],coeffs[j],ext_knots); //
}
}
}
Data file:
test_data.csv (5.4 KB)
In R, I then execute the following statements:
test_data <- read.csv("test_data.csv")
x <- as.vector(test_data)
minmax <- range(x)
irt_dat <- list(K = 3, Y=test_data, N=100, num_knots=4,ext_knots=c(minmax[1],minmax[1],minmax[2],minmax[2]))
fit_spl <- stan(file = './stan_models/exp_spl.stan', data = irt_dat, iter = 4000, chains=1)