More IID data yields lots of divergences & terrible recovery

I have a periodic timeseries model that works seemingly great when it only sees one timeseries, recovering simulated frequency, phase, amplitude and noise parameters great:

but when I expand this model to expect multiple timeseries that differ only in their noise (not even the magnitude of noise, so they’re effectively IID samples from the one-timeseries model), it get lots of divergences and the different chains explore very different areas of the parameter space.

Below is the model and R code for exploring this behaviour. Any ideas what might be going awry here? I feel like I must be missing something very simple but I’ve banged my head on this for a few days now and haven’t worked out what, so any help would be greatly appreciated!

functions{
	// p(): an example periodic-but-asymmetric function
	vector p(vector x){
		return(sin(x+sin(x)/2)) ;
	}
}
data{
	// num_samples: number of samples on the x-axis
	int num_samples ;
	// num_signals: number of signals (columns in y)
	int num_signals ;
	// y: matrix of observation
	matrix[num_samples,num_signals] y ;
	// x: values on the x-axis associated with the observations in y
	vector[num_samples] x ;
	// sample_prior_only: helper to toggle on/off posterior sampling
	int<lower=0,upper=1> sample_prior_only ;
}
transformed data{
	// pre-compute part of the input to p()
	vector[num_samples] x_times_pi_times_2 = x*pi()*2 ;
	// compute some quantities for constaining hz
	real hz_max = 1/((x[2]-x[1])*5) ; //assumes equal spacing
	real hz_min = 1/(x[num_samples]-x[1]) ;
	real hz_max_minus_min = hz_max - hz_min ;
}
parameters{
	// noise: measurement noise
	real<lower=0> noise ;
	// hz_scaled01: helper variable for constaining hz
	real<lower=0,upper=1> hz_helper ;
	// phamp: phase-and-amplitude vector (actually a transform thereof)
	vector[2] phamp ;
}
transformed parameters{
	// get hz given hz_helper and constraints
	real hz = hz_helper * hz_max_minus_min + hz_min ;
	// extract phase from phamp
	real phase = atan2(phamp[1],phamp[2]) ;
	// extract amp from phamp (vector magnitude)
	real amp = sqrt(dot_self(phamp)) ;
	// compute the latent function given x, hz, amp & phase
	vector[num_samples] f = amp * p(x_times_pi_times_2*hz - phase) ;
}
model{
	// prior for hz_helper: inverted-u-shaped prior on support
	hz_helper ~ beta(2,2) ;
	// prior for noise: moderately-informed given simulated data
	noise ~ std_normal() ;
	// prior for pamp: implies a weakly-informed prior on amplitude
	phamp ~ std_normal() ;
	//if not simply seeking to sample from the prior,
	//likelihood is the observed data plus noise
	if(!sample_prior_only){
		for(this_signal in 1:num_signals){
			y[,this_signal] ~ normal(f,noise) ;
		}
	}
}

And R code for playing:

#preamble ----

#load packages used
library(cmdstanr)
library(tidyverse)

#load and compile the model
mod = cmdstan_model('hz_mat.stan')

# example function that's periodic but asymmetric
p = function(x,hz,phase){
	x = x*(2*pi)*hz - phase
	sin(x+sin(x)/2)
}

# Generate data ----
seed = 1 #vary this across simulations
num_signals = 1 #if 1, model works fine; if >2 shitshow.

set.seed(seed)
dat = (
	tibble::tibble(
		x = seq(0,10,by=1/10)
		, f = p(x,hz=1,phase=0)
	)
	%>% tidyr::expand_grid(
		signal = sprintf('signal%02d',1:num_signals)
	)
	%>% dplyr::group_by(signal)
	%>% dplyr::mutate(
		y = f + rnorm(n(),0,.1)
	)
)

# quick viz:
(
	dat
	%>% ggplot()
	+ geom_point(aes(x=x,y=y),alpha = .5)
	+ geom_line(aes(x=x,y=f))
)

# reshape to wide
dat_wide =
	(
		dat
		%>% tidyr::pivot_wider(
			names_from = 'signal'
			, values_from = 'y'
		)
		#model assumes that the data is ordered by x
		%>% dplyr::arrange(x)
	)

# sample & check diagnostics ----
fit = mod$sample(
	data = list(
		num_samples = nrow(dat_wide)
		, num_signals = ncol(dat_wide)-2
		, y = (
			dat_wide
			%>% dplyr::select(starts_with('signal'))
			%>% as.matrix()
		)
		, x = dat_wide$x #assumed to be ordered
		, sample_prior_only = F
 	)
	, chains = parallel::detectCores()/2-1
	, parallel_chains = parallel::detectCores()/2-1
	, refresh = 1000
	, seed = seed
	, iter_warmup = 1e3
	, iter_sampling = 1e3
)

diagnostics =
	(
		fit$summary()
		%>% dplyr::select(variable,rhat,contains('ess'))
		%>% dplyr::filter(substr(variable,1,1)!='f')
		%>% dplyr::mutate(
			diagnostics = fit$cmdstan_diagnose()$stdout #annoyingly not quiet-able
			, treedepth_exceeded = str_detect(diagnostics,'transitions hit the maximum')
			, ebmfi_low = str_detect(diagnostics,' is below the nominal threshold')
		)
		%>% dplyr::select(-diagnostics)
	)
print(diagnostics)

# Visualize the posteriors on parameters ---
bayesplot::mcmc_hist(fit$draws(variables=c('phase','hz','amp','noise','snr')))
bayesplot::mcmc_hist_by_chain(fit$draws(variables=c('phase')))
bayesplot::mcmc_hist_by_chain(fit$draws(variables=c('hz')))
bayesplot::mcmc_hist_by_chain(fit$draws(variables=c('amp')))
bayesplot::mcmc_hist_by_chain(fit$draws(variables=c('noise')))

# extract draws and visualize the posterior on the latent function ----
draws =
	(
		fit$draws(
			variables = 'f'
		)
		%>% posterior::as_draws_df()
		%>% tibble::as_tibble()
		%>% dplyr::select(-.iteration)
		%>% tidyr::pivot_longer(
			cols = starts_with('f')
			, names_to = 'x'
		)
		%>% dplyr::mutate(
			x = str_replace(x,fixed('f['),'')
			, x = str_replace(x,fixed(']'),'')
			, x = dat_wide$x[as.numeric(x)]
			, .chain = factor(.chain)
		)
	)


(
	draws
	%>% dplyr::group_by(
		x
		, .chain
	)
	%>% dplyr::summarise(
		med = mean(value)
		, lo50 = quantile(value,.25)
		, hi50 = quantile(value,.75)
		, lo95 = quantile(value,.025)
		, hi95 = quantile(value,.975)
		, .groups = 'drop'
	)
	%>% ggplot()
	+ facet_grid(.chain~.)
	+ geom_line(
		data = dat
		, mapping = aes(
			x = x
			, y = f
		)
		# , size = 2
		# , linetype=2
	)
	+ geom_point(
		data = dat
		, mapping = aes(
			x = x
			, y = y
		)
		, alpha = .5
	)
	+ geom_line(
		data =
			(
				draws
				%>% dplyr::nest_by(.chain,.draw)
				%>% dplyr::group_by(.chain)
				%>% dplyr::slice_sample(n=10)
				%>% tidyr::unnest(cols=data)
			)
		, mapping = aes(
			x = x
			, y = value
			, group = .draw
			# , colour = .chain
		)
		, alpha = .2
	)
	+ geom_ribbon(
		mapping = aes(
			x = x
			, ymin = lo50
			, ymax = hi50
			, group = .chain
			, fill = .chain
		)
		, alpha = .5
	)
	+ geom_ribbon(
		mapping = aes(
			x = x
			, ymin = lo95
			, ymax = hi95
			, group = .chain
			, fill = .chain
		)
		, alpha = .5
	)

)
1 Like

With anything sinusoidal I’d be worrying about an aliasing thing.

I assume if the likelihoods/priors/everything for the two different sinusoids are totally independent, then there shouldn’t be an interaction, other than more parameters is more likely to accidentally sample something that is aliasing.

Maybe just run a large number of chains with the single sinusoid (with different parameters) and make sure you don’t get problems there? If this doesn’t do anything I kinda expect a bug somewhere.

The data were generated from the same sinusoid (same hz & phase and sampling times), just with different random measurement noise (but same noise distribution). So they really are IID. Possibly a better set of terms I should use for this is to contrast between the single-noisy-replicate-at-each-timepoint data/model and multiple-noisy-replicates-at-each-timepoint data/model.

Yup, I have confirmed that across many different values for the data-generation/model parameters, the single-noisy-replicate-at-each-timepoint model does fine, it’s only when I add multiple noisy replicates at each timepoint that it fails. I’ve even explored adding replicates to isolated timepoints, and observe some pretty expected behaviour (they’re more informative when adding at the peak/trough points than at a zero-crossing point), but I still observe much greater fragility in the sampling regardless of where the replicates occur.

One thought I just had:
When the model samples poorly in the presence of multiple replicates per timepoint, it seems to be that the likelihood has a hard time distinguishing between (1) model parameter configurations that combine high SNR (low ‘noise’, high ‘amp’) with the true values of frequency & phase and (2) model configurations that combine low SNR with arbitrary values of frequency & phase. I don’t know why this identifiability issue would only seemingly crop up once there is more data at each timepoint; I’d have thought that such data would help steer towards the data-generating values for noise and amp. I also thought that sufficiently strong priors on at least noise, amp, and hz would help steer the sampling away from the latter configurations, but haven’t had any luck in that regard.

Oh so this is multiple observations of the same sinusoid. Yeah this does seem weird.

I guess multiple observations will have a similar effect as observations with lower noise. Can you try single data with a lower noise?

I could be an initial conditions thing, and the priors don’t determine those.

Yeah, I thought I’d set up the model to have parameters where initial conditions in the typical range would all be fine, but supplying inits based on the true values does indeed allow even the multiple replicates to sample well again. I can probably work out approximations to this for real-world data, but still find this behaviour very strange.

The planetary example in the workflow preprint (Chapter 11) describes IMHO quite similar situation: there is a multimodality in the likelihood that is exacerbated with more data. There the unwanted modes are formed by a much higher frequency rotation fitting some data points very well and other terribly. And any local change then decreases the fit to the well fitted datapoints more than it improves the fit for the others, creating a mode with overall very low likelihood. This looks like something similar might be happening here… (but I am not sure, just guessing).

2 Likes

Yeah this. I was trying to think of how to recommend the case study but it wasn’t quite up yet. I forgot it was part of the workflow paper :/

Thanks for pointing me to this paper, I somehow missed it!

I’ll definitely have to do some exploration of the likelihood. I’ll report back if this yields any useful insights.

@betanalpha commented in another thread:

IID data just amplifies the shape of the individual likelihood function. If each individual likelihood function is mildly multimodal – but not so much that the sampler can’t move between the modes and find all of the probability mass – then the product of likelihood functions can be extremely multimodal – such that the sampler gets stuck in bad modes due to initialization. In other words peaks get amplified and the valleys between the peaks get suppressed.

I coded up the computation of just the likelihood for a model with inference on the phase and frequency of a sine. Here’s the topography with just one IID sample-per-timepoint:

image

And for 10 IID samples per timepoint:
image

Note the color scale has changed by a factor of 10. Here’s the two on the same scale:

So I think the explanation of what goes wrong with periodic models is those diagonal troughs, into which chains can get stuck depending on where they’re initialized, and (as @betanalpha pointed out) the inclusion of more IID data only exacerbates the problem because the troughs get deeper.

I’m going to play a bit to see if I can discern why those troughs are there; they’re not at the frequencies I’d have expected from simple harmonic behaviours, plus they’re sensitive to phase too. Hopefully by working out why the modes are there in the first place a transform of some sort will suggest itself that will help improve the geometry.

1 Like

It may have something to do with how the test model is set up.

p = function(x,hz,phase){
	x = x*(2*pi)*hz - phase
	sin(x+sin(x)/2)
f <- p(x, 1, 0)
stats::spectrum(f, method="ar")

gives


and from your 1st plot it looks like some chains are picking up the second energy peak(possibly with a phase shift). To test if this is the case, one can fit the same model but use a much higher sampling/acquisition frequency than the signal frequency, as

x <- seq(0, 10, 0.001)
f <- p(x, 1, 0)
stats::spectrum(f, method="ar")

gives

1 Like

Thanks for your time/input, though I’ve actually replicated this with even simple sin() functions (that’s what my most recent thread shows).