Hello everyone,
Is there an equivalent of the R’s function which in Stan? I would like to obtain the indexes of a vector that meet a certain condition (if possible without a loop).
For example, in R, if I have a vector A = c(9, 3, 0, 3, 0, 2). Using the function which(A==0) will return 3 5. Is there an easy way to do that in Stan?
Thank you in advance!
Not an easy way, I believe. You’d have to write your own function, probably using the various sorting functions of Stan inside, and probably some if’s.
What’s your use case?
It’s difficult to imagine how such a function would fit into Stan currently due to the requirement that you give all variables a size ahead of time. What can you do with the result of a function like this, if you can’t assign it anywhere? Not much, unfortunately
For example,
I had a previous fonction looking like that, using real:
functions{
real f(real t, vector psi, int nb_param_str){
real TS;
vector[nb_param_str] psi;
if(t <= 0){TS = psi[1]*exp(psi[4]*t);}
else{TS = psi[1]*(exp(psi[2]*t)+exp(-psi[3]*t)-1);}
return(TS);
}
}
That I wanted to transformed like that, using vector :
functions{
vector f(vector t, vector psi, int nb_param_str){
vector [size(t)] TS; //Vector of the SLD measurement
vector [nb_param_str] psi; //Individual Parameters
int start_neg_time = fmin(which(t <= 0));
int end_neg_time = fmax(which(t <= 0));
int start_pos_time = fmin(which(t > 0));
int end_pos_time = fmax (which(t > 0));
TS[start_neg_time:end_neg_time] = psi[1]*exp(psi[4]*t[start_neg_time:end_neg_time]);
TS[start_pos_time:end_pos_time] = psi[1]*(exp(psi[2]*t[start_pos_time:end_pos_time])+exp(-psi[3]*t[start_pos_time:end_pos_time])-1);
return(TS);
}
But I didn’t not find a equivalent of which function in stan.
For that example I defined those variables (start_neg_time , end_neg_time , start_pos_time , end_pos_time) in my R files to give them to stan via the data block.
But in other case, I cannot define the variable via the data block like that. Should I then go back to a function with a real instead of a vector?
Thank you!
I was also missing such a function at some point. the whichequals function below should do it, don’t ask me about efficiency but I assume since it’s integers it’s no big deal…
array[] int vecequals(array[] int a, int test, int comparisonType){ //do indices of a match test condition?
array[size(a)] int check;
for(i in 1:size(check)) check[i] = comparisonType ? (test==a[i]) : (test!=a[i]);
return(check);
}
// Function: whichequals
// Parameters:
// - b: an array of integers
// - test: an integer value representing the test condition
// - comparisonType: an integer value representing the type of comparison (0 for !=, 1 for ==)
// Returns:
// - An array of integers representing the indices of elements in b that match the test condition.
array[] int whichequals(array[] int b, int test, int comparisonType){ //return array of indices of b matching test condition
array[size(b)] int check = vecequals(b,test,comparisonType);
array[sum(check)] int which;
int counter = 1;
if(size(b) > 0){
for(i in 1:size(b)){
if(check[i] == 1){
which[counter] = i;
counter += 1;
}
}
}
return(which);
}
you can compute the size of it first and then assign it in a second go (unwieldy but still sometimes useful), or something like: x[which(y==2)] = z[which(y==2)]