I implemented the model (to the best of my ability), and it doesn’t seem to perform well on simulated data. I simulated data with 2/3 of the neurons belonging to population 1 (with an effect) and 1/3 belonging to population 2 (no effect). but the mixture probability is consistently around 0.995. I don’t get any ill-behavior messages during fitting.
I’m using the following Stan code to define the model:
data {
int<lower=1> nNeurons; // Number of groups
int<lower=1> N; // Number of total datapoints
vector[N] Stimulation; // Stimulation indicator
vector[N] Response; // Response data
array[N] int<lower=1> neuronId; // Neuron indices for each observation
}
parameters {
real baselineFR; // Average baseline FR
real<lower=0> meanStimEffect; // Increment of FR with stimulation
vector[nNeurons] neuronBaseFR; // Neuron-specific baseline FR
vector[nNeurons] neuronEffect; // Neuron-specific effect of stimulation
real<lower=0> sigmaBaselineFR; // Standard deviation of baseline FR
real<lower=0> sigmaEffect; // Standard deviation of stimulation effect
real<lower=0> sigmaResidual; // Standard deviation of residuals
real<lower=0,upper=1> p; // The probability that an arbitrary individual is in subpopulation 2
}
model {
// Priors
baselineFR ~ normal(0, 30);
sigmaBaselineFR ~ cauchy(0, 5);
neuronBaseFR ~ normal(baselineFR, sigmaBaselineFR);
meanStimEffect ~ normal(0, 30);
sigmaEffect ~ cauchy(0, 5);
sigmaResidual ~ cauchy(0, 5);
p ~ beta(1, 1); // Prior for the mixture weight, assuming equal chance for both subpopulations.
for (i in 1:N) {
// linear predictor with the effect
real lp_1 = neuronBaseFR[neuronId[i]] + neuronEffect[neuronId[i]] * Stimulation[i];
// linear predictor without the effect
real lp_2 = neuronBaseFR[neuronId[i]];
// Mixture model for the response
target += log_mix(p,
normal_lpdf(Response[i] | lp_1, sigmaResidual),
normal_lpdf(Response[i] | lp_2, sigmaResidual));
}
}
In the code above, neuronBaseFR
and neuronEffect
are the intercept and slope used to generate the linear predictors lp_1
and lp_2
.
I use the following code to generate a dataset in Python (note, individuals in my case are neurons):
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from cmdstanpy import CmdStanModel
import seaborn as sns
import os
import arviz as az
###############
# GENERATE THE DATA
###############
# Dataset parameters
nTrials = 20 # number of trials per condition
neurons1N = 20 # number of neurons with effect (population 1)
neurons2N = 10 # number of neurons without effect (population 2)
baselineFiring = 10 # mean baseline firing rate
neuronBaselineSd = 3 # variability in baseline firing rate
effect = 10 # mean effect of stimulation on population 1
residualSd = 3 # trial response variability
nNeurons = neurons1N + neurons2N
neuronType = ['1'] * neurons1N + ['2'] * neurons2N
typeEffects = {'1': effect, '2': 0}
# Generate the dataset
data = []
# Generate the baseline firing rates for each neuron
neuronBaselines = np.random.normal(0, neuronBaselineSd, nNeurons) + \
baselineFiring
# Define the means for each condition and group
for neuron in range(nNeurons):
for stimulation in [0, 1]:
samples = np.random.normal(neuronBaselines[neuron] +
typeEffects[neuronType[neuron]] * stimulation,
residualSd, nTrials)
data.extend(zip([neuron] * nTrials,
[neuronType[neuron]] * nTrials,
range(1, nTrials + 1),
[stimulation] * nTrials, samples))
# Create a DataFrame from the generated data
expData = pd.DataFrame(data, columns=['Neuron', 'Type', 'Trial',
'Stim', 'Response'])
I get the following summary of the results with arziv
:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
baselineFR 10.302 0.494 9.375 11.241 0.010 0.007 2605.0 1424.0 1.0
meanStimEffect 23.620 17.784 0.009 56.464 0.360 0.255 1771.0 1005.0 1.0
neuronBaseFR[0] 11.616 0.646 10.425 12.788 0.012 0.008 3015.0 1487.0 1.0
neuronBaseFR[1] 12.359 0.658 11.069 13.538 0.012 0.008 3052.0 1763.0 1.0
neuronBaseFR[2] 8.824 0.673 7.609 10.099 0.012 0.009 2959.0 1767.0 1.0
... ... ... ... ... ... ... ... ... ...
neuronEffect[29] 0.414 0.960 -1.331 2.192 0.017 0.018 3284.0 1596.0 1.0
sigmaBaselineFR 2.476 0.363 1.868 3.189 0.007 0.006 2571.0 1588.0 1.0
sigmaEffect 19.473 87.730 0.015 49.762 2.246 1.589 3465.0 905.0 1.0
sigmaResidual 3.036 0.065 2.910 3.155 0.001 0.001 3461.0 1321.0 1.0
p 0.995 0.004 0.987 1.000 0.000 0.000 1805.0 1176.0 1.0