Marginalizing over latent discrete parameters

Yes, I can share it . The Stan file has 1500 lines with several user defined functions related with our new coalescent model( https://doi.org/10.1016/j.tpb.2021.09.003). The idea is : re parametrize the model using a simplex vector(size N) over the total tree length(sum of all branch lengths) to put N-1 change points.

Excluding user defined functions(approx. 1000 lines above data block) it looks like this(the last part is similar to Pystan for Jukes Cantor CTMC model):

data{
  real<lower=0.0, upper=2.0> K;//parameter for the model M_K
  int<lower=2> N;//number of populations
  int<lower=2> total_sample_size;
  int <lower=0> L;// alignment length
  int<lower=0,upper=4> tipdata[total_sample_size,L];
  simplex[N] pop_sizes_proportion;
  //int<lower=0>  number_immigrants[N-1];
  int<lower=0,upper=2*total_sample_size> topology[total_sample_size-1,3];
}
transformed data{
    int number_branches = 2*total_sample_size-2;
    int number_of_tips_below[2*total_sample_size-1];
    //matrix  [(total_sample_size-1),(total_sample_size-1)] Q_coeff_hypoexponential= rep_matrix(0.0, (total_sample_size-1), (total_sample_size-1));
    int number_of_coalescent_times_below[2*total_sample_size-1] ;
    vector[4] tip_partials[2*total_sample_size,L];
    int<lower=0,upper=1> indexes_nodes_below[total_sample_size-1, 2*total_sample_size-1]= rep_array(0, total_sample_size-1,2*total_sample_size-1 );
    //matrix[total_sample_size-1, 2*total_sample_size-1] indexes_nodes_below= rep_matrix(0, total_sample_size-1,2*total_sample_size-1 );
    int map_internal_node_topology_row[2*total_sample_size-1]= rep_array(0,2*total_sample_size-1);
    matrix[total_sample_size,2*total_sample_size-1 ] indexes_candidate_MRCAs= rep_matrix(0, total_sample_size, 2*total_sample_size-1) ;
    int indexes_candidate_MRCAs_array[total_sample_size,total_sample_size ] = rep_array(0, total_sample_size, total_sample_size) ;
    matrix[number_branches, total_sample_size-1 ] coal_times_to_branch_lengths= rep_matrix(0, number_branches, total_sample_size-1);
    vector[total_sample_size-1] col_sums= rep_vector(0,total_sample_size-1 );
    vector [3*total_sample_size-4] w;
    int v[3*total_sample_size-4];
    int u[number_branches+1];
    int number_living_ancestors[total_sample_size]=rep_array(1, total_sample_size);
    int max_number_configurations = maximum_number_configurations(total_sample_size, N);

    for(i in 1:(total_sample_size)){
       number_of_tips_below[i]=1;
       number_of_coalescent_times_below[i]=0;
        indexes_candidate_MRCAs_array[1, i ] = i;
    }
    number_living_ancestors[1]= total_sample_size;
    indexes_candidate_MRCAs[1, 1:total_sample_size ] = rep_row_vector(1,total_sample_size );

    for(i in 1:(total_sample_size-1)){
          indexes_candidate_MRCAs[i+1,]= indexes_candidate_MRCAs[i,];
          indexes_candidate_MRCAs[i+1,topology[i,1]]=1;
          indexes_candidate_MRCAs[i+1,topology[i,2]]=0;
          indexes_candidate_MRCAs[i+1,topology[i,3]]=0;
    }

    for(i in 2:(total_sample_size)){
        for(j in 1:(2*total_sample_size-1)){
            if(indexes_candidate_MRCAs[i,j]==1){
                 indexes_candidate_MRCAs_array[i, number_living_ancestors[i]]=j;
                 number_living_ancestors[i]= number_living_ancestors[i]+1;
            }
        }
        number_living_ancestors[i]= total_sample_size+1-i;
    }

    for(i in 1:(total_sample_size-1)){
       number_of_tips_below[topology[i,1]]=number_of_tips_below[topology[i,2]]+ number_of_tips_below[topology[i,3]];
       number_of_coalescent_times_below[topology[i,1]]=number_of_coalescent_times_below[topology[i,2]]+ number_of_coalescent_times_below[topology[i,3]]+1;
        map_internal_node_topology_row[topology[i,1]]=i;
        coal_times_to_branch_lengths[2*i-1, i]=1.0;
        coal_times_to_branch_lengths[2*i, i]=1.0;
        if(topology[i,2] >total_sample_size){
           coal_times_to_branch_lengths[2*i-1,  map_internal_node_topology_row[topology[i,2]]]= -1.0;
        }
        if(topology[i,3] >total_sample_size){
           coal_times_to_branch_lengths[2*i,  map_internal_node_topology_row[topology[i,3]]]= -1.0;
        }
    }

    w=csr_extract_w(coal_times_to_branch_lengths);
    v=csr_extract_v(coal_times_to_branch_lengths);
    u=csr_extract_u(coal_times_to_branch_lengths);

    for(i in 1:(total_sample_size-1)){
      col_sums[i]= sum(coal_times_to_branch_lengths[,i]);
    }

     for( n in 1:total_sample_size ) {//rows
          for( i in 1:L ) {//columns
              //for( a in 1:4 ) {
                 tip_partials[n,i]= rep_vector(0.0, 4);
                 tip_partials[n,i][tipdata[n,i] ] = 1.0;
              //}
          }
      }

    for (i in 1:(total_sample_size-1)){

           if (topology[i,2] > total_sample_size && topology[i,3] > total_sample_size){

              for (k in 1:(2*total_sample_size-1))
                 indexes_nodes_below[i,k]= indexes_nodes_below[map_internal_node_topology_row[topology[i,2]],k] + indexes_nodes_below[map_internal_node_topology_row[topology[i,3]],k];

           }
           else{

               if (topology[i,2] > total_sample_size){
              indexes_nodes_below[i,]= indexes_nodes_below[map_internal_node_topology_row[topology[i,2]],];
                 }
               else if (topology[i,3] > total_sample_size){
              indexes_nodes_below[i,]= indexes_nodes_below[map_internal_node_topology_row[topology[i,3]],];
                 }

             indexes_nodes_below[i,topology[i,2]]=1;
             indexes_nodes_below[i,topology[i,3]]=1;
           }
           indexes_nodes_below[i,topology[i,1]]=1;
      }

}
parameters{
  real<lower=0.0> hyper_parameters_exponential;
  //vector<lower=0>[N] deltaTs;
  vector<lower=0.0001>[N] deltas;
  real<lower=0.0> torigin_oldest_population;
  real<lower=0.0> theta;

  simplex[N] simplex_torigins;
  simplex[total_sample_size] simplex_coal_times;
}
transformed parameters {
  //declarations

  positive_ordered[total_sample_size-1] coal_event_times_model_time_oldest_population;
  positive_ordered[N-1] torigins_model_time_oldest_population;
  vector<lower=0.0>[number_branches] branch_lengths;
  vector<lower=0.0>[N-1] torigins_model_time ;

  coal_event_times_model_time_oldest_population=head(cumulative_sum(simplex_coal_times),total_sample_size-1) * torigin_oldest_population;

  branch_lengths =   csr_matrix_times_vector(number_branches, total_sample_size-1, w,  v, u, coal_event_times_model_time_oldest_population);


  torigins_model_time_oldest_population= head(cumulative_sum(simplex_torigins),N-1) * max(coal_event_times_model_time_oldest_population);


  torigins_model_time= torigins_model_time_oldest_population *pop_sizes_proportion[N];
  torigins_model_time = torigins_model_time ./ pop_sizes_proportion[1:(N-1)];

}
model{
  vector[4] partials[total_sample_size,L];  // partial probabilities for the S tips and S-1 internal nodes
  matrix[4,4] p_matrices[number_branches]; // finite-time transition matrices for each branch

  real a;
  real a_divide_by_3;
  vector[4] left;
  vector[4] right;
  //matrix[total_sample_size-1, total_sample_size-1] J;
  vector[total_sample_size-1] J_diag_coal_times;
  vector[N-1] J_diag_torigins;
 // real total_branch_length;

  hyper_parameters_exponential ~ gamma(rep_vector(0.001,N), rep_vector(0.001,N));
  //deltaTs ~ exponential(hyper_parameters_exponential);
  deltas ~ exponential(hyper_parameters_exponential);
  theta ~ exponential(1);


  torigin_oldest_population~conditionalDensityTOrigin(deltas[N], total_sample_size);
  print("torigin_oldest_population=",torigin_oldest_population);


  J_diag_coal_times = rep_vector(torigin_oldest_population, total_sample_size-1);
  target += sum(log(J_diag_coal_times));



  for(j in 1:(N-1)){
      torigins_model_time[j]~conditionalDensityTOrigin(deltas[j], total_sample_size);
  }

  coal_event_times_model_time_oldest_population~structured_coalescent_tree_marginalized_MRCA(deltas, append_row(torigins_model_time, torigin_oldest_population), append_row(torigins_model_time_oldest_population, torigin_oldest_population),
                    pop_sizes_proportion, indexes_nodes_below, map_internal_node_topology_row, number_of_tips_below, number_of_coalescent_times_below, indexes_candidate_MRCAs_array, topology,
                    number_living_ancestors, K, total_sample_size, max_number_configurations);


  for( b in 1:number_branches){
               a = 0.75*(1-exp(-4*branch_lengths[b]*theta/3));
               a_divide_by_3 = a / 3.0;
               p_matrices[b]= rep_matrix(a_divide_by_3, 4, 4);
               for(i in 1:4){
                 p_matrices[b][i,i]= 1.0-a;
               }
             }

  for( i in 1:L ){//columns
     for( n in 1:(total_sample_size-1) ) {//rows
                   left  = (topology[n,2] > total_sample_size)? partials[topology[n,2]-total_sample_size,i] :  tip_partials[topology[n,2],i];
                   right = (topology[n,3] > total_sample_size)? partials[topology[n,3]-total_sample_size,i] :  tip_partials[topology[n,3],i];

                   partials[topology[n,1]-total_sample_size,i] = (p_matrices[2*n-1]*left) .* (p_matrices[2*n]*right);
                 }
                 for(j in 1:4){

                 	partials[2*total_sample_size-total_sample_size,i][j] = partials[topology[total_sample_size-1,1]-total_sample_size,i][j] * 0.25;
                 }
                 target += log(sum(partials[total_sample_size,i]));
            }
}

2 Likes