Stan on M4 Mac?

The new M4 Apple Silicon chips look to have pretty impressive performance.
Has anyone tried Stan on the an M4 Mac yet?
If so, have you had any technical issues?
Are you getting noteworthy speed ups compared to your old set-up?

The M2 and M3 Macs perform really well for Stan. This is because we’re largely memory bottlenecked in large models and the ARM series is much better at dealing with asynchronous memory access.

1 Like

I got one of the base-level $500 M4 Mac Minis to play around with – have only run a few Stan models on it so far, but noticing a 1.5-3x speedup relative to my laptop (a 2019 MBP w/ an i9-9980HK), which is consistent with the ~2.5x difference in Geekbench 6 scores. Note also this was using 8 chains in parallel and taking the average across chains (to better reflect the 4p + 6e core composition of the 10c M4 chip).

4 Likes

@jroon @Bob_Carpenter this might also interest you – I did a quick test of within-chain parallelization via reduce_sum in a toy model, comparing my Macbook’s performance to the new M4 Mac Mini’s:

functions {
  real partial_sum_lpdf(array[] real y_subset, int start, int end, real mu, real sigma) {
    return normal_lpdf(y_subset | mu, sigma);
  }
}
data {
  int<lower=1> N;
  array[N] real y;
}
parameters {
  real mu;
  real<lower=0> sigma;
}
model {
  mu ~ normal(0, 10);
  sigma ~ exponential(1);
  int grainsize = 1;
  target += reduce_sum(partial_sum_lpdf, y, grainsize, mu, sigma);
}

as a sort of empirical “benchmark” of Amdahl’s law (with BF sales, I could get a base M4 Mac Mini for ~$400, so I snagged another 3 and currently have them daisy chained together and to my macbook via TB4 bridge, distributing jobs – including Stan runs – via GNU Parallel. So I can either distribute embarassingly parallelizable runs across all of them, or do eg four independent chains with within-chain parallelization across all 4, or some combination thereof).

So anyway, I simulate N ∈ {5E2, 5E3, 5E4, 5E5} samples from a normal(2, 1.5) with the same seed and fit the above model with 1, 2, 3, 4… 12 threads, with 50 replicate fits from independent initializations (500 warmup, 500 sampling iterations), using CmdStanR calls to CmdStan. On the MBP I get this scaling:

and on the Mac Mini I get this scaling:

The left axis is in “relative time”, where I divided by the maximum time across 1-12 threads to get everything in (0,1]. So:

Grey closed circles: observed average completion time for a single run (incl warmup), rescaled to a proportion of a single-threaded run

Red open circles: theoretical ideal completion time, just 1 / (1, 2, 3, 4…, 12), with no other bottlenecks (eg i/o)

Blue triangles: relative # of CPU hours using within-chain parallelization, equivalent to the average completion time x the # threads

Pink plusses: Marginal per-thread speed-up of each subsequent thread relative to the theoretical speedup. In other words, the ratio of the differenced grey values to the differenced red values. Axis switches over to the right, and a horizontal pink band marks the [0,1] interval. When near 1, adding a core at the margin improves completion time near the theoretical upper bound. When at 0, adding a core does nothing to speed up the total completion time. When <0, adding a core makes the total completion time go up, worse than doing nothing (the cost of overhead exceeds the cost of more parallel compute). When >1, an average over 50 runs was just too noisy, and the preceding observed average was probably just really bad so the current one probably regressed to (or past) the mean.

Overall, I was curious about a few things:

  1. there are benefits to within-chain parallelization in this implementation even with very small models (N=500) – I’d have expected overhead to dominate even at 2 threads.

  2. the Mac Mini plateaued at 4 threads – this makes some sense (bc it has 4 performance cores, so maybe once those are used up the efficiency cores can only “pay for themselves” – but I’d still not have expected such exact balance), but it was odd that it did not curve back up, even after 10 threads?

2 Likes

Thanks for the analysis, @NikVetr.

You see this with JAX models, too. It depends on whether your hardware can keep the memory local enough for this to be a win.

I have no idea about the performance vs. efficiency cores, but given what you say, this shouldn’t be too surprising.