Dear Stan users,
I’ve been working with a model of gene transcription and would like to use Stan to infer its parameters. The data are counts of the number of RNA transcripts of genes in cells, measured at a handful of times after the cells have been stimulated in a certain way and so I start with a list of real-valued times and a table of counts y_i,j, giving the number of cells observed in state i at time j. There are typically lots (a few hundred) of these states.
The underlying model has a transition rate matrix that depends on just a handful - typically four or five - parameters, though the matrix itself is big and sparse. I’ve written Stan functions as outlined below:
- Use the parameters to to build the transition rate matrix R.
- Use Stan’s scale_matrix_exp_multiply() to compute pi_j = exp( t_j * R ) pi_0. Here the pi’s are vectors of probabilities such that pi_ij gives the probability of finding the chain in state i at time j, given the initila distribution pi_i,0.
- Finally, model my count data at time j as being drawn from a multinomial distribution with underlying probabilities pi_j.
Sadly, the model fails to compile, dying with an error about the first argument to scale_matrix_exp_multiply(). As far as I can tell, the key line is
candidate function template not viable: no known conversion from ‘local_scalar_t__’ (aka ‘stan::math::var’) to ‘const double’ for 1st argument
scale_matrix_exp_multiply(const double& t, const Eigen::Matrix<Ta, -1, -1>& A,
I’d be grateful for any advice.
Thanks,
Mark
Code:
functions {
int state_to_idx( int n_active, int n_transcripts, int n_loci )
//********************************************************************
// Map a state specified by an (n_active, n_transcripts) pair to
// and index in the vector of probabilities.
{
return( (n_loci + 1)*n_transcripts + n_active ) ;
}
//______________________________________________________________
matrix build_rate_mat(
int n_loci,
int max_nRNA, // maximal number of transcripts
real mu_0, // transcription rate in off state
real mu_1, // transcription rate in on state
real k_a, // activation rate
real k_i, // inactivation rate
real delta // degradation rate
)
//*****************************************************************
// Given all the parameters of a model, build the (sparse) rate
// matrix.
//
{
// Work out the number of states in the model and
// initialise the result.
int n_inactive ; // used below
int n_states = (1 + n_loci) * (1 + max_nRNA) ;
matrix[n_states, n_states] rate_mat ;
// Loop over all states of the chain, adding entries
// describing outgoing reactions only.
for( m in 0:max_nRNA ) { // loop over transcript count
for( a in 0:n_loci ) { // loop over number of active sites
int from_idx = state_to_idx( a, m, n_loci ) ;
int to_idx ;
real rate ;
// Transcriptional activation
if( a < n_loci ) {
to_idx = state_to_idx( a+1, m, n_loci ) ;
rate = k_a * (n_loci - a) ;
rate_mat[to_idx, from_idx] = rate ;
rate_mat[from_idx, from_idx] -= rate ;
}
// Transcriptional inactivation
if( a > 0 ) {
to_idx = state_to_idx( a-1, m, n_loci ) ;
rate = k_i * a ;
rate_mat[to_idx, from_idx] = rate ;
rate_mat[from_idx, from_idx] -= rate ;
}
// Transcription
n_inactive = n_loci - a ;
rate = (n_inactive * mu_0) + (a * mu_1) ;
rate_mat[from_idx, from_idx] -= rate ;
if( m < max_nRNA ) {
to_idx = state_to_idx( a, m+1, n_loci ) ;
rate_mat[to_idx, from_idx] = rate ;
}
// Degradation of transcripts
if( m > 0 ) {
to_idx = state_to_idx( a, m-1, n_loci ) ;
rate = delta * m ;
rate_mat[to_idx, from_idx] = rate ;
rate_mat[from_idx, from_idx] -= rate ;
}
}
}
return( rate_mat ) ;
}
//______________________________________________________________
matrix compute_probs(
int n_loci, // Usually two
int max_nRNA, // maximal number of transcripts
real[] times, // Times at which we have observations
real mu_0, // transcription rate in off state
real mu_1, // transcription rate in on state
real k_a, // activation rate
real k_i, // inactivation rate
real delta // degradation rate
)
//*****************************************************************
// Given all the parameters of a model, build the (sparse) rate
// matrix and get the steady-state probs from its eigenspectrum.
//
{
// Declare all the variables.
int n_states = (1 + n_loci) * (1 + max_nRNA) ;
int n_times = size(times) ;
real crnt_time ;
real t_step ;
real pi_sum ;
matrix[n_states, 1] crnt_pi ;
matrix[n_states, 1] next_pi ;
matrix[n_states, n_times] pi_mat ;
matrix[n_states, n_states] rate_mat ;
// Set the distribution at t=0.
crnt_time = 0.0 ;
crnt_pi[1, 1] = 1.0 ;
// Build the rate matrix
rate_mat = build_rate_mat( max_nRNA, n_loci, mu_0, mu_1, k_a, k_i, delta ) ;
// Compute the pi's
for( j in 1:n_times ) {
// Compute exp(t * rate_mat ) * pi_zero ;
t_step = times[j] - crnt_time ;
next_pi = scale_matrix_exp_multiply( t_step, rate_mat, crnt_pi ) ;
// Correct the normalisation and record the result
pi_sum = 0.0 ;
for( i in 1:n_states ) { pi_sum += next_pi[i,j] ; }
next_pi /= pi_sum ;
for( i in 1:n_states ) { pi_mat[i,j] = next_pi[i, 1] ; }
// Advance time
crnt_time = times[j] ;
crnt_pi = next_pi ;
}
return( pi_mat ) ;
}
}
data {
// All times and rates are in minutes
int<lower=1> max_nRNA ; // Max number of transcripts
int<lower=1> n_loci ; // Typically 2
int<lower=1> n_times ;
int<lower=1> n_states ;
real times[n_times] ; // Times at which we have data
int<lower=0> counts[n_states, n_times] ; // Number of cells in each state
// For the current model this is fixed at zero.
real<lower=0> mu_0 ; // transcription rate in off state
}
parameters {
// All these are measured relative to the degradation rate
real<lower=0> mu_1 ; // transcription rate in on state
real<lower=0> k_a ; // activation rate
real<lower=0> k_i ; // inactivation rate
real<lower=0> delta ; // rate of mRNA degradation
}
transformed parameters {
matrix[n_states, n_times] pi_mat =
compute_probs( n_loci, max_nRNA, times, mu_0, mu_1, k_a, k_i, delta ) ;
}
model {
mu_1 ~ uniform(0.0, 20.0) ;
k_a ~ uniform(0.0, 10.0) ;
k_i ~ uniform(0.0, 10.0) ;
delta ~uniform(0.0, 1.0) ;
for( j in 1:n_times ) {
counts[:,j] ~ multinomial( pi_mat[:,j] ) ;
}
}