Stan equivalent to R's which()

Is there a function in Stan equivalent to R’s which() command, that basically gives me an index of elements of a vector meeting a certain condition.
Eg:

> a <- c(0, 4, 5, 7, 0, 2, 45, 1, 0)
> which(a==0)
[1] 1 5 9

So I tried to write a function in Stan to match zeros:

functions{
    vector which0(vector x) { 
        int len;
        vector [len] matches;
        len = num_elements(x);
        for(i in 1:len){
            if(x[i] == key)
                matches[i] = i;
            else
                matches[i] = 0;
        }
        return matches;
    }
}

But the problem with that is it generates a vector like this: 1, 0, 0, 0, 5, 0, 0, 0, 9. Those zeroes are a problem since I want to use this as an index to subset a vector.

Any suggestions on a better approach, or how to remove those zeroes to have a vector 1,5,9 i.e. with a varaible length = number of matches ?

I found it easier to do stuff like this in R and feed the results to RStan as data. Then you’d have an number of observed Y int N_Yobs and a position vector (of where thise observed Y’s are) vector[N_Yobs] Ypos (the result of R’s which). Do that for the “missing” values as well and then, as a parameter, have a vector[N_Ymis] Ymis. Then you can rebuild the full vector in Stan’s transformed parameters block. (I’ll post more detail on that in the other thread, probably this afternoon GMT+1.)

1 Like

Eh admittedly I didnt’ think of that, but a complication is I’m doing this rowise in a loop. If I go with the option of precomputing it in R I’ll be passing into Stan N vectors of varying length for Ypos and Ymis. Sounds very messy!

An inefficient way would be to write two functions.
(1) to calculate the number of matches
(2) to find the matches where (1) would allow you to specify the length of the matches vector.

You do have to loop over de target vector twice (or only until your last match). Hence the inefficiency. It’s probably not going to affect your performance too much if the target vector is data.

2 Likes

So I may have to rethink it a bit. The target vector is data defined as an array:

vector[NvarsY] y [N];

I had intended to which0() inside a for(n in 1:N) loop, because outside of the loop I don’t really care about the index. However I can see that precomputing it in R or a transformed data{} block would be faster, but then I have to contend with N indexes of varying length - and that seems just as hard to me?

EDIT: Ok I think this is gonna be really hard, so I’m gonna try @Max_Mantei’s solution to do it in R first! Thanks all for responses

How I would do it:

In R:

N <- 40
K <- 2

mus <- c(-1.5, 0.5)
sigmas <- c(1.7, 2.4)
rho <- 0.4
Omega <- diag(1, nrow = 2, ncol = 2)
Omega[2,1] <- rho
Omega[1,2] <- rho
L_Omega <- chol(Omega)
Z <- matrix(rnorm(N*K), nrow = N, ncol = K)

log_Y <- t(mus + diag(sigmas) %*% L_Omega %*% t(Z))

apply(log_Y, 2, mean)
apply(log_Y, 2, sd)
cor(log_Y) # when it's a small sample this can be pretty off...

Y <- exp(log_Y)

N_mis <- 10

for (i in 1:N_mis){
  v <- sample(1:K, 1)
  Y[which.min(Y[Y[,v] != 0, v]), v] <- 0
  }

N_obs <- sum(Y != 0)

Y_obs_pos <- list()
Y_obs <- list()
for (k in 1:K){
  Y_obs_pos[[k]] <- cbind(n = which(Y[,k] != 0), k = k)
  Y_obs[[k]] <- Y[Y[,k] != 0, k]
}

Y_obs_pos <- do.call(rbind, Y_obs_pos)
Y_obs <- unlist(Y_obs)

# checking...
check <- matrix(0, nrow = N, ncol = K)
for (i in 1:N_obs)
  check[Y_obs_pos[i,1], Y_obs_pos[i,2]] <- Y_obs[i]
all.equal(Y, check)

Y_mis_pos <- list()
for (k in 1:K)
  Y_mis_pos[[k]] <- cbind(n = which(Y[,k] == 0), k = k)
Y_mis_pos <- do.call(rbind, Y_mis_pos)

standata <- list(
  N = N, 
  N_obs = N_obs,
  N_mis = N_mis,
  Y_obs_pos = Y_obs_pos,
  Y_mis_pos = Y_mis_pos,
  Y_obs = Y_obs
)

Then in Stan:

data {
  int N;
  int K;
  int N_obs;
  int N_mis;
  int Y_obs_pos[N_obs,2];
  int Y_mis_pos[N_mis,2];
  vector[N_obs] Y_obs;
}
parameters {
  vector[K] mu;
  vector<lower=0>[K] sigma;
  vector<lower=0>[N_mis] Y_mis;
  cholesky_factor_corr[K] L_Omega;
}
transformed parameters{
  matrix[K,K] L_Sigma = diag_pre_multiply(sigma, L_Omega);
  vector[K] log_Y[N];
  for (n in 1:N_obs){
    //print("Observed Y @-- N: ", Y_obs_pos[n,1], " | K: ", Y_obs_pos[n,2]);
    log_Y[Y_obs_pos[n,1], Y_obs_pos[n,2]] = log(Y_obs[n]);
  }
  for (n in 1:N_mis){
    //print("Missing Y @-- N: ", Y_mis_pos[n,1], " | K: ", Y_mis_pos[n,2]);
    log_Y[Y_mis_pos[n,1], Y_mis_pos[n,2]] = log(Y_mis[n]);
    }
}
model {
  log_Y ~ multi_normal_cholesky(mu, L_Sigma);
  mu ~ normal(0, 2.5);
  sigma ~ exponential(1);
  L_Omega ~ lkj_corr_cholesky(4);
  for (n in 1:N)
    target += -log_Y[n];
}
generated quantities{
  corr_matrix[2] Omega = multiply_lower_tri_self_transpose(L_Omega);
  real<lower=-1,upper=1> rho = Omega[2,1];
}

Running the model:

library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

posterior <- stan(file = "multi_missing_value.stan", data = standata)
print(posterior, c("mu", "sigma", "rho", "Y_mis"))

…and the output:

Inference for Stan model: multi_missing_value.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

           mean se_mean     sd  2.5%   25%   50%   75%  97.5% n_eff Rhat
mu[1]     -1.54    0.00   0.27 -2.07 -1.71 -1.54 -1.35  -1.01  5415    1
mu[2]      0.24    0.00   0.39 -0.54 -0.03  0.25  0.50   1.01  6340    1
sigma[1]   1.56    0.00   0.20  1.23  1.41  1.54  1.68   1.98  4820    1
sigma[2]   2.41    0.00   0.29  1.94  2.21  2.38  2.58   3.08  5716    1
rho        0.11    0.00   0.15 -0.21  0.00  0.12  0.22   0.40  4489    1
Y_mis[1]   0.84    0.05   2.90  0.01  0.08  0.23  0.62   5.47  3174    1
Y_mis[2]   0.99    0.08   4.91  0.01  0.09  0.27  0.78   6.06  3741    1
Y_mis[3]   0.66    0.04   2.35  0.01  0.05  0.16  0.49   4.07  3365    1
Y_mis[4]   0.64    0.04   2.18  0.01  0.06  0.17  0.51   4.05  3674    1
Y_mis[5]   0.64    0.04   2.28  0.01  0.07  0.19  0.52   3.79  2854    1
Y_mis[6]   0.63    0.03   1.83  0.01  0.07  0.18  0.51   4.20  3381    1
Y_mis[7]  28.82    4.26 264.61  0.01  0.25  1.25  5.99 171.24  3860    1
Y_mis[8]  29.49    6.40 320.53  0.01  0.25  1.33  7.15 189.52  2511    1
Y_mis[9]  25.66    3.99 226.52  0.01  0.21  1.16  6.13 162.93  3219    1
Y_mis[10] 34.45    5.95 356.09  0.01  0.30  1.47  7.98 186.75  3580    1
1 Like

Thanks @Max_Mantei, I’ve been working on something similar, but you just gave me an idea to improve it - many thanks!

1 Like

Here check out this thread I started way back:

2 Likes

Thanks @saudiwin, I actually saw that before making this thread. It won’t work for what I wanted. I’ve come to the conclusion a Stan version of which() is impossible because the length of the which() output varies depending on the data, and that would require storing the result in a ragged array and Stan doesn’t have ragged arrays! I’m pursuing another approach doing the donkey work in R similar to @Max_Mantei suggested, but a bit different.
Thanks anyhow, all replies appreciated!

1 Like

Not yet. And you can’t write a generic one because we don’t have higher order functions. In Stan, if you have a particular test, such as a == 0, you can write it this way:

int num_matches(int[] x, int a) {
  int n = 0;
  for (i in 1:size(x))
    if (x[i] == a)
      n += 1;
  return n;
}
  
int[] which_equal(int[] x, int a) {
  vector[num_matches(x, a)] match_positions;
  int pos = 1;
  for (i in 1:size(x)) {
    if (x[i] == a) {
      match_positions[pos] = x[i];
      pos += 1;
    }
  }
  return match_positions;
}

Warning: I didn’t try to compile or debug this, but the basic idea should work.

2 Likes

Thanks Bob. I ended up doing the indexing in R and passing it in as data, but good to know this solution too!

Just came across this when answering a different post. Here are slightly modified versions of Bob’s suggestions above that I just tested when responding to How to find the location of a value in a vector. I changed it to work with a vector input (instead of an int array) and fixed a typo where match_positions[pos] = x[i] should be match_positions[pos] = i.

// Find the number of elements in the vector x that equal real number y
int num_matches(vector x, real y) {
  int n = 0;
  for (i in 1:rows(x))
    if (x[i] == y)
      n += 1;
  return n;
}
  
// Find the indexes of the elements in the vector x that equal real number y
int[] which_equal(vector x, real y) {
  int match_positions[num_matches(x, y)];
  int pos = 1;
  for (i in 1:size(x)) {
    if (x[i] == y) {
      match_positions[pos] = i;
      pos += 1;
    }
  }
  return match_positions;
}

// example usage later in stan program
int ids[num_matches(x, y)] = which_equal(x, y);
3 Likes

I tried to use this function but at line

for (i in 1:size(x)) {

I get the error

No matches for:
Available argument signatures for size:

How could I fix it? Perhaps using

rows(x)

Instead?

I am trying out the function this way and a weird thing happened: it will not work with the function mean(), but it will work with the function min(), though they both take the same input type. For mean I get the same error:

No matches for:
Available argument signatures for mean: