Is A known and you just need to solve for x? Or does A get sampled as well?
If you know A and want to construct a random x which satisfies the constraint you can sample x from a constrained mvn. All the stuff in the transformed data block can be done ahead of time using sparse solvers. The input mu
is optional. I think you could probably have this as a parameter vector.
This is constructing X \sim N(\mu, \Gamma) given AX = b
data {
int<lower=0> n;
int<lower=0> N;
matrix[n, N] A;
matrix[N, N] G;
vector[n] b;
vector[N] mu;
}
transformed data {
// P = diag(N) - t(A) %*% solve(A %*% t(A)) %*% A;
matrix[N, N] P = add_diag(-A' * (tcrossprod(A) \ A), rep_vector(1, N));
tuple(matrix[N, N], vector[N]) eigen = eigendecompose_sym(P * G \ P);
matrix[N, N] Sigma = P * eigen.1;
vector[N] sigma = 1 ./ sqrt(eigen.2 + 1e-8);
//mu_new <- mu + G %*% t(A) %*% solve(A %*% G %*% t(A)) %*% (b - A %*% mu)
vector[N] mu_new = mu + G * A' * quad_form(G, A') \ (b - A * mu);
}
parameters {
vector[N] x_raw;
}
transformed parameters {
vector[N] x = mu_new + diag_post_multiply(Sigma, sigma) * x_raw;
}
model {
x_raw ~ std_normal();
}
generated quantities {
vector[n] b_out = A * x;
}
Let’s say I generate the data as
N <- 8
n <- 3
mu <- c(1:N)
A <- matrix(rnorm(n * N), n, N)
b <- rnorm(n)
G <- rethinking::rlkjcorr(1, N)
ibrary(cmdstanr)
mod <- cmdstan_model("mvn_linear_constraint.stan")
mod_fit <- mod$sample(
data = list(
n = n,
N = N,
A = A,
G = G,
b = b,
mu = mu
),
parallel_chains = 4
)
mod_fit$summary("b_out")
b
I get that the constraint is satisfied with a random x vector
> mod_fit$summary("b_out")
# A tibble: 3 × 10
variable mean median sd mad q5 q95 rhat ess_bulk
<chr> <num> <num> <num> <num> <num> <num> <num> <num>
1 b_out[1] -0.966 -0.966 0 0 -0.966 -0.966 NA NA
2 b_out[2] -1.45 -1.45 0 0 -1.45 -1.45 NA NA
3 b_out[3] -0.552 -0.552 0 0 -0.552 -0.552 NA NA
# ℹ 1 more variable: ess_tail <num>
> b
[1] -0.9655925 -1.4478979 -0.5517504
where x is
> mod_fit$summary("x")
# A tibble: 8 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
1 x[1] -2.99 -2.99 0.567 0.572 -3.92 -2.05 1.00 6711. 2754.
2 x[2] -6.08 -6.08 0.629 0.620 -7.11 -5.00 1.00 6416. 3129.
3 x[3] 3.07 3.07 0.897 0.899 1.59 4.56 1.00 6591. 3207.
4 x[4] -0.619 -0.613 0.716 0.722 -1.78 0.579 1.00 6444. 3107.
5 x[5] 4.84 4.84 0.680 0.671 3.73 5.95 1.00 6681. 3112.
6 x[6] 10.4 10.4 0.926 0.909 8.84 11.9 1.00 6708. 2975.
7 x[7] -4.44 -4.44 0.537 0.537 -5.31 -3.55 1.00 6267. 2979.
8 x[8] -1.70 -1.71 0.726 0.712 -2.89 -0.503 1.00 6710. 2718.