Stan SIMD & Performance

Ah, well – looks like the vari type interleaves values with adjoints. This is (IMO) better than Julia’s ForwardDiff.Dual, which interleaves a value with a vector of (forward-mode) partials, but the interleaving is still not ideal.

If we have an Eigen matrix, what is the actual data layout?
Given that stan::math::var is a pointer to the vari, does that mean it is a matrix of pointers to value, adjoint pairs?

From my (admittedly poor) understanding of what Stan does under the hood, it doesn’t look like it can take advantage of SIMD.
Based on assessments of runtime performance vs SIMD-implementations, as well as comparisons of clock speeds (watch -n1 "cat /proc/cpuinfo | grep MHz") vs bios settings for non-avx, avx2, and avx512 workloads further support that conclusion.

As does the fact that the “vectorized” bernoulli_logit_glm function was several times slower than a Julia version I vectorized. Again, I’m fairly ignorant about the Stan version’s implementation and did not look at the assembly (although I did of the Julia version), so all I can speak of there is the observed performance difference implying unrealized potential in Stan.

Even with doubles instead of the stan::math::var type, Eigen as a library does not appear to do a good job for small, fixed-size, matrices (compared to libraries like PaddedMatrices.jl, blaze, or Intel MKL JIT. I haven’t investigated small, dynamically-sized Eigen matrices (which Stan uses), but it’d be a little odd if that performed better.

I’m not sure if I have anything constructive to say. Making good use of SIMD starts from the data structures / the algorithms need to be written with it in mind, therefore there (normally) aren’t any easy fixes or recommendations.
Especially when I haven’t spent the time to really study Stan’s internals and algorithms. And I’m also hopefully wrong about some of the things I’ve said here.

2 Likes

I cant say whether eigen matrices of vars receive simd instructions, but I have an example below from godbolt and it looks like Eigen does fine at generating code using VMOVUPD and other AVX instructions (I’m not the best at reading assembly so if that’s false my bad)

You bring up some interesting Qs though! I’ll check the gcc output with a stan example next week to see what sort of instructions eigen is generating. Could also probs bring the base files for stan into godbolt and do it all there. Also wonder what adding vectors of vars would look like. We keep the val and adj next to each other because in scalar arithmetic on vars we tend to use the val and adj of each var in a way that it’s better to have them next to each other

Also on a side note it would be a bit easier to see what’s going on in your post if you made some graphs over the different sizes / performance. Imo

I’d do more but on mobile atm

Back in the day Eigen needed compiler flags to enable taking advantage of chip-specific vectorization and optimizations – not sure if that’s the still the case or not but is worth verifying before discussing changing architectures (perhaps along with investigations of the assembly code).

1 Like

Eigen does use some SIMD instructions, but (as I said) it does a bad job of it.

Here are some benchmarks with neat sizes (16x24 * 24x14 = 16x14) and awkward prime sizes (31x37 * 37x29 = 31x29). Pardon the awkward code (particularly creating and allocating the matrices), but Eigen uses vmovapd instead of vmovupd; they’re more or less the same (I’ve not noticed a performance difference), except vmovapd causes segfaults without alignment.

julia> using PaddedMatrices, LinearAlgebra, Random, BenchmarkTools

julia> import PaddedMatrices: AbstractMutableFixedSizePaddedMatrix

julia> # Compiled with
       # g++ -O3 -fno-signed-zeros -fno-trapping-math -fassociative-math -march=skylake-avx512 -mprefer-vector-width=512 -I/usr/include/eigen3 -shared -fPIC eigen_matmul_test.cpp -o libgppeigentest.so
       # clang++ -O3 -fno-signed-zeros -fno-trapping-math -fassociative-math -march=skylake-avx512 -mprefer-vector-width=512 -I/usr/include/eigen3 -shared -fPIC eigen_matmul_test.cpp -o libclangeigentest.so
       const gpplib = "/home/chriselrod/Documents/progwork/Cxx/libgppeigentest.so"
"/home/chriselrod/Documents/progwork/Cxx/libgppeigentest.so"

julia> const clanglib = "/home/chriselrod/Documents/progwork/Cxx/libclangeigentest.so"
"/home/chriselrod/Documents/progwork/Cxx/libclangeigentest.so"

julia> function gppeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{16,14,Float64,16},
           A::AbstractMutableFixedSizePaddedMatrix{16,24,Float64,16},
           B::AbstractMutableFixedSizePaddedMatrix{24,14,Float64,24}
       )
           ccall(
               (:mul_16x24times24x14, gpplib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
gppeigen! (generic function with 1 method)

julia> function gppeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{31,29,Float64,31},
           A::AbstractMutableFixedSizePaddedMatrix{31,37,Float64,31},
           B::AbstractMutableFixedSizePaddedMatrix{37,29,Float64,37}
       )
           ccall(
               (:mul_31x37times37x29, gpplib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
gppeigen! (generic function with 2 methods)

julia> function clangeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{16,14,Float64,16},
           A::AbstractMutableFixedSizePaddedMatrix{16,24,Float64,16},
           B::AbstractMutableFixedSizePaddedMatrix{24,14,Float64,24}
       )
           ccall(
               (:mul_16x24times24x14, clanglib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
clangeigen! (generic function with 1 method)

julia> function clangeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{31,29,Float64,31},
           A::AbstractMutableFixedSizePaddedMatrix{31,37,Float64,31},
           B::AbstractMutableFixedSizePaddedMatrix{37,29,Float64,37}
       )
           ccall(
               (:mul_31x37times37x29, clanglib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
clangeigen! (generic function with 2 methods)

julia> align_64(x) = (x + 63) & ~63
align_64 (generic function with 1 method)

julia> align_64(x, ptr::Ptr{T}) where {T} = align_64(x*sizeof(T) + ptr)
align_64 (generic function with 2 methods)

julia> const PTR = Base.unsafe_convert(Ptr{Float64}, Libc.malloc(1<<20)); # more than enough space

julia> align_64(x::Ptr{T}) where {T} = reinterpret(Ptr{T},align_64(reinterpret(UInt,x)))
align_64 (generic function with 3 methods)

julia> align_64(PTR)
Ptr{Float64} @0x00007fad8b585040

julia> A16x24 = PtrMatrix{16,24,Float64,16}(align_64(PTR)); randn!(A16x24);

julia> B24x14 = PtrMatrix{24,14,Float64,24}(align_64(16*24,pointer(A16x24))); randn!(B24x14);

julia> C16x14 = PtrMatrix{16,14,Float64,16}(align_64(24*14,pointer(B24x14)));

julia> clangeigen!(C16x14, A16x24, B24x14)

julia> @benchmark gppeigen!($C16x14, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     446.863 ns (0.00% GC)
  median time:      454.548 ns (0.00% GC)
  mean time:        457.793 ns (0.00% GC)
  maximum time:     701.843 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     197

julia> @benchmark clangeigen!($C16x14, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     470.245 ns (0.00% GC)
  median time:      474.612 ns (0.00% GC)
  mean time:        477.438 ns (0.00% GC)
  maximum time:     663.571 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     196

julia> @benchmark mul!($C16x14, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     88.141 ns (0.00% GC)
  median time:      90.651 ns (0.00% GC)
  mean time:        91.153 ns (0.00% GC)
  maximum time:     142.651 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     960

julia> # Prime sizes
       A31x37 = PtrMatrix{31,37,Float64,31}(align_64(16*14,pointer(C16x14))); randn!(A31x37);

julia> B37x29 = PtrMatrix{37,29,Float64,37}(align_64(31*37,pointer(A31x37))); randn!(B37x29);

julia> C31x29 = PtrMatrix{31,29,Float64,31}(align_64(37*29,pointer(B37x29)));

julia> @benchmark gppeigen!($C31x29, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.435 μs (0.00% GC)
  median time:      2.458 μs (0.00% GC)
  mean time:        2.494 μs (0.00% GC)
  maximum time:     6.107 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark clangeigen!($C31x29, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.346 μs (0.00% GC)
  median time:      2.373 μs (0.00% GC)
  mean time:        2.388 μs (0.00% GC)
  maximum time:     6.090 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark mul!($C31x29, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     585.606 ns (0.00% GC)
  median time:      607.706 ns (0.00% GC)
  mean time:        610.846 ns (0.00% GC)
  maximum time:     1.466 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     180

This roughly four-fold difference in performance was common across an array of small sizes.

You can look at the C++ code and the assembly here.
For comparison, here is the assembly Julia produced (I deleted all lines that only contained comments for brevity’s sake):

julia> @code_native mul!(C16x14, A16x24, B24x14)
	.text
	movq	(%rsi), %rax
	vmovupd	(%rax), %zmm26
	vmovupd	64(%rax), %zmm28
	movq	(%rdx), %rcx
	vbroadcastsd	(%rcx), %zmm1
	vmulpd	%zmm26, %zmm1, %zmm0
	vmulpd	%zmm28, %zmm1, %zmm1
	vbroadcastsd	192(%rcx), %zmm3
	vmulpd	%zmm26, %zmm3, %zmm2
	vmulpd	%zmm28, %zmm3, %zmm3
	vbroadcastsd	384(%rcx), %zmm5
	vmulpd	%zmm26, %zmm5, %zmm4
	vmulpd	%zmm28, %zmm5, %zmm5
	vbroadcastsd	576(%rcx), %zmm7
	vmulpd	%zmm26, %zmm7, %zmm6
	vmulpd	%zmm28, %zmm7, %zmm7
	vbroadcastsd	768(%rcx), %zmm9
	vmulpd	%zmm26, %zmm9, %zmm8
	vmulpd	%zmm28, %zmm9, %zmm9
	vbroadcastsd	960(%rcx), %zmm11
	vmulpd	%zmm26, %zmm11, %zmm10
	vmulpd	%zmm28, %zmm11, %zmm11
	vbroadcastsd	1152(%rcx), %zmm13
	vmulpd	%zmm26, %zmm13, %zmm12
	vmulpd	%zmm28, %zmm13, %zmm13
	vbroadcastsd	1344(%rcx), %zmm15
	vmulpd	%zmm26, %zmm15, %zmm14
	vmulpd	%zmm28, %zmm15, %zmm15
	vbroadcastsd	1536(%rcx), %zmm17
	vmulpd	%zmm26, %zmm17, %zmm16
	vmulpd	%zmm28, %zmm17, %zmm17
	vbroadcastsd	1728(%rcx), %zmm19
	vmulpd	%zmm26, %zmm19, %zmm18
	vmulpd	%zmm28, %zmm19, %zmm19
	vbroadcastsd	1920(%rcx), %zmm21
	vmulpd	%zmm26, %zmm21, %zmm20
	vmulpd	%zmm28, %zmm21, %zmm21
	vbroadcastsd	2112(%rcx), %zmm23
	vmulpd	%zmm26, %zmm23, %zmm22
	vmulpd	%zmm28, %zmm23, %zmm23
	vbroadcastsd	2304(%rcx), %zmm25
	vmulpd	%zmm26, %zmm25, %zmm24
	vmulpd	%zmm28, %zmm25, %zmm25
	vbroadcastsd	2496(%rcx), %zmm29
	vmulpd	%zmm26, %zmm29, %zmm27
	vmulpd	%zmm28, %zmm29, %zmm26
	addq	$192, %rax
	movq	$-23, %rdx
	nopw	%cs:(%rax,%rax)
	nopl	(%rax,%rax)
L336:
	vmovapd	%zmm27, %zmm28
	vmovupd	-64(%rax), %zmm27
	vmovupd	(%rax), %zmm29
	vbroadcastsd	192(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm0 # zmm0 = (zmm27 * zmm30) + zmm0
	vfmadd231pd	%zmm30, %zmm29, %zmm1 # zmm1 = (zmm29 * zmm30) + zmm1
	vbroadcastsd	384(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm2 # zmm2 = (zmm27 * zmm30) + zmm2
	vfmadd231pd	%zmm30, %zmm29, %zmm3 # zmm3 = (zmm29 * zmm30) + zmm3
	vbroadcastsd	576(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm4 # zmm4 = (zmm27 * zmm30) + zmm4
	vfmadd231pd	%zmm30, %zmm29, %zmm5 # zmm5 = (zmm29 * zmm30) + zmm5
	vbroadcastsd	768(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm6 # zmm6 = (zmm27 * zmm30) + zmm6
	vfmadd231pd	%zmm30, %zmm29, %zmm7 # zmm7 = (zmm29 * zmm30) + zmm7
	vbroadcastsd	960(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm8 # zmm8 = (zmm27 * zmm30) + zmm8
	vfmadd231pd	%zmm30, %zmm29, %zmm9 # zmm9 = (zmm29 * zmm30) + zmm9
	vbroadcastsd	1152(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm10 # zmm10 = (zmm27 * zmm30) + zmm10
	vfmadd231pd	%zmm30, %zmm29, %zmm11 # zmm11 = (zmm29 * zmm30) + zmm11
	vbroadcastsd	1344(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm12 # zmm12 = (zmm27 * zmm30) + zmm12
	vfmadd231pd	%zmm30, %zmm29, %zmm13 # zmm13 = (zmm29 * zmm30) + zmm13
	vbroadcastsd	1536(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm14 # zmm14 = (zmm27 * zmm30) + zmm14
	vfmadd231pd	%zmm30, %zmm29, %zmm15 # zmm15 = (zmm29 * zmm30) + zmm15
	vbroadcastsd	1728(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm16 # zmm16 = (zmm27 * zmm30) + zmm16
	vfmadd231pd	%zmm30, %zmm29, %zmm17 # zmm17 = (zmm29 * zmm30) + zmm17
	vbroadcastsd	1920(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm18 # zmm18 = (zmm27 * zmm30) + zmm18
	vfmadd231pd	%zmm30, %zmm29, %zmm19 # zmm19 = (zmm29 * zmm30) + zmm19
	vbroadcastsd	2112(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm20 # zmm20 = (zmm27 * zmm30) + zmm20
	vfmadd231pd	%zmm30, %zmm29, %zmm21 # zmm21 = (zmm29 * zmm30) + zmm21
	vbroadcastsd	2304(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm22 # zmm22 = (zmm27 * zmm30) + zmm22
	vfmadd231pd	%zmm30, %zmm29, %zmm23 # zmm23 = (zmm29 * zmm30) + zmm23
	vbroadcastsd	2496(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm24 # zmm24 = (zmm27 * zmm30) + zmm24
	vfmadd231pd	%zmm30, %zmm29, %zmm25 # zmm25 = (zmm29 * zmm30) + zmm25
	vbroadcastsd	2688(%rcx,%rdx,8), %zmm30
	vfmadd213pd	%zmm28, %zmm30, %zmm27 # zmm27 = (zmm30 * zmm27) + zmm28
	vfmadd231pd	%zmm30, %zmm29, %zmm26 # zmm26 = (zmm29 * zmm30) + zmm26
	subq	$-128, %rax
	incq	%rdx
	jne	L336
	movq	(%rdi), %rax
	vmovupd	%zmm0, (%rax)
	vmovupd	%zmm1, 64(%rax)
	vmovupd	%zmm2, 128(%rax)
	vmovupd	%zmm3, 192(%rax)
	vmovupd	%zmm4, 256(%rax)
	vmovupd	%zmm5, 320(%rax)
	vmovupd	%zmm6, 384(%rax)
	vmovupd	%zmm7, 448(%rax)
	vmovupd	%zmm8, 512(%rax)
	vmovupd	%zmm9, 576(%rax)
	vmovupd	%zmm10, 640(%rax)
	vmovupd	%zmm11, 704(%rax)
	vmovupd	%zmm12, 768(%rax)
	vmovupd	%zmm13, 832(%rax)
	vmovupd	%zmm14, 896(%rax)
	vmovupd	%zmm15, 960(%rax)
	vmovupd	%zmm16, 1024(%rax)
	vmovupd	%zmm17, 1088(%rax)
	vmovupd	%zmm18, 1152(%rax)
	vmovupd	%zmm19, 1216(%rax)
	vmovupd	%zmm20, 1280(%rax)
	vmovupd	%zmm21, 1344(%rax)
	vmovupd	%zmm22, 1408(%rax)
	vmovupd	%zmm23, 1472(%rax)
	vmovupd	%zmm24, 1536(%rax)
	vmovupd	%zmm25, 1600(%rax)
	vmovupd	%zmm27, 1664(%rax)
	vmovupd	%zmm26, 1728(%rax)
	vzeroupper
	retq
	nopl	(%rax)

and for the larger. awkwardly sized matrices:

julia> @code_native mul!(C31x29, A31x37, B37x29)
	.text
	pushq	%rbp
	pushq	%r15
	pushq	%r14
	pushq	%r13
	pushq	%r12
	pushq	%rbx
	movq	%rdx, -8(%rsp)
	movq	%rsi, -16(%rsp)
	movl	$1, %r9d
	movl	$1488, %esi             # imm = 0x5D0
	movb	$127, %r10b
	nopw	%cs:(%rax,%rax)
	nopl	(%rax)
L48:
	leaq	-1(%r9), %rcx
	movq	-16(%rsp), %rax
	movq	(%rax), %r11
	movq	-8(%rsp), %rax
	movq	(%rax), %r14
	imulq	$1488, %rcx, %rax       # imm = 0x5D0
	imulq	$1776, %rcx, %r13       # imm = 0x6F0
	addq	(%rdi), %rax
	leaq	(%r14,%rsi), %rcx
	leaq	440(%r11), %rdx
	xorl	%r12d, %r12d
	nopw	%cs:(%rax,%rax)
	nopl	(%rax)
L112:
	imulq	$248, %r12, %rbx
	leaq	(%r11,%rbx), %rbp
	vmovupd	(%r11,%rbx), %zmm19
	vmovupd	64(%r11,%rbx), %zmm22
	vmovupd	128(%r11,%rbx), %zmm23
	vbroadcastsd	(%r14,%r13), %zmm3
	vmovupd	192(%r11,%rbx), %zmm24
	vmulpd	%zmm19, %zmm3, %zmm0
	vmulpd	%zmm22, %zmm3, %zmm1
	vmulpd	%zmm23, %zmm3, %zmm2
	vbroadcastsd	296(%r14,%r13), %zmm7
	vmulpd	%zmm24, %zmm3, %zmm6
	vmulpd	%zmm19, %zmm7, %zmm3
	vmulpd	%zmm22, %zmm7, %zmm4
	vmulpd	%zmm23, %zmm7, %zmm5
	vbroadcastsd	592(%r14,%r13), %zmm11
	vmulpd	%zmm24, %zmm7, %zmm10
	vmulpd	%zmm19, %zmm11, %zmm7
	vmulpd	%zmm22, %zmm11, %zmm8
	vmulpd	%zmm23, %zmm11, %zmm9
	vbroadcastsd	888(%r14,%r13), %zmm15
	vmulpd	%zmm24, %zmm11, %zmm14
	vmulpd	%zmm19, %zmm15, %zmm11
	vmulpd	%zmm22, %zmm15, %zmm12
	vmulpd	%zmm23, %zmm15, %zmm13
	vbroadcastsd	1184(%r14,%r13), %zmm20
	vmulpd	%zmm24, %zmm15, %zmm18
	vmulpd	%zmm19, %zmm20, %zmm15
	vmulpd	%zmm22, %zmm20, %zmm16
	vmulpd	%zmm23, %zmm20, %zmm17
	vbroadcastsd	1480(%r14,%r13), %zmm25
	vmulpd	%zmm24, %zmm20, %zmm20
	vmulpd	%zmm19, %zmm25, %zmm21
	vmulpd	%zmm22, %zmm25, %zmm22
	vmulpd	%zmm23, %zmm25, %zmm23
	vmulpd	%zmm24, %zmm25, %zmm19
	movq	$-35, %r8
	movq	%rdx, %r15
	nopl	(%rax)
L368:
	vmovapd	%zmm23, %zmm24
	vmovapd	%zmm22, %zmm25
	vmovapd	%zmm21, %zmm26
	vmovupd	-192(%r15), %zmm21
	vmovupd	-128(%r15), %zmm22
	vmovupd	-64(%r15), %zmm23
	vmovupd	(%r15), %zmm27
	vbroadcastsd	-1200(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm0 # zmm0 = (zmm21 * zmm28) + zmm0
	vfmadd231pd	%zmm28, %zmm22, %zmm1 # zmm1 = (zmm22 * zmm28) + zmm1
	vfmadd231pd	%zmm28, %zmm23, %zmm2 # zmm2 = (zmm23 * zmm28) + zmm2
	vfmadd231pd	%zmm28, %zmm27, %zmm6 # zmm6 = (zmm27 * zmm28) + zmm6
	vbroadcastsd	-904(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm3 # zmm3 = (zmm21 * zmm28) + zmm3
	vfmadd231pd	%zmm28, %zmm22, %zmm4 # zmm4 = (zmm22 * zmm28) + zmm4
	vfmadd231pd	%zmm28, %zmm23, %zmm5 # zmm5 = (zmm23 * zmm28) + zmm5
	vfmadd231pd	%zmm28, %zmm27, %zmm10 # zmm10 = (zmm27 * zmm28) + zmm10
	vbroadcastsd	-608(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm7 # zmm7 = (zmm21 * zmm28) + zmm7
	vfmadd231pd	%zmm28, %zmm22, %zmm8 # zmm8 = (zmm22 * zmm28) + zmm8
	vfmadd231pd	%zmm28, %zmm23, %zmm9 # zmm9 = (zmm23 * zmm28) + zmm9
	vfmadd231pd	%zmm28, %zmm27, %zmm14 # zmm14 = (zmm27 * zmm28) + zmm14
	vbroadcastsd	-312(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm11 # zmm11 = (zmm21 * zmm28) + zmm11
	vfmadd231pd	%zmm28, %zmm22, %zmm12 # zmm12 = (zmm22 * zmm28) + zmm12
	vfmadd231pd	%zmm28, %zmm23, %zmm13 # zmm13 = (zmm23 * zmm28) + zmm13
	vfmadd231pd	%zmm28, %zmm27, %zmm18 # zmm18 = (zmm27 * zmm28) + zmm18
	vbroadcastsd	-16(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm15 # zmm15 = (zmm21 * zmm28) + zmm15
	vfmadd231pd	%zmm28, %zmm22, %zmm16 # zmm16 = (zmm22 * zmm28) + zmm16
	vfmadd231pd	%zmm28, %zmm23, %zmm17 # zmm17 = (zmm23 * zmm28) + zmm17
	vfmadd231pd	%zmm28, %zmm27, %zmm20 # zmm20 = (zmm27 * zmm28) + zmm20
	vbroadcastsd	280(%rcx,%r8,8), %zmm28
	vfmadd213pd	%zmm26, %zmm28, %zmm21 # zmm21 = (zmm28 * zmm21) + zmm26
	vfmadd213pd	%zmm25, %zmm28, %zmm22 # zmm22 = (zmm28 * zmm22) + zmm25
	vfmadd213pd	%zmm24, %zmm28, %zmm23 # zmm23 = (zmm28 * zmm23) + zmm24
	vfmadd231pd	%zmm28, %zmm27, %zmm19 # zmm19 = (zmm27 * zmm28) + zmm19
	addq	$248, %r15
	incq	%r8
	jne	L368
	vmovupd	8928(%rbp), %zmm24
	vmovupd	8992(%rbp), %zmm25
	vmovupd	9056(%rbp), %zmm26
	kmovd	%r10d, %k1
	vmovupd	9120(%rbp), %zmm27 {%k1} {z}
	vbroadcastsd	288(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm0 # zmm0 = (zmm24 * zmm28) + zmm0
	vfmadd231pd	%zmm28, %zmm25, %zmm1 # zmm1 = (zmm25 * zmm28) + zmm1
	vfmadd231pd	%zmm28, %zmm26, %zmm2 # zmm2 = (zmm26 * zmm28) + zmm2
	vfmadd231pd	%zmm28, %zmm27, %zmm6 # zmm6 = (zmm27 * zmm28) + zmm6
	vbroadcastsd	584(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm3 # zmm3 = (zmm24 * zmm28) + zmm3
	vfmadd231pd	%zmm28, %zmm25, %zmm4 # zmm4 = (zmm25 * zmm28) + zmm4
	vfmadd231pd	%zmm28, %zmm26, %zmm5 # zmm5 = (zmm26 * zmm28) + zmm5
	vfmadd231pd	%zmm28, %zmm27, %zmm10 # zmm10 = (zmm27 * zmm28) + zmm10
	vbroadcastsd	880(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm7 # zmm7 = (zmm24 * zmm28) + zmm7
	vfmadd231pd	%zmm28, %zmm25, %zmm8 # zmm8 = (zmm25 * zmm28) + zmm8
	vfmadd231pd	%zmm28, %zmm26, %zmm9 # zmm9 = (zmm26 * zmm28) + zmm9
	vfmadd231pd	%zmm28, %zmm27, %zmm14 # zmm14 = (zmm27 * zmm28) + zmm14
	vbroadcastsd	1176(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm11 # zmm11 = (zmm24 * zmm28) + zmm11
	vfmadd231pd	%zmm28, %zmm25, %zmm12 # zmm12 = (zmm25 * zmm28) + zmm12
	vfmadd231pd	%zmm28, %zmm26, %zmm13 # zmm13 = (zmm26 * zmm28) + zmm13
	vfmadd231pd	%zmm28, %zmm27, %zmm18 # zmm18 = (zmm27 * zmm28) + zmm18
	vbroadcastsd	1472(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm15 # zmm15 = (zmm24 * zmm28) + zmm15
	vfmadd231pd	%zmm28, %zmm25, %zmm16 # zmm16 = (zmm25 * zmm28) + zmm16
	vfmadd231pd	%zmm28, %zmm26, %zmm17 # zmm17 = (zmm26 * zmm28) + zmm17
	vfmadd231pd	%zmm28, %zmm27, %zmm20 # zmm20 = (zmm27 * zmm28) + zmm20
	vbroadcastsd	1768(%r14,%r13), %zmm28
	vfmadd213pd	%zmm21, %zmm28, %zmm24 # zmm24 = (zmm28 * zmm24) + zmm21
	vfmadd213pd	%zmm22, %zmm28, %zmm25 # zmm25 = (zmm28 * zmm25) + zmm22
	vfmadd213pd	%zmm23, %zmm28, %zmm26 # zmm26 = (zmm28 * zmm26) + zmm23
	vmovupd	%zmm0, (%rax,%rbx)
	vmovupd	%zmm1, 64(%rax,%rbx)
	vmovupd	%zmm2, 128(%rax,%rbx)
	vmovupd	%zmm6, 192(%rax,%rbx) {%k1}
	vmovupd	%zmm3, 248(%rax,%rbx)
	vmovupd	%zmm4, 312(%rax,%rbx)
	vmovupd	%zmm5, 376(%rax,%rbx)
	vmovupd	%zmm10, 440(%rax,%rbx) {%k1}
	vmovupd	%zmm7, 496(%rax,%rbx)
	vmovupd	%zmm8, 560(%rax,%rbx)
	vmovupd	%zmm9, 624(%rax,%rbx)
	vmovupd	%zmm14, 688(%rax,%rbx) {%k1}
	vmovupd	%zmm11, 744(%rax,%rbx)
	vmovupd	%zmm12, 808(%rax,%rbx)
	vmovupd	%zmm13, 872(%rax,%rbx)
	vmovupd	%zmm18, 936(%rax,%rbx) {%k1}
	vmovupd	%zmm15, 992(%rax,%rbx)
	vmovupd	%zmm16, 1056(%rax,%rbx)
	vmovupd	%zmm17, 1120(%rax,%rbx)
	vmovupd	%zmm20, 1184(%rax,%rbx) {%k1}
	vmovupd	%zmm24, 1240(%rax,%rbx)
	vmovupd	%zmm25, 1304(%rax,%rbx)
	vmovupd	%zmm26, 1368(%rax,%rbx)
	vfmadd231pd	%zmm28, %zmm27, %zmm19 # zmm19 = (zmm27 * zmm28) + zmm19
	vmovupd	%zmm19, 1432(%rax,%rbx) {%k1}
	addq	$248, %rdx
	testq	%r12, %r12
	leaq	1(%r12), %r12
	jne	L112
	addq	$1776, %rsi             # imm = 0x6F0
	cmpq	$4, %r9
	leaq	1(%r9), %r9
	jne	L48
	movq	(%rdi), %rax
	movq	-16(%rsp), %rcx
	movq	(%rcx), %rsi
	movq	-8(%rsp), %rcx
	movq	(%rcx), %rcx
	vmovupd	(%rsi), %zmm16
	vmovupd	64(%rsi), %zmm18
	vmovupd	128(%rsi), %zmm19
	vmovupd	192(%rsi), %zmm20
	vbroadcastsd	7104(%rcx), %zmm3
	vmulpd	%zmm16, %zmm3, %zmm0
	vmulpd	%zmm18, %zmm3, %zmm1
	vmulpd	%zmm19, %zmm3, %zmm2
	vmulpd	%zmm20, %zmm3, %zmm3
	vbroadcastsd	7400(%rcx), %zmm7
	vmulpd	%zmm16, %zmm7, %zmm4
	vmulpd	%zmm18, %zmm7, %zmm5
	vmulpd	%zmm19, %zmm7, %zmm6
	vmulpd	%zmm20, %zmm7, %zmm7
	vbroadcastsd	7696(%rcx), %zmm11
	vmulpd	%zmm16, %zmm11, %zmm8
	vmulpd	%zmm18, %zmm11, %zmm9
	vmulpd	%zmm19, %zmm11, %zmm10
	vmulpd	%zmm20, %zmm11, %zmm11
	vbroadcastsd	7992(%rcx), %zmm15
	vmulpd	%zmm16, %zmm15, %zmm12
	vmulpd	%zmm18, %zmm15, %zmm13
	vmulpd	%zmm19, %zmm15, %zmm14
	vmulpd	%zmm20, %zmm15, %zmm15
	vbroadcastsd	8288(%rcx), %zmm21
	vmulpd	%zmm16, %zmm21, %zmm17
	vmulpd	%zmm18, %zmm21, %zmm18
	vmulpd	%zmm19, %zmm21, %zmm19
	vmulpd	%zmm20, %zmm21, %zmm16
	leaq	440(%rsi), %rdx
	movq	$-35, %rdi
	nopw	%cs:(%rax,%rax)
	nopl	(%rax,%rax)
L1408:
	vmovapd	%zmm19, %zmm20
	vmovapd	%zmm18, %zmm21
	vmovapd	%zmm17, %zmm22
	vmovupd	-192(%rdx), %zmm17
	vmovupd	-128(%rdx), %zmm18
	vmovupd	-64(%rdx), %zmm19
	vmovupd	(%rdx), %zmm23
	vbroadcastsd	7392(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm0 # zmm0 = (zmm17 * zmm24) + zmm0
	vfmadd231pd	%zmm24, %zmm18, %zmm1 # zmm1 = (zmm18 * zmm24) + zmm1
	vfmadd231pd	%zmm24, %zmm19, %zmm2 # zmm2 = (zmm19 * zmm24) + zmm2
	vfmadd231pd	%zmm24, %zmm23, %zmm3 # zmm3 = (zmm23 * zmm24) + zmm3
	vbroadcastsd	7688(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm4 # zmm4 = (zmm17 * zmm24) + zmm4
	vfmadd231pd	%zmm24, %zmm18, %zmm5 # zmm5 = (zmm18 * zmm24) + zmm5
	vfmadd231pd	%zmm24, %zmm19, %zmm6 # zmm6 = (zmm19 * zmm24) + zmm6
	vfmadd231pd	%zmm24, %zmm23, %zmm7 # zmm7 = (zmm23 * zmm24) + zmm7
	vbroadcastsd	7984(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm8 # zmm8 = (zmm17 * zmm24) + zmm8
	vfmadd231pd	%zmm24, %zmm18, %zmm9 # zmm9 = (zmm18 * zmm24) + zmm9
	vfmadd231pd	%zmm24, %zmm19, %zmm10 # zmm10 = (zmm19 * zmm24) + zmm10
	vfmadd231pd	%zmm24, %zmm23, %zmm11 # zmm11 = (zmm23 * zmm24) + zmm11
	vbroadcastsd	8280(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm12 # zmm12 = (zmm17 * zmm24) + zmm12
	vfmadd231pd	%zmm24, %zmm18, %zmm13 # zmm13 = (zmm18 * zmm24) + zmm13
	vfmadd231pd	%zmm24, %zmm19, %zmm14 # zmm14 = (zmm19 * zmm24) + zmm14
	vfmadd231pd	%zmm24, %zmm23, %zmm15 # zmm15 = (zmm23 * zmm24) + zmm15
	vbroadcastsd	8576(%rcx,%rdi,8), %zmm24
	vfmadd213pd	%zmm22, %zmm24, %zmm17 # zmm17 = (zmm24 * zmm17) + zmm22
	vfmadd213pd	%zmm21, %zmm24, %zmm18 # zmm18 = (zmm24 * zmm18) + zmm21
	vfmadd213pd	%zmm20, %zmm24, %zmm19 # zmm19 = (zmm24 * zmm19) + zmm20
	vfmadd231pd	%zmm24, %zmm23, %zmm16 # zmm16 = (zmm23 * zmm24) + zmm16
	addq	$248, %rdx
	incq	%rdi
	jne	L1408
	vmovupd	8928(%rsi), %zmm20
	vmovupd	8992(%rsi), %zmm21
	vmovupd	9056(%rsi), %zmm22
	movb	$127, %dl
	kmovd	%edx, %k1
	vmovupd	9120(%rsi), %zmm23 {%k1} {z}
	vbroadcastsd	7392(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm0 # zmm0 = (zmm20 * zmm24) + zmm0
	vfmadd231pd	%zmm24, %zmm21, %zmm1 # zmm1 = (zmm21 * zmm24) + zmm1
	vfmadd231pd	%zmm24, %zmm22, %zmm2 # zmm2 = (zmm22 * zmm24) + zmm2
	vfmadd231pd	%zmm24, %zmm23, %zmm3 # zmm3 = (zmm23 * zmm24) + zmm3
	vbroadcastsd	7688(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm4 # zmm4 = (zmm20 * zmm24) + zmm4
	vfmadd231pd	%zmm24, %zmm21, %zmm5 # zmm5 = (zmm21 * zmm24) + zmm5
	vfmadd231pd	%zmm24, %zmm22, %zmm6 # zmm6 = (zmm22 * zmm24) + zmm6
	vfmadd231pd	%zmm24, %zmm23, %zmm7 # zmm7 = (zmm23 * zmm24) + zmm7
	vbroadcastsd	7984(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm8 # zmm8 = (zmm20 * zmm24) + zmm8
	vfmadd231pd	%zmm24, %zmm21, %zmm9 # zmm9 = (zmm21 * zmm24) + zmm9
	vfmadd231pd	%zmm24, %zmm22, %zmm10 # zmm10 = (zmm22 * zmm24) + zmm10
	vfmadd231pd	%zmm24, %zmm23, %zmm11 # zmm11 = (zmm23 * zmm24) + zmm11
	vbroadcastsd	8280(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm12 # zmm12 = (zmm20 * zmm24) + zmm12
	vfmadd231pd	%zmm24, %zmm21, %zmm13 # zmm13 = (zmm21 * zmm24) + zmm13
	vfmadd231pd	%zmm24, %zmm22, %zmm14 # zmm14 = (zmm22 * zmm24) + zmm14
	vfmadd231pd	%zmm24, %zmm23, %zmm15 # zmm15 = (zmm23 * zmm24) + zmm15
	vbroadcastsd	8576(%rcx), %zmm24
	vfmadd213pd	%zmm17, %zmm24, %zmm20 # zmm20 = (zmm24 * zmm20) + zmm17
	vfmadd213pd	%zmm18, %zmm24, %zmm21 # zmm21 = (zmm24 * zmm21) + zmm18
	vfmadd213pd	%zmm19, %zmm24, %zmm22 # zmm22 = (zmm24 * zmm22) + zmm19
	vfmadd231pd	%zmm24, %zmm23, %zmm16 # zmm16 = (zmm23 * zmm24) + zmm16
	vmovupd	%zmm0, 5952(%rax)
	vmovupd	%zmm1, 6016(%rax)
	vmovupd	%zmm2, 6080(%rax)
	vmovupd	%zmm3, 6144(%rax) {%k1}
	vmovupd	%zmm4, 6200(%rax)
	vmovupd	%zmm5, 6264(%rax)
	vmovupd	%zmm6, 6328(%rax)
	vmovupd	%zmm7, 6392(%rax) {%k1}
	vmovupd	%zmm8, 6448(%rax)
	vmovupd	%zmm9, 6512(%rax)
	vmovupd	%zmm10, 6576(%rax)
	vmovupd	%zmm11, 6640(%rax) {%k1}
	vmovupd	%zmm12, 6696(%rax)
	vmovupd	%zmm13, 6760(%rax)
	vmovupd	%zmm14, 6824(%rax)
	vmovupd	%zmm15, 6888(%rax) {%k1}
	vmovupd	%zmm20, 6944(%rax)
	vmovupd	%zmm21, 7008(%rax)
	vmovupd	%zmm22, 7072(%rax)
	vmovupd	%zmm16, 7136(%rax) {%k1}
	popq	%rbx
	popq	%r12
	popq	%r13
	popq	%r14
	popq	%r15
	popq	%rbp
	vzeroupper
	retq
	nop

The most striking difference from looking at the assembly is the far higher density of vmul and vfmadd instructions. These occur within blocks within loops in Julia code (the blocks being sized to avoid register spills). The Eigen assembly is far harder to follow, looking much more complex. And runs around 4 times slower.

1 Like

A quick workaround could be by using external BLAS from Eigen, see: https://eigen.tuxfamily.org/dox/TopicUsingBlasLapack.html

Where specifically are the bad versions of the instructions? For things like vmovapd usually gives you better access because it has guaranteed alignment (else the throw)

  1. Thanks for the code! I’ll try to replicate this week, though last time I tried installing Julia it was a bit of a mess. Could you try this with non-fixed size matrices? We use dynamic sized matrices in the math library because we don’t know the data shape at compile time.

  2. I’d recommend making summary charts and graphs out of your results. As they are it’s very hard for me to make inference from your raw print outs.

  3. I’d post this to the Eigen email chain, I think they would have better insight into what’s going on here.

  4. A 4x speedup over Eigen would put you out of the scope of the y-axis of the blaze benchmarks. that’s pretty wild! I’ll look into this but I’d very much check your work to make sure everything adds up.

Actually looking at the mat-mul benchmark here your results would be faster than the multi-threaded version of blaze.

My tests were on a CPU with the avx-512 instruction set.

( 16*14*(2*24-1) / 90 ) * 10^3 is approximately 117,000.

Those Blaze benchmarks were run on a E5 Haswell EP CPU at 2.3 GHz.
I ran my benchmarks on a 7900X. I’d have to check what the clock speed is set to in the bios, but (a) probably >50% higher, and (b) it has the avx512 instruction set. Avx512 doubles vector width (doubling theoretical FLOPS), and doubles the number of registers (allowing for larger kernels, making it easier to reach peak FLOPS in matrix multiplication).

Where specifically are the bad versions of the instructions? For things like vmovapd usually gives you better access because it has guaranteed alignment (else the throw)
Lines 177-780 from my earlier Godbolt link:

.L24:
        xor     ebx, ebx
        xor     r11d, r11d
        jmp     .L2
Eigen::internal::gemm_pack_rhs<double, long, Eigen::internal::const_blas_data_mapper<double, long, 0>, 4, 0, false, false>::operator()(double*, Eigen::internal::const_blas_data_mapper<double, long, 0> const&, long, long, long, long) [clone .constprop.0]:
        push    rbp
        mov     rbp, rsp
        push    r15
        mov     r15, rdx
        push    r14
        push    r13
        mov     r13, rdi
        push    r12
        push    rbx
        and     rsp, -64
        sub     rsp, 8
        mov     QWORD PTR [rsp-88], rsi
        mov     QWORD PTR [rsp-96], rcx
        test    rcx, rcx
        lea     rax, [rcx+3]
        cmovns  rax, rcx
        add     rdx, 3
        test    r15, r15
        cmovns  rdx, r15
        and     rdx, -4
        and     rax, -4
        mov     QWORD PTR [rsp-64], rdx
        mov     QWORD PTR [rsp-80], rax
        jle     .L84
        mov     rax, rdx
        dec     rax
        shr     rax, 2
        lea     rsi, [4+rax*4]
        sal     rax, 4
        mov     QWORD PTR [rsp-112], rsi
        mov     QWORD PTR [rsp-104], rax
        mov     QWORD PTR [rsp-8], 0
        mov     QWORD PTR [rsp-24], 0
        mov     QWORD PTR [rsp], r15
        vmovdqa64       zmm5, ZMMWORD PTR .LC0[rip]
        vmovdqa64       zmm4, ZMMWORD PTR .LC1[rip]
        mov     r14, rdi
.L55:
        mov     rax, QWORD PTR [rsp-88]
        mov     rbx, QWORD PTR [rsp-24]
        mov     r10, QWORD PTR [rax+8]
        mov     r9, QWORD PTR [rax]
        lea     rax, [rbx+2]
        imul    rax, r10
        mov     r15, rbx
        lea     r11, [rbx+1]
        mov     QWORD PTR [rsp-32], rax
        lea     rdi, [r9+rax*8]
        lea     rax, [rbx+3]
        imul    rax, r10
        imul    r15, r10
        imul    r11, r10
        mov     QWORD PTR [rsp-40], rax
        lea     r8, [r9+rax*8]
        mov     rax, QWORD PTR [rsp-8]
        mov     r10, QWORD PTR [rsp-64]
        lea     rdx, [r14+rax*8]
        xor     eax, eax
        cmp     QWORD PTR [rsp-64], 0
        lea     rcx, [r9+r15*8]
        lea     rsi, [r9+r11*8]
        jle     .L85
.L47:
        vmovupd ymm7, YMMWORD PTR [rcx+rax*8]
        sub     rdx, -128
        vunpckhpd       ymm2, ymm7, YMMWORD PTR [rsi+rax*8]
        vunpcklpd       ymm0, ymm7, YMMWORD PTR [rsi+rax*8]
        vmovupd ymm7, YMMWORD PTR [rdi+rax*8]
        vunpckhpd       ymm3, ymm7, YMMWORD PTR [r8+rax*8]
        vunpcklpd       ymm1, ymm7, YMMWORD PTR [r8+rax*8]
        vshuff64x2      ymm6, ymm2, ymm3, 0
        add     rax, 4
        vshuff64x2      ymm2, ymm2, ymm3, 3
        vshuff64x2      ymm3, ymm0, ymm1, 0
        vshuff64x2      ymm0, ymm0, ymm1, 3
        vmovupd YMMWORD PTR [rdx-128], ymm3
        vmovupd YMMWORD PTR [rdx-96], ymm6
        vmovupd YMMWORD PTR [rdx-64], ymm0
        vmovupd YMMWORD PTR [rdx-32], ymm2
        cmp     r10, rax
        jg      .L47
        mov     rax, QWORD PTR [rsp-8]
        mov     rbx, QWORD PTR [rsp-104]
        mov     rdx, QWORD PTR [rsp-112]
        lea     rax, [rax+16+rbx]
        mov     QWORD PTR [rsp-8], rax
.L48:
        cmp     QWORD PTR [rsp], rdx
        jle     .L45
        mov     rbx, QWORD PTR [rsp]
        mov     r12, QWORD PTR [rsp-8]
        sub     rbx, rdx
        mov     rax, rbx
        dec     rax
        lea     r10, [r12+rbx*4]
        mov     QWORD PTR [rsp-48], rax
        mov     r13, QWORD PTR [rsp-32]
        lea     rax, [r14+r12*8]
        lea     r12, [r14+r10*8]
        lea     r10, [r15+rdx]
        mov     QWORD PTR [rsp-56], rbx
        lea     rbx, [r9+r10*8]
        lea     r10, [r11+rdx]
        mov     QWORD PTR [rsp-16], r12
        lea     r12, [r9+r10*8]
        lea     r10, [r13+0+rdx]
        lea     r13, [r9+r10*8]
        mov     r10, QWORD PTR [rsp-40]
        add     r15, QWORD PTR [rsp]
        add     r10, rdx
        lea     r10, [r9+r10*8]
        mov     QWORD PTR [rsp-72], r10
        lea     r10, [r9+r15*8]
        cmp     rax, r10
        setnb   r15b
        cmp     rbx, QWORD PTR [rsp-16]
        setnb   r10b
        or      r15d, r10d
        mov     r10, QWORD PTR [rsp]
        add     r10, r11
        lea     r10, [r9+r10*8]
        cmp     rax, r10
        setnb   r11b
        cmp     QWORD PTR [rsp-16], r12
        setbe   r10b
        or      r10d, r11d
        and     r10d, r15d
        cmp     QWORD PTR [rsp-48], 6
        seta    r11b
        and     r10d, r11d
        mov     r11, QWORD PTR [rsp-32]
        add     r11, QWORD PTR [rsp]
        lea     r11, [r9+r11*8]
        cmp     rax, r11
        setnb   r11b
        cmp     QWORD PTR [rsp-16], r13
        setbe   r15b
        or      r11d, r15d
        test    r10b, r11b
        je      .L49
        mov     r11, QWORD PTR [rsp-40]
        add     r11, QWORD PTR [rsp]
        lea     r9, [r9+r11*8]
        mov     r11, QWORD PTR [rsp-72]
        cmp     rax, r9
        setnb   r9b
        cmp     QWORD PTR [rsp-16], r11
        setbe   r10b
        or      r9b, r10b
        je      .L49
        mov     r10, QWORD PTR [rsp-56]
        xor     r9d, r9d
        shr     r10, 3
        sal     r10, 6
.L50:
        vmovupd zmm0, ZMMWORD PTR [rbx+r9]
        vmovupd zmm3, ZMMWORD PTR [r13+0+r9]
        vmovupd zmm1, ZMMWORD PTR [r12+r9]
        vmovupd zmm6, ZMMWORD PTR [r11+r9]
        vmovapd zmm2, zmm0
        vpermt2pd       zmm2, zmm5, zmm3
        vpermt2pd       zmm0, zmm4, zmm3
        vmovapd zmm3, zmm1
        vpermt2pd       zmm3, zmm5, zmm6
        vpermt2pd       zmm1, zmm4, zmm6
        vmovapd zmm6, zmm2
        vpermt2pd       zmm2, zmm4, zmm3
        vmovupd ZMMWORD PTR [rax+64+r9*4], zmm2
        vmovapd zmm2, zmm0
        vpermt2pd       zmm6, zmm5, zmm3
        vpermt2pd       zmm2, zmm5, zmm1
        vpermt2pd       zmm0, zmm4, zmm1
        vmovupd ZMMWORD PTR [rax+r9*4], zmm6
        vmovupd ZMMWORD PTR [rax+128+r9*4], zmm2
        vmovupd ZMMWORD PTR [rax+192+r9*4], zmm0
        add     r9, 64
        cmp     r9, r10
        jne     .L50
        mov     rbx, QWORD PTR [rsp-56]
        mov     r15, QWORD PTR [rsp-8]
        mov     rax, rbx
        and     rax, -8
        add     rdx, rax
        lea     r10, [r15+rax*4]
        cmp     rbx, rax
        je      .L52
        vmovsd  xmm0, QWORD PTR [rcx+rdx*8]
        lea     rax, [0+r10*8]
        vmovsd  QWORD PTR [r14+r10*8], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+rdx*8]
        mov     rbx, QWORD PTR [rsp]
        vmovsd  QWORD PTR [r14+8+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+rdx*8]
        lea     r11, [rdx+1]
        vmovsd  QWORD PTR [r14+16+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+rdx*8]
        add     r10, 4
        vmovsd  QWORD PTR [r14+24+rax], xmm0
        cmp     rbx, r11
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+r11*8]
        lea     rax, [0+r10*8]
        vmovsd  QWORD PTR [r14+r10*8], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+r11*8]
        lea     r10, [rdx+2]
        vmovsd  QWORD PTR [r14+8+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+r11*8]
        lea     r9, [0+r11*8]
        vmovsd  QWORD PTR [r14+16+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+r11*8]
        vmovsd  QWORD PTR [r14+24+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+8+r9]
        lea     r10, [rdx+3]
        vmovsd  QWORD PTR [r14+32+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+8+r9]
        vmovsd  QWORD PTR [r14+40+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+8+r9]
        vmovsd  QWORD PTR [r14+48+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+8+r9]
        vmovsd  QWORD PTR [r14+56+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+16+r9]
        lea     r10, [rdx+4]
        vmovsd  QWORD PTR [r14+64+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+16+r9]
        vmovsd  QWORD PTR [r14+72+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+16+r9]
        vmovsd  QWORD PTR [r14+80+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+16+r9]
        vmovsd  QWORD PTR [r14+88+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+24+r9]
        lea     r10, [rdx+5]
        vmovsd  QWORD PTR [r14+96+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+24+r9]
        vmovsd  QWORD PTR [r14+104+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+24+r9]
        vmovsd  QWORD PTR [r14+112+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+24+r9]
        vmovsd  QWORD PTR [r14+120+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+32+r9]
        add     rdx, 6
        vmovsd  QWORD PTR [r14+128+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+32+r9]
        vmovsd  QWORD PTR [r14+136+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+32+r9]
        vmovsd  QWORD PTR [r14+144+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+32+r9]
        vmovsd  QWORD PTR [r14+152+rax], xmm0
        cmp     rbx, rdx
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+40+r9]
        vmovsd  QWORD PTR [r14+160+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+40+r9]
        vmovsd  QWORD PTR [r14+168+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+40+r9]
        vmovsd  QWORD PTR [r14+176+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+40+r9]
        vmovsd  QWORD PTR [r14+184+rax], xmm0
.L52:
        mov     rax, QWORD PTR [rsp-8]
        mov     rsi, QWORD PTR [rsp-48]
        lea     rax, [rax+4+rsi*4]
        mov     QWORD PTR [rsp-8], rax
.L45:
        add     QWORD PTR [rsp-24], 4
        mov     rax, QWORD PTR [rsp-24]
        cmp     QWORD PTR [rsp-80], rax
        jg      .L55
        mov     r15, QWORD PTR [rsp]
        mov     r13, r14
.L54:
        mov     rsi, QWORD PTR [rsp-80]
        cmp     QWORD PTR [rsp-96], rsi
        jle     .L82
        mov     rax, QWORD PTR [rsp-88]
        mov     r8, QWORD PTR [rax]
        mov     r14, QWORD PTR [rax+8]
        test    r15, r15
        jle     .L82
        mov     r9, r15
        and     r9, -8
        lea     rdi, [r9+3]
        mov     QWORD PTR [rsp-40], rdi
        lea     rdi, [r9+4]
        mov     QWORD PTR [rsp-48], rdi
        lea     rdi, [r9+5]
        mov     rcx, rsi
        mov     QWORD PTR [rsp-56], rdi
        lea     rdi, [r9+6]
        imul    rcx, r14
        mov     QWORD PTR [rsp-64], rdi
        lea     rdi, [r15-1]
        mov     rbx, QWORD PTR [rsp-8]
        lea     rax, [0+r14*8]
        mov     r10, rsi
        mov     QWORD PTR [rsp-24], rdi
        mov     rsi, r15
        mov     QWORD PTR [rsp-8], r14
        lea     r12, [r9+2]
        mov     QWORD PTR [rsp], rax
        shr     rsi, 3
        lea     rax, [0+r15*8]
        mov     QWORD PTR [rsp-16], rax
        mov     QWORD PTR [rsp-32], r12
        lea     rdx, [r8+rcx*8]
        lea     rax, [r13+0+rbx*8]
        sal     rsi, 6
        lea     r11, [r9+1]
        mov     r12, rbx
.L60:
        lea     rdi, [rdx+64]
        cmp     rax, rdi
        lea     rdi, [rax+64]
        setnb   bl
        cmp     rdx, rdi
        setnb   dil
        or      bl, dil
        mov     edi, 0
        je      .L56
        cmp     QWORD PTR [rsp-24], 6
        jbe     .L56
.L57:
        vmovupd zmm4, ZMMWORD PTR [rdx+rdi]
        vmovupd ZMMWORD PTR [rax+rdi], zmm4
        add     rdi, 64
        cmp     rdi, rsi
        jne     .L57
        lea     rdi, [r12+r9]
        cmp     r9, r15
        je      .L59
        lea     rbx, [rcx+r9]
        vmovsd  xmm0, QWORD PTR [r8+rbx*8]
        vmovsd  QWORD PTR [r13+0+rdi*8], xmm0
        inc     rdi
        cmp     r15, r11
        jle     .L59
        lea     r14, [rcx+r11]
        vmovsd  xmm0, QWORD PTR [r8+r14*8]
        lea     rbx, [0+rdi*8]
        vmovsd  QWORD PTR [r13+0+rdi*8], xmm0
        mov     rdi, QWORD PTR [rsp-32]
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-40]
        vmovsd  QWORD PTR [r13+8+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-48]
        vmovsd  QWORD PTR [r13+16+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-56]
        vmovsd  QWORD PTR [r13+24+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-64]
        vmovsd  QWORD PTR [r13+32+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        vmovsd  QWORD PTR [r13+40+rbx], xmm0
.L59:
        inc     r10
        add     r12, r15
        add     rdx, QWORD PTR [rsp]
        add     rax, QWORD PTR [rsp-16]
        add     rcx, QWORD PTR [rsp-8]
        cmp     QWORD PTR [rsp-96], r10
        jne     .L60
.L82:
        vzeroupper
        lea     rsp, [rbp-40]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     r15
        pop     rbp
        ret
.L56:
        vmovsd  xmm0, QWORD PTR [rdx+rdi*8]
        vmovsd  QWORD PTR [rax+rdi*8], xmm0
        inc     rdi
        cmp     r15, rdi
        jne     .L56
        jmp     .L59
.L49:
        mov     r9, QWORD PTR [rsp]
.L53:
        vmovsd  xmm0, QWORD PTR [rcx+rdx*8]
        add     rax, 32
        vmovsd  QWORD PTR [rax-32], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+rdx*8]
        vmovsd  QWORD PTR [rax-24], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+rdx*8]
        vmovsd  QWORD PTR [rax-16], xmm0
        vmovsd  xmm0, QWORD PTR [r8+rdx*8]
        inc     rdx
        vmovsd  QWORD PTR [rax-8], xmm0
        cmp     r9, rdx
        jne     .L53
        jmp     .L52
.L85:
        xor     edx, edx
        jmp     .L48
.L84:
        mov     QWORD PTR [rsp-8], 0
        jmp     .L54
Eigen::internal::gebp_kernel<double, double, long, Eigen::internal::blas_data_mapper<double, long, 0, 0>, 12, 4, false, false>::operator()(Eigen::internal::blas_data_mapper<double, long, 0, 0> const&, double const*, double const*, long, long, long, double, long, long, long, long) [clone .constprop.0]:
        push    rbp
        lea     rax, [r9+3]
        vmovapd zmm19, zmm0
        mov     rbp, rsp
        push    r15
        mov     r15, rdx
        push    r14
        mov     r14, rdi
        mov     rdi, rcx
        push    r13
        mov     r13, r8
        push    r12
        push    rbx
        mov     rbx, r8
        and     rsp, -64
        sub     rsp, 136
        test    r9, r9
        cmovns  rax, r9
        mov     r11, rax
        mov     QWORD PTR [rsp-8], rdx
        mov     rax, rcx
        movabs  rdx, 3074457345618258603
        imul    rdx
        sar     rdi, 63
        and     r11, -4
        mov     rax, rdx
        sar     rax
        sub     rax, rdi
        lea     rax, [rax+rax*2]
        lea     rdi, [0+rax*4]
        mov     rax, rcx
        sub     rax, rdi
        mov     rdx, rax
        lea     rax, [rax+7]
        cmovns  rax, rdx
        mov     QWORD PTR [rsp], rcx
        and     rax, -8
        add     rax, rdi
        mov     QWORD PTR [rsp-24], rax
        test    rcx, rcx
        lea     rax, [rcx+3]
        cmovns  rax, rcx
        xor     edx, edx
        and     rax, -4
        mov     QWORD PTR [rsp-16], rax
        mov     eax, 1012
        sub     rax, r8
        sal     rax, 5
        mov     rcx, rax
        mov     QWORD PTR [rsp-56], rax
        lea     rax, [r8+r8*2]
        sal     rax, 5
        mov     r10, rax
        mov     QWORD PTR [rsp+96], rax
        mov     rax, rcx
        div     r10
        mov     QWORD PTR [rsp-32], rsi
        mov     QWORD PTR [rsp+56], r9
        mov     QWORD PTR [rsp+64], r11
        mov     QWORD PTR [rsp-40], rdi
        and     rbx, -8
        mov     edx, 1
        test    rax, rax
        cmovle  rax, rdx
        lea     rax, [rax+rax*2]
        sal     rax, 2
        mov     QWORD PTR [rsp-80], rax
        test    rdi, rdi
        jle     .L87
        imul    rax, r8
        imul    r11, r8
        mov     QWORD PTR [rsp+40], rsi
        sal     rax, 3
        mov     QWORD PTR [rsp-72], rax
        mov     rax, r8
        sal     rax, 5
        mov     QWORD PTR [rsp+8], rax
        lea     rax, [rbx-1]
        shr     rax, 3
        inc     rax
        mov     rcx, rax
        sal     rcx, 8
        lea     rdx, [rax+rax*2]
        sal     rax, 6
        mov     QWORD PTR [rsp+32], rcx
        mov     QWORD PTR [rsp+24], rax
        lea     rcx, [0+r8*8]
        mov     rax, rbx
        mov     QWORD PTR [rsp+16], rcx
        mov     r12, rdx
        lea     rcx, [r15+r11*8]
        neg     rax
        sal     r12, 8
        sal     rax, 3
        mov     QWORD PTR [rsp-64], rcx
        mov     QWORD PTR [rsp-48], 0
        mov     QWORD PTR [rsp+72], rax
        mov     rax, r12
        vbroadcastsd    ymm20, xmm0
        mov     r12, r8
        mov     r13, rax
.L101:
        mov     rcx, QWORD PTR [rsp-48]
        mov     rdi, QWORD PTR [rsp-80]
        mov     QWORD PTR [rsp+48], rcx
        mov     rax, rcx
        add     rcx, rdi
        mov     rdi, QWORD PTR [rsp-40]
        mov     QWORD PTR [rsp-48], rcx
        cmp     rdi, rcx
        cmovle  rcx, rdi
        cmp     QWORD PTR [rsp+64], 0
        mov     QWORD PTR [rsp+128], rcx
        jle     .L88
        cmp     rcx, rax
        jle     .L89
        xor     eax, eax
        mov     r11, QWORD PTR [rsp-8]
        mov     QWORD PTR [rsp+80], r13
        mov     r13, rax
.L92:
        lea     rax, [r13+1]
        mov     QWORD PTR [rsp+120], rax
        lea     rax, [r13+2]
        mov     QWORD PTR [rsp+112], rax
        lea     rax, [r13+3]
        mov     QWORD PTR [rsp+104], rax
        mov     rax, QWORD PTR [rsp+32]
        mov     r10, QWORD PTR [rsp+40]
        add     rax, r11
        mov     QWORD PTR [rsp+88], rax
        mov     r9, QWORD PTR [rsp+48]
.L90:
        mov     rdx, QWORD PTR [r14+8]
        mov     rcx, QWORD PTR [r14]
        mov     rsi, rdx
        imul    rsi, r13
        mov     rax, r10
        prefetcht0      [r10]
        add     rsi, r9
        lea     r8, [rcx+rsi*8]
        mov     rsi, QWORD PTR [rsp+120]
        prefetcht0      [r8]
        imul    rsi, rdx
        prefetcht0      [r11]
        add     rsi, r9
        lea     rdi, [rcx+rsi*8]
        mov     rsi, QWORD PTR [rsp+112]
        prefetcht0      [rdi]
        imul    rsi, rdx
        imul    rdx, QWORD PTR [rsp+104]
        add     rsi, r9
        add     rdx, r9
        lea     rsi, [rcx+rsi*8]
        lea     rcx, [rcx+rdx*8]
        prefetcht0      [rsi]
        prefetcht0      [rcx]
        test    rbx, rbx
        jle     .L207
        vxorpd  xmm0, xmm0, xmm0
        lea     rdx, [r10+512]
        mov     rax, r11
        vmovapd ymm2, ymm0
        vmovapd ymm3, ymm0
        vmovapd ymm18, ymm0
        vmovapd ymm6, ymm0
        vmovapd ymm7, ymm0
        vmovapd ymm8, ymm0
        vmovapd ymm9, ymm0
        vmovapd ymm10, ymm0
        vmovapd ymm11, ymm0
        vmovapd ymm12, ymm0
        vmovapd ymm13, ymm0
        xor     r15d, r15d

For fixed size matrix multiplication at these sizes, packing is unnecessary (although it is necessary for large sizes, where it prevents memory throttling; packing also has an O(N^2) cost and because matrix multiplication is O(N^3) is thus also “free” at larger sizes).

The kernel code itself otherwise looks more or less fine, other than that they don’t support the avx512 instruction set:

.L97:
        prefetcht0      [rax]
        vmovapd ymm16, YMMWORD PTR [rdx-448]
        vmovapd ymm15, YMMWORD PTR [rdx-512]
        vmovapd ymm14, YMMWORD PTR [rdx-480]
        vbroadcastsd    ymm1, QWORD PTR [rax]
        vmovapd ymm5, ymm16
        vfmadd231pd ymm15, ymm1, ymm13
        vfmadd231pd ymm5, ymm1, ymm11
        vfmadd231pd ymm14, ymm1, ymm12
        vbroadcastsd    ymm1, QWORD PTR [rax+8]
        vfmadd231pd ymm5, ymm1, ymm8
        vmovapd ymm5, ymm18
        vbroadcastsd    ymm18, QWORD PTR [rax+24]
        vmovapd ymm4, ymm16
        vfmadd231pd ymm15, ymm1, ymm10
        vfmadd231pd ymm14, ymm1, ymm9
        prefetcht0      [rdx]
        vbroadcastsd    ymm1, QWORD PTR [rax+16]
        vfmadd231pd ymm4, ymm1, ymm5
        vfmadd231pd ymm15, ymm1, ymm7
        vmovapd ymm4, ymm18
        vfmadd231pd ymm14, ymm1, ymm6
        vfmadd231pd ymm15, ymm4, ymm3

Note the 256-bit ymm registers instead of 512-bit zmm registers, despite compiling with -march=skylake-avx512 -mprefer-vector-width=512.

That said, my statement should be qualified “poor vectorization at small fixed sizes with the avx512 instruction set”, because packing isn’t helpful yet at those sizes, and it doesn’t seem to use that instruction set. Given larger matrices (where you want packing) and if you only have avx2 (such as the Haswell E5-2650V3 from Blaze’s benchmarks), my criticisms are non-applicable.
[I looked at Eigen’s asm more closely in making this post; earlier I just saw that the asm was long, that it included things like `vunpckhpd`, `vshuff64`, etc, that it didn’t use `zmm` registers, and that performance was slow in those examples, and perhaps judged it unfairly without first looking more closely at what it was trying to do.]

  1. My Julia matrix multuiplication library doesn’t really support dynamically sized matrices yet. Intel MKL JIT would probably be a great choice if it can be combined with Eigen like @andre.pfeuffer’s link, although MKL may have poor performance on AMD CPUs (ie, it may simply use SSE instead of avx? I’m not sure about this).
  2. I’ll try and get around to it.

It would be nice to see the benchmark for yours with the blaze benchmarks on your system! Tone often gets lost on the internet I hope you don’t think I was dismissing your work at all, it’s v cool! You certainly seem to know more about these things then I do

After this comment I was trying to look where in the packet math that builds this would be. It looks like the avx512 gemm stuff may have had an update in the pipeline for Eigen 3.4

http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1642

// yadayada
        vmovapd zmm11, ZMMWORD PTR [rdx+256]
        vfmadd231pd     zmm6, zmm11, QWORD PTR [rax-256]{1to8}
        vfmadd231pd     zmm4, zmm11, QWORD PTR [rax-248]{1to8}
        vfmadd231pd     zmm9, zmm11, QWORD PTR [rax-240]{1to8}
        vfmadd132pd     zmm11, zmm2, QWORD PTR [rax-232]{1to8}
// yadayada

Though there is still a bunch of yadayada and the unpacking that will probs leave your implementation faster for small sizes

Again no worries! It’s good to be excited about things

Intel “tries” to have good performance on AMD, but yeah it probably won’t be as nice. Maybe we can look into checking if a user has mlk setup and build Eigen with it if they do.

Thanks for diving down into the assembly. Some of this sounds like it might be better discussed on the Eigen lists. They’re very responsive if there are general techniques that can speed things up. You might also ask them how to get the most out of their code, as it may require some flags to set up.

It’s worse than that—it totally defeats locality and we have to do a lot of copying to get things into vector form to pass into other algorithms.

If it’s Eigen<var, -1, -1>, then it’s effectively Eigen<vari*, -1, -1>. Each pointer points to a contiguous pair of value and adjoint. These pointers don’t need to point to things contiguous (just imagine an assignment, which sets a new pointer).

Only after unpacking that mess into an Eigen<double, -1, -1>, where it’s a contiguous array of double values in column-major order.

The vectorization here is of the function, not the assembly. We need clearer names. There’s a reduction inside the bernoulli_logit_glm so that it returns a single log probability mass. Was the Julia version doing reverse-mode autodiff?

It’s really more geared toward big dynamic matrices. Partly we just don’t care so much when the matrices are small as everything’s fast.

Of course—it’s just general memory locality issues to start.

Nobody’s found a good data structure that lets you do everything we want, like have matrices where you can set elements for reverse-mode autodiff and where you get memory locality. You can design systems that do matrix derivatives, but then accessing their elements is painful. Everyone wants to go down this path when they first see our (and others’) code. I invite you to try, because if someone could genuinely solve the locality problem in a useful way, it’d be huge.

The best thing to study for our autodiff aside from the code is the arXiv paper.

One of the reasons we used Eigen was that it supports templating. But we wind up doing most of the packing/unpacking on Eigen<var, ...> structures, which eats up the gains from something like SIMD.

Sure; here are the same three sizes I tested with Eigen:

julia> using BenchmarkTools, LinearAlgebra, PaddedMatrices, Random

julia> using PaddedMatrices: AbstractMutableFixedSizePaddedMatrix

julia> const CXXDIR = "/home/chriselrod/Documents/progwork/Cxx";

julia> const PADDED_BLAZE_LIB = joinpath(CXXDIR, "libblazemul_padded.so");

julia> const UNPADDED_BLAZE_LIB = joinpath(CXXDIR, "libblazemul_unpadded.so");

julia> @inline aligned_8(N) = (N & 7) == 0
aligned_8 (generic function with 1 method)

julia> @generated function blazemul!(
           C::AbstractMutableFixedSizePaddedMatrix{M,N,T,PC},
           A::AbstractMutableFixedSizePaddedMatrix{M,K,T,PA},
           B::AbstractMutableFixedSizePaddedMatrix{K,N,T,PB}
       ) where {M,N,K,T,PC,PA,PB}
           lib = if aligned_8(PC) && aligned_8(PA) && aligned_8(PB)
               PADDED_BLAZE_LIB
           else
               UNPADDED_BLAZE_LIB
           end
           func = QuoteNode(Symbol(:mul_,M,:x,K,:times,K,:x,N))
           quote
               ccall(
                   ($func, $lib), Cvoid,
                   (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble}),
                   C, A, B
               )
           end
       end
blazemul! (generic function with 1 method)

julia> align_64(x) = (x + 63) & ~63
align_64 (generic function with 1 method)

julia> align_64(x, ptr::Ptr{T}) where {T} = align_64(x*sizeof(T) + ptr)
align_64 (generic function with 2 methods)

julia> const PTR = Base.unsafe_convert(Ptr{Float64}, Libc.malloc(1<<21)); # more than enough space

julia> align_64(x::Ptr{T}) where {T} = reinterpret(Ptr{T},align_64(reinterpret(UInt,x)))
align_64 (generic function with 3 methods)

julia> approx_equal(A,B) = all(x -> isapprox(x[1],x[2]), zip(A,B))
approx_equal (generic function with 1 method)

julia> A16x24 = PtrMatrix{16,24,Float64,16}(align_64(PTR)); randn!(A16x24);

julia> B24x14 = PtrMatrix{24,14,Float64,24}(align_64(16*24,pointer(A16x24))); randn!(B24x14);

julia> C16x14b = PtrMatrix{16,14,Float64,16}(align_64(24*14,pointer(B24x14)));

julia> C16x14j = PtrMatrix{16,14,Float64,16}(align_64(16*14,pointer(C16x14b)));

julia> @benchmark blazemul!($C16x14b, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     101.626 ns (0.00% GC)
  median time:      102.874 ns (0.00% GC)
  mean time:        105.372 ns (0.00% GC)
  maximum time:     512.652 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     942

julia> @benchmark mul!($C16x14j, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     91.679 ns (0.00% GC)
  median time:      91.904 ns (0.00% GC)
  mean time:        93.996 ns (0.00% GC)
  maximum time:     136.302 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     955

julia> approx_equal(C16x14b, C16x14j)
true

julia> # Prime sizes
       A31x37 = PtrMatrix{31,37,Float64,31}(align_64(16*14,pointer(C16x14j))); randn!(A31x37);

julia> B37x29 = PtrMatrix{37,29,Float64,37}(align_64(31*37,pointer(A31x37))); randn!(B37x29);

julia> C31x29b = PtrMatrix{31,29,Float64,31}(align_64(37*29,pointer(B37x29)));

julia> C31x29j = PtrMatrix{31,29,Float64,31}(align_64(31*29,pointer(C31x29b)));

julia> @benchmark blazemul!($C31x29b, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.868 μs (0.00% GC)
  median time:      1.886 μs (0.00% GC)
  mean time:        1.929 μs (0.00% GC)
  maximum time:     4.423 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark mul!($C31x29j, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     582.324 ns (0.00% GC)
  median time:      602.297 ns (0.00% GC)
  mean time:        612.824 ns (0.00% GC)
  maximum time:     2.398 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     182

julia> approx_equal(C31x29b, C31x29j)
true

julia> A31x37p = PtrMatrix{31,37,Float64,32}(align_64(31*29,pointer(C31x29j))); randn!(A31x37p);

julia> B37x29p = PtrMatrix{37,29,Float64,40}(align_64(32*37,pointer(A31x37p))); randn!(B37x29p);

julia> C31x29bp = PtrMatrix{31,29,Float64,32}(align_64(40*29,pointer(B37x29p)));

julia> C31x29jp = PtrMatrix{31,29,Float64,32}(align_64(32*29,pointer(C31x29bp)));

julia> @benchmark blazemul!($C31x29bp, $A31x37p, $B37x29p)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     628.894 ns (0.00% GC)
  median time:      648.347 ns (0.00% GC)
  mean time:        663.371 ns (0.00% GC)
  maximum time:     972.388 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     170

julia> @benchmark mul!($C31x29jp, $A31x37p, $B37x29p)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     556.533 ns (0.00% GC)
  median time:      575.101 ns (0.00% GC)
  mean time:        584.645 ns (0.00% GC)
  maximum time:     1.001 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     184

julia> approx_equal(C31x29bp, C31x29jp)
true 

Summary (minimum times in nanoseconds):
16x24 * 24x14
Eigen: 446.9 (g++)
PaddedMatrices.jl: 91.5
Blaze: 101.7

31x37 * 37x29
Eigen: 2346 (clang++)
PaddedMatrices: 594.2
Blaze: 1885

31x37 * 37x29 with padding:
PaddedMatrices: 570.8
Blaze: 643.8

I compiled Blaze with

g++ -Ofast -march=native -mprefer-vector-width=512 -shared -fPIC blaze_test.cpp -o libblazemul_padded.so
g++ -DBLAZE_USE_PADDING=0 -Ofast -march=native -mprefer-vector-width=512 -shared -fPIC blaze_test.cpp -o libblazemul_unpadded.so

For with and without padding, respectively.

No worries. I did think you think you might be implying that my results were invalid – ie, that the benchmarked code wasn’t doing what I thought it is doing. Eg, either that the matrix multiplication code was broken and producing invalid results, or that the compiler was able to recognize constants in the benchmark and optimize them away / hoist them out.

I hope I don’t come across as dismissive or disrespectful either. I’m here because I’m a fan of the Stan project / MCMC and HMC/NUTS. It would be great if Stan could be made faster.

It may still be behind until they add kernels to take advantage of all 32 floating point registers:

It could be reduced by exploiting the additional registers offered by AVX512 but this has yet to be implemented.

At least, until relatively large sizes. I’ll eventually look into adding packing/unpacking to my library. But it isn’t a high priority. Libraries like MKL are so well optimized, I may as well use it (ie, use the library Julia is linked with, and let users choose which library that is to the extant they can/care).

Thanks for those explanations. That all sounds far from ideal.

Does this happen automatically? If so, then it’ll at least scale well / that unpacking will be “free” asymptotically (with respect to matrix multiplication).

Well, the confusing thing is that (in my experience) that is normally what people coming from interpreted languages mean by vectorization, while people coming from compiled langauges are normally talking about the assembly / autovectorizer.
Most folks I talk to in person mostly have R experience, so it’d be nice if there were less ambiguous language to use in general.

Returns the logdensity (reduced to fit into one register, ie down to 4 or 8 doubles) and analytical gradients (equivalent to the Jacobians here, given that the functions are N-to-1).
Multiple dispatch makes it easy to take advantage of sparsity; I can define a lot of common lpdfs and other commonly used functions, and have them return diagonal, block-diagonal, or even specially-typed objects representing the Jacobians.

That is the jist of my approach:

  1. Write optimized versions of functions that also return (sparse, when possible) representations of the Jacobians.
  2. Use source to source transformation to call these functions.
  3. Multiple dispatch dispatches to the appropriate update functions for each Jacobian type, returning the derivative.

This lets me avoid having to use any dual number types / data layouts that may interfere with vectorization.
It isn’t as flexible as Stan’s approach (I don’t even have loops or control flow like if/else working!), but I plan on moving to a hybrid approach where I just use some other autodiff library to differentiate the parts that my library can’t. Probably Zygote once it gets more mature.

As an example of the “write optimized versions of functions”, the multi normal cholesky function:

#include <cmath>
#include <stan/math.hpp>

extern "C" {
double normal_lpdf_cholesky(double* dY, double* dmu, double* dL, double* Y, double* mu, double* L, long N, long P){
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,Eigen::Dynamic> vY(P,N), vL(P,P);
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,1> vmu(P), vY1col(P);
  for (long p = 0; p < P; p++){
    vmu(p) = mu[p];
    for (long pr = 0; pr < p; pr++){
      vL(pr,p) = 0;
    }
    for (long pr = p; pr < P; pr++){
      vL(pr,p) = L[pr + p*P];
    }
  }

  stan::math::var lp = 0;

  for (long n = 0; n < N; n++){
    for (long p = 0; p < P; p++){
      vY1col(p) = Y[p + n*P];
    }
    lp += stan::math::multi_normal_cholesky_log<true>(vY1col, vmu, vL);
    for (long p = 0; p < P; p++){
      vY(p,n) = vY1col(p);
    }    
  }

  lp.grad();

  for (long p = 0; p < P; p++){
    dmu[p] = vmu(p).adj();
    for (long pr = p; pr < P; pr++){
      dL[pr + p*P] = vL(pr,p).adj();
    }
  }
  for (long n = 0; n < N; n++){
    for (long p = 0; p < P; p++){
      dY[p + n*P] = vY(p,n).adj();
    }
  }  
  return lp.val();
}

}

which I compiled with

g++ -O3 -march=native -shared -fPIC -DNDEBUG -DEIGEN_NO_DEBUG -DADEPT_STACK_THREAD_UNSAFE -fno-signed-zeros -fno-trapping-math -fassociative-math -pipe -feliminate-unused-debug-types -I$STAN_MATH -I$STAN_MATH/lib/eigen_3.3.3 -I$STAN_MATH/lib/boost_1.69.0 -I$STAN_MATH/lib/sundials_4.1.0/include stan_normal_test.cpp -o libstan_normal_chol.so

If those flags can be improved, let me know.

julia> using ProbabilityDistributions, BenchmarkTools, PaddedMatrices, LinearAlgebra, SIMDPirates, VectorizationBase

julia> using ProbabilityDistributions: Normal, ∂Normal

julia> using StructuredMatrices: MutableLowerTriangularMatrix, MutableSymmetricMatrixL, choltest!

julia> approx_equal(A,B) = all(x -> isapprox(x[1],x[2]), zip(A,B))
approx_equal (generic function with 1 method)

julia> N, P = 480, 20;

julia> Y = @Mutable randn(N, P);

julia> μ = @Mutable randn(P);

julia> Σ = (@Constant randn(P,3P>>1)) |> x -> x * x' |> MutableSymmetricMatrixL;

julia> L = MutableLowerTriangularMatrix{P,Float64}(undef);

julia> choltest!(L, Σ); # test because the function is a work in progress

julia> Ya′ = Array(Y');

julia> μa = Array(μ);

julia> La = LowerTriangular(Array(L));

julia> const SPTR = PaddedMatrices.StackPointer(VectorizationBase.align(Libc.malloc(1<<26) + 63));

julia> # Not calling ∂Normal directly as workaround for https://github.com/JuliaLang/julia/issues/32414
       function normal_chol(sptr, Y, μ, L) 
           sptr2, (lp, ∂Y, ∂μ, ∂L) = ∂Normal(sptr, Y, μ, L, Val{(true,true,true)}())
           SIMDPirates.vsum(lp), ∂Y, ∂μ, ∂L 
       end
normal_chol (generic function with 1 method)

julia> # g++ -O3 -march=native -shared -fPIC -DNDEBUG -DEIGEN_NO_DEBUG -DADEPT_STACK_THREAD_UNSAFE -fno-signed-zeros -fno-trapping-math -fassociative-math -pipe -feliminate-unused-debug-types -I$STAN_MATH -I$STAN_MATH/lib/eigen_3.3.3 -I$STAN_MATH/lib/boost_1.69.0 -I$STAN_MATH/lib/sundials_4.1.0/include stan_normal_test.cpp -o libstan_normal_chol.so
       const STANDIR = "/home/chriselrod/Documents/progwork/Cxx";

julia> const STANLIB0 = joinpath(STANDIR, "libstan_normal_chol.so");

julia> function stan_multinormal_cholesky!(∂Y, ∂μ, ∂L, Y, μ, L)
           P, N = size(Y)
           ccall(
               (:normal_lpdf_cholesky, STANLIB0), Cdouble,
               (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Clong,Clong),
               ∂Y, ∂μ, ∂L.data, Y, μ, L.data, N, P
           )
       end
stan_multinormal_cholesky! (generic function with 1 method)

julia> ∂Ya′ = similar(Ya′);

julia> ∂μa = similar(μa);

julia> ∂La = similar(La);

julia> lp, ∂Y, ∂μ, ∂L = normal_chol(SPTR, Y, μ, L); lp
-15591.837040912393

julia> stan_multinormal_cholesky!(∂Ya′, ∂μa, ∂La, Ya′, μa, La)
-15591.837040912407

julia> approx_equal(∂Y', ∂Ya′)
true

julia> approx_equal(∂μ', ∂μa')
true

julia> approx_equal(∂L, ∂La)
true

julia> @benchmark normal_chol(SPTR, $Y, $μ, $L)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.845 μs (0.00% GC)
  median time:      7.961 μs (0.00% GC)
  mean time:        8.073 μs (0.00% GC)
  maximum time:     105.388 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

julia> @benchmark stan_multinormal_cholesky!($∂Ya′, $∂μa, $∂La, $Ya′, $μa, $La)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     70.675 ms (0.00% GC)
  median time:      81.732 ms (0.00% GC)
  mean time:        82.230 ms (0.00% GC)
  maximum time:     94.341 ms (0.00% GC)
  --------------
  samples:          61
  evals/sample:     1

That is 8 microseconds (for the Julia code) vs 70 ms for Stan’s autodiff, or a 8,750-fold difference, to get the log density (dropping constants) and the gradients with respect to all the arguments.
(Edit improved Julia’s performance, because I forgot to allign Julia’s stack – I guess that’s why to use vmovapd instead of vmovupd, AFAIK they’re equally fast, it’s just that the latter is slower while the former crahses when unaligned.)

I thought stan::math would be much faster if instead of calling stan::math::multi_normal_cholesky_log<true> I called (Y' / L).dot_self() - N * L.diag().sum(), but as the following post illustrates, this doesn’t seem to be the case.

I’m definitely a non-expert when it comes to autodiff and their implementations, so I seriously doubt I’ve solved the locality problem. But are there any problems with the general hybrid approach of:

  1. First line source to source, using optimized methods when available (like the bernoulli logit, or multi-normal cholesky) as described in the earlier 1-3. Keep all data packed to take advantage of SIMD when possible.
  2. Fall back to more general autodiff for complicated things like control flow and array indexing (especially assignments). If you need dual number types for this, try to use struct-of-array type data structures – both because they often enable SIMD, and because that stops you from having to pack and unpack. I would do this if I had to rely on Julia’s ForwardDiff, but Zygote (another source to source library) doesn’t need this.

I also describe the approach a little here. I show an example model and the code transformations to produce the reverse-mode AD code – largely a wall of ProbabilityModels.PaddedMatrices.RESERVED_INCREMENT_SEED_RESERVED but because of multiple dispatch (like operator overloading), they each call different methods appropriate to the types.

I broke the code recently as I’m working on tracking changes to DynamicHMC.jl, but I plan on fixing it over labor day weekend.

1 Like

Compiling the following function:

double normal_lpdf_cholesky2(double* dY, double* dmu, double* dL, double* Y, double* mu, double* L, long N, long P){
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,Eigen::Dynamic> vY(N,P), vL(P,P);
  Eigen::Matrix<stan::math::var,1,Eigen::Dynamic> vmu(P);
  //  std::cout << "N: " << N << " ; P: " << P << std::endl;
  for (long p = 0; p < P; p++){
    vmu(p) = mu[p];
    for (long pr = p; pr < P; pr++){
      vL(pr,p) = L[pr + p*P];
    }
    for (long n = 0; n < N; n++){
      vY(n,p) = Y[n + p*N];
    }
  }

  
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,Eigen::Dynamic> vDelta = vY.rowwise() - vmu;

  vL.triangularView<Eigen::Lower>().adjoint().solveInPlace<Eigen::OnTheRight>(vDelta);
  
  stan::math::var lp = -0.5*(vDelta.squaredNorm()) - N * (vL.diagonal().array().log().sum());

  lp.grad();

  for (long p = 0; p < P; p++){
    dmu[p] = vmu(p).adj();
    for (long pr = p; pr < P; pr++){
      dL[pr + p*P] = vL(pr,p).adj();
    }
    for (long n = 0; n < N; n++){
      dY[n + p*N] = vY(n,p).adj();
    }
  }
  return lp.val();
}

and using the same data layout for Y as my Julia version, I now get:

julia> function stan_multinormal_cholesky2!(∂Y, ∂μ, ∂L, Y, μ, L)
           N, P = size(Y)
           ccall(
               (:normal_lpdf_cholesky2, STANLIB0), Cdouble,
               (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Clong,Clong),
               ∂Y, ∂μ, ∂L.data, Y, μ, L.data, N, P
           )
       end
stan_multinormal_cholesky2! (generic function with 1 method)

julia> Ya = Array(Y);

julia> ∂Ya = similar(Ya);

julia> ∂μa = similar(μa);

julia> ∂La = similar(La);

julia> lp, ∂Y, ∂μ, ∂L = normal_chol(SPTR, Y, μ, L); lp
-15591.837040912393

julia> stan_multinormal_cholesky2!(∂Ya, ∂μa, ∂La, Ya, μa, La)
-15591.837040912395

julia> approx_equal(∂Y, ∂Ya)
true

julia> approx_equal(∂μ, ∂μa)
true

julia> approx_equal(∂L, ∂La)
true

julia> @benchmark normal_chol(SPTR, $Y, $μ, $L)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.955 μs (0.00% GC)
  median time:      8.004 μs (0.00% GC)
  mean time:        8.101 μs (0.00% GC)
  maximum time:     15.369 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

julia> @benchmark stan_multinormal_cholesky2!($∂Ya, $∂μa, $∂La, $Ya, $μa, $La)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     172.438 ms (0.00% GC)
  median time:      193.080 ms (0.00% GC)
  mean time:        191.596 ms (0.00% GC)
  maximum time:     214.988 ms (0.00% GC)
  --------------
  samples:          27
  evals/sample:     1

Or about a 20,000-fold difference in performance to compute the same thing.
For reference, I’m on the latest development version of Stan-math. I compiled the code using:

g++ -O3 -march=native -shared -fPIC -DNDEBUG -DEIGEN_NO_DEBUG -DADEPT_STACK_THREAD_UNSAFE -fno-signed-zeros -fno-trapping-math -fassociative-math -pipe -feliminate-unused-debug-types -I$STAN_MATH -I$STAN_MATH/lib/eigen_3.3.3 -I$STAN_MATH/lib/boost_1.69.0 -I$STAN_MATH/lib/sundials_4.1.0/include stan_normal_test.cpp -o libstan_normal_chol.so

The macro definitions were just copy and pasted from the Stan-math paper. The rest was generic optimization flags or include dirs (I installed Stan math and the dependencies via git cloning CmdStan).

I’m on the latest g++:

> g++ --version
g++ (Clear Linux OS for Intel Architecture) 9.2.1 20190828 gcc-9-branch@274989
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

EDIT:
For the sake of it, adding the analytical gradients using Eigen (note that the Eigen code is using dynamically sized arrays, while the Julia code is using statically sized; additionally, the Eigen code is copying a little excessively, but I didn’t want to look too heavily into interop for dynamically sized arrays):

julia> function stan_multinormal_cholesky3!(∂Y, ∂μ, ∂L, Y, μ, L)
           N, P = size(Y)
           ccall(
               (:normal_lpdf_cholesky3, STANLIB0), Cdouble,
               (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Clong,Clong),
               ∂Y, ∂μ, ∂L.data, Y, μ, L.data, N, P
           )
       end
stan_multinormal_cholesky3! (generic function with 1 method)

julia> lp, ∂Y, ∂μ, ∂L = normal_chol(SPTR, Y, μ, L); lp
-14785.977219901622

julia> stan_multinormal_cholesky3!(∂Ya, ∂μa, ∂La, Ya, μa, La)
-14785.97721990162

julia> approx_equal(∂Y, ∂Ya)
true

julia> approx_equal(∂μ, ∂μa)
true

julia> approx_equal(∂L, ∂La)
true

julia> @benchmark normal_chol(SPTR, $Y, $μ, $L)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     8.098 μs (0.00% GC)
  median time:      8.360 μs (0.00% GC)
  mean time:        8.414 μs (0.00% GC)
  maximum time:     13.498 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

julia> @benchmark stan_multinormal_cholesky3!($∂Ya, $∂μa, $∂La, $Ya, $μa, $La)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     57.995 μs (0.00% GC)
  median time:      60.663 μs (0.00% GC)
  mean time:        60.579 μs (0.00% GC)
  maximum time:     181.221 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

Over 7 times slower, but the analytical gradient is definitely much better here.
C++ code:


double normal_lpdf_cholesky3(double* dY, double* dmu, double* dL, double* Y, double* mu, double* L, long N, long P){
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> vY(N,P), vL(P,P);
  Eigen::Matrix<double,1,Eigen::Dynamic> vmu(P);
  //  std::cout << "N: " << N << " ; P: " << P << std::endl;
  for (long p = 0; p < P; p++){
    vmu(p) = mu[p];
    dmu[p] = 0;
  }
  for (long p = 0; p < P; p++){
    for (long pr = p; pr < P; pr++){
      vL(pr,p) = L[pr + p*P];
    }
    for (long n = 0; n < N; n++){
      vY(n,p) = Y[n + p*N];
    }
  }

  
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> vDelta = vY.rowwise() - vmu;

  vL.triangularView<Eigen::Lower>().adjoint().solveInPlace<Eigen::OnTheRight>(vDelta);
  
  double lp = -0.5*(vDelta.squaredNorm()) - N * (vL.diagonal().array().log().sum());
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> dlpdy = vL.triangularView<Eigen::Lower>().solve<Eigen::OnTheRight>(vDelta);
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> dlpdl = dlpdy.adjoint() * vDelta;

  for (long p = 0; p < P; p++){
    for (long n = 0; n < N; n++){
      dmu[p] += dlpdy(n,p);
      dY[n + p*N] = -dlpdy(n,p);
    }
    dL[p + p*P] = dlpdl(p,p) - N / L[p + p*P];
    for (long pr = p+1; pr < P; pr++){
      dL[pr + p*P] = dlpdl(pr,p);
    }
  }
  return lp;
}

I compiled with the same flags as before.

2 Likes

I’m not sure what you mean by automatically. What happens is that we construct two new matrices and do a loop to copy out the values and adjoints. It’s expensive and has to hit the heap for the matrices.

Yes, we know, and thus are sorry we ever used that term. Do you have a better term for what we’re doing with things like our probability functions?

Sparsity is key. The alternative is to use the adjoint-vector multiply operation that never has to construct an explicit Jacobian. That’s what we do for most of our matrix operations.

That’s a lot! What is the speedup coming from? And it looks even bigger in later posts. Is this something we could do in Stan?

I don’t know how to do this.

1 Like

How easy is it nowadays to call Julia functions from C? Would an approach like this one work to get Julia’s speed and Stan’s HMC?

Turing is supposed to have a similar NUTS implementation https://github.com/TuringLang/Turing.jl

@Bob_Carpenter by “struct-of-array type data structures” I think @chriselrod just means storing a matrix of vars as a struct containing two MatrixXds, one for values and one for adjoints.

We do a few other things inefficiently in Stan around Eigen as well - for example, we destruct and reallocate matrices with exactly the same sizes as previous log_prob evals on every leapfrog step. All told, I think Eigen reallocation has accounted around 50% of execution time in the Stan models I’ve profiled.

I don’t know the Julia interpreter, but if it’s like in other languages I would suspect there is significant overhead when leaving the interpreter (to call out to e.g. C) and then returning to it. Have you tested within C++? We have some example Stan Math benchmarks using Google’s benchmark tool set up here: https://github.com/seantalts/perf-math

@chriselrod Your approach to completely separate Jacobian computation from value computation seems like it would be hard to integrate with Stan’s AD system - the API always expects one to calculate the value and then call a special global function on just the inputs to get their gradients. I can’t think of any way to adapt that API in a general purpose way to what you’re proposing. However, and I think this is where you were going, the Stan Math AD system might have the same architectural benefits w.r.t. SIMD if we just do as you mentioned and swap out Eigen::Matrix for (a struct containing?) two parallel MatrixXds, right?

If we can figure out how to, at the same time, stop destroying and reallocating so many Eigen matrices each leapfrog step, that would be great.

3 Likes

That question wasn’t clear.
Possibilities I was wondering about (sorted from best to worst):

  1. Eigen already “packs” doubles while performing matrix multiplication (at least for large matrices). Would be nice to be able to combine this with what you describe. I recall the folks at BLIS talking about taking advantage of the packing step in this way (in BLIS) for better performance for different data types (eg, complex values).
  2. What you just described.
  3. The user would have to define code doing that manually.

“2)” is much better than “3)”, but “1)” would be neat.

I’m not really sure.
Julia has a syntax for element-wise operations (adding a dot to a function call or operator); dot-ed code lowers to calls to broadcast, so when talking about element-wise operations folks normally talk about broadcasting.
Stan also has broadcasting, but it’s explicit (row/column vector -> matrix) and the “vectorized” functions don’t use that convention.

Perhaps it’s best to use “SIMD” when talking about low level/CPU vectorization, although it’s awkward to use as a verb. “SIMDize”?

I could try to write a C++ version, or if someone else would like to try, I could explain the algorithm.
Some of the advantage came from the fact I was benchmarking on a computer with avx512, which Eigen (at that time, haven’t rechecked now) did, while my own code simply switches to the architecture’s largest vector width.

Something that can’t be reproduced, but would significantly simplify the implementation, is that I parameterized on the dimensionality of the normal (ie, dimensionality is effectively a template parameter in C++). I should write a Julia version to see how much of an impact this makes.
I also used an awkward packed data layout for the cholesky factor of the correlation matrix, which I’m unsure yet if I regret. It makes writing optimized code more difficult, because indexing is inefficient. It also makes parameterizing on the size more important than if I used a standard layout of filling only the upper/lower triangle of a standard layout.
It does take up less memory, and I store the diagonal separately, so that some calculations like logdet are easier to optimize. On the otherhand, I could use the same amount of memory by defining a data structure for lower triangular matrices by simply writing the gradient (the adjoint) as an upper triangular matrix to the upper triangle of a P x P + 1 matrix.
If I did things over again – or when I rewrite things later – I may try that approach. Would make it easier to call other libraries, for example, and make indexing random elements more efficient. I could also still separate the diagonals.

For evaluating the log density and gradiants, the evaluation is vectorized across samples (the matrix Y). That is, for N observations a P-variate normal, it evaluates multiple observations at a time / it vectorizes across N.
It calculates all results (density and all gradients) in 1-ish pass over Y.
The “-ish” caveat on 1 is that if it needs the full N x P gradient with respect to Y, for each block across N, it sweeps forward and then backward across P to calculate the corresponding block of (Y - mu) / L' / L.
In calculating everything in one pass, I also tried to order computations so that values in registers can be reused rather than stored somewhere and reloaded later.

The implementation isn’t great for large “P”, because it currently only loops over a microkernel, rather than a microkernel within a macrokernel.
That is, for block of observations it’s evaluating, it passes over the covariance matrix. If P is around 100, passing over the matrix is expensive enough that it would be better to block over this dimension as well in a macrokernel.
You could use a runtime checks on N and P to dispatch to different implementations.
Eg, choosing a specialized version when N == 1 (common cose, eg Gaussian process models), and then switch between a version using Eigen (like this) when P > 100, and using a C++ version of my Julia code otherwise.

The size of blocks over N change is determined at compile time, based on CPU architecture and the value of P (but, to let P be dynamic, as required by Stan, it can be standardized as a function of CPU architecture). Standard microkernel sizes (n x p) may be, for example:
Eg, for avx512, it may be: loop blocks of 40x5 -> loop blocks of 8x30 -> 8x30 + mask*
and for avx2, it may be loop blocks of 12x4 -> loop blocks of 4x14 -> 4x14 + mask*

These are determined by vector width and number of floating point registers.
But for dynamic P, 30 and 14 are extreme. This would probably require runing.

*For avx512, I think I should change the loop of 8 followed by a masked 8, into a single masked loop of 8, where each lane is turned on until it hits the end. This would reduce code size. For avx2, where masking is less efficient, I would have to benchmark to see which is better, but I think I ought to err on the side of smaller code size.
What I gave above would be the current behavior when P is large. (If P were known and small, it would often select different values, eg for P=12 and the cpu supported avx512, it would use 32x6, for example.)

@seantalts is correct, that’s exactly what I meant. If the values and adjoints can both be packed densely in their respective arrays, this would save you from copying for matrix multiplication, and allow other loops and function using vectorized/array notation (like y ~ normal(mu, sigma) when some of these are vectors) to use SIMD (again, without copying).

If you’re looking for speed, I would recommend my own suite of libraries ;)
InplaceDHMC.jl
ProbabilityModels.jl
I hope to have them registered and start recommending others try them out in the next few months. They’re not production ready yet.
They are much faster than the other Julia libraries (Turing or DynamicHMC, the latter of which I forked and modified), but my approach has also been less flexible.
My intention is to make the subset it can optimize reasonably flexible and extensible, and then default to slow approaches otherwise.

I will add more documentation and benchmarks as I get closer to release. In particular, I wrote a custom allocator for InplaceDHMC to manage memory while evaluating trees / performing the leap frog steps. It requires only a few more slots than the max treedepth, and finding empty slots is fast. It also frees the position and momentum as soon as they’re freed, without needing reference counting, etc.

Code defining the tree object and sampling from the tree is here.

Julia normally doesn’t run on an interpreter (if you want to interpret Julia, you must use the library JuliaInterpreter.jl). Julia methods are compiled with LLVM the first time a function x argument-type combination gets called (with the caveat that sometimes it’ll compile a generic version that accepts multiple types; it uses heuristics but you can also encourage or force specialization). Ie, function definitions are templates, each argument being another template parameter.

According to this benchmark, Julia actually has lower FFI overhead than C or C++. This is because Julia dereferences the function pointer before compiling, which is not possible for a precompiled binary or shared library to do.

I’m away from home until the end of the month, and shutdown my home computers before leaving (so no ssh). I could try later, but I’d expect results to be similar.

Yikes!
There’s an obvious 2x performance improvement right there. Perhaps obvious (“don’t reallocate”), although I’m sure it will take a lot of work to take care of.
Could you use some type of custom allocator to get better performance?

In ProbabilityModels.jl, I write everything to a custom stack, simply incrementing the stack pointer each time. I know exactly what escapes (the gradient, and nothing else), so I can simply restart the stack pointer at the same place each time.
This does make an assumption about how much memory the log density evaluation actually needs. It would crash if I don’t preallocate enough space.
But this is a simple approach, and faster than doing a bunch of heap allocations (incrementing a pointer vs finding an empty slot in memory, etc). In Julia, which has a GC, it is also vastly faster than triggering it, but C++ doesn’t have that performance issue.

I’ve written many of my functions to accept such a pointer. Some of them also return an (incremented) pointer to support this pattern. This all gets resolved during compile time.

Is that the sort of thing that a C++ version of the tree I’m using in Julia can help with?
I have already considering implementing a C++ version of this.
My colleagues all use R. Once I publish releases of my libraries, I was considering compiling the @code_llvm raw=true dump_module=true of logdensity functions with llc and clang, and linking to a C++ library to support them in making apps that run MCMC.
Compile times of some of my code is poor, and demands a massive refactoring or a rewrite. But I only have so much time.

Yeah. It goes without saying that speeding up computation can only be worth as much as the % of runtime is taken up by computation. With about 50% of the time spent on (re)allocations, that’ll be more profitable as a first optimization-target.

Struct of arrays also normally have overloaded assignment operators, so that they can still behave like an array of structs, but a different data layout under the hood.
A Julia example (taken from this discussion:

using StaticArrays, BenchmarkTools, StructArrays
struct Point3D{T} <: FieldVector{3, T}
    x::T
    y::T
    z::T
end
function f!(v)
    @inbounds @simd ivdep for i in eachindex(v)
        x, y, z = v[i]
        v[i] = Point3D(2x, x*y, x*z)
    end
end
v_aos = [Point3D(rand(3))  for i = 1:1000]; # array of structs
v_soa = StructArray(v_aos); # struct of arrays

Benchmarks:

julia> @benchmark f!($v_aos) # array of structs
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.637 μs (0.00% GC)
  median time:      1.643 μs (0.00% GC)
  mean time:        1.649 μs (0.00% GC)
  maximum time:     3.972 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark f!($v_soa) # struct of arrays
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     532.674 ns (0.00% GC)
  median time:      534.689 ns (0.00% GC)
  mean time:        538.171 ns (0.00% GC)
  maximum time:     1.065 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     190

This was with an avx2 (Haswell) laptop. Same function computing the same result, but the struct of arrays layout enables simd.
We can see how the assembly changes when we change the data layout; array of structs:

# julia> @code_native debuginfo=:none f!(v_aos)
	.text
	movq	%rsi, -8(%rsp)
	movq	(%rsi), %rcx
	movq	24(%rcx), %rax
	testq	%rax, %rax
	jle	L102
	movq	%rax, %rdx
	sarq	$63, %rdx
	andnq	%rax, %rdx, %rax
	movq	(%rcx), %rcx
	addq	$16, %rcx
	xorl	%edx, %edx
	movabsq	$.rodata.cst8, %rsi
	vmovsd	(%rsi), %xmm0           # xmm0 = mem[0],zero
	nopw	%cs:(%rax,%rax)
	nop
L64:
	vmovupd	-16(%rcx), %xmm1
	vunpcklpd	%xmm1, %xmm0, %xmm2 # xmm2 = xmm0[0],xmm1[0]
	vmulpd	%xmm2, %xmm1, %xmm2
	vmulsd	(%rcx), %xmm1, %xmm1
	vmovupd	%xmm2, -16(%rcx)
	vmovsd	%xmm1, (%rcx)
	incq	%rdx
	addq	$24, %rcx
	cmpq	%rax, %rdx
	jb	L64
L102:
	movabsq	$jl_system_image_data, %rax
	retq
	nopw	%cs:(%rax,%rax)
	nopl	(%rax,%rax)

Struct of arrays:

# julia> @code_native debuginfo=:none f!(v_soa)
	.text
	movq	%rsi, -8(%rsp)
	movq	(%rsi), %rax
	movq	(%rax), %rax
	movq	(%rax), %rcx
	movq	24(%rcx), %rdx
	testq	%rdx, %rdx
	jle	L186
	movq	%rdx, %rsi
	sarq	$63, %rsi
	andnq	%rdx, %rsi, %r8
	movq	8(%rax), %rdx
	movq	16(%rax), %rax
	movq	(%rcx), %rcx
	movq	(%rdx), %rdx
	movq	(%rax), %rsi
	cmpq	$4, %r8
	jae	L66
	xorl	%edi, %edi
	jmp	L144
L66:
	movabsq	$9223372036854775804, %rdi # imm = 0x7FFFFFFFFFFFFFFC
	andq	%r8, %rdi
	xorl	%eax, %eax
	nopw	%cs:(%rax,%rax)
	nopl	(%rax,%rax)
L96:
	vmovupd	(%rcx,%rax,8), %ymm0
	vaddpd	%ymm0, %ymm0, %ymm1
	vmulpd	(%rdx,%rax,8), %ymm0, %ymm2
	vmulpd	(%rsi,%rax,8), %ymm0, %ymm0
	vmovupd	%ymm1, (%rcx,%rax,8)
	vmovupd	%ymm2, (%rdx,%rax,8)
	vmovupd	%ymm0, (%rsi,%rax,8)
	addq	$4, %rax
	cmpq	%rax, %rdi
	jne	L96
	cmpq	%rdi, %r8
	je	L186
L144:
	vmovsd	(%rcx,%rdi,8), %xmm0    # xmm0 = mem[0],zero
	vaddsd	%xmm0, %xmm0, %xmm1
	vmulsd	(%rdx,%rdi,8), %xmm0, %xmm2
	vmulsd	(%rsi,%rdi,8), %xmm0, %xmm0
	vmovsd	%xmm1, (%rcx,%rdi,8)
	vmovsd	%xmm2, (%rdx,%rdi,8)
	vmovsd	%xmm0, (%rsi,%rdi,8)
	incq	%rdi
	cmpq	%r8, %rdi
	jb	L144
L186:
	movabsq	$jl_system_image_data, %rax
	vzeroupper
	retq

The array of structs loop body (L64) uses some sd (single double) suffixed operands, and a few vectorized pd (packed double) operands, but with at most an xmm register (128 bits, or 2 doubles).
The struct of arrays layout results in a loop body (L96) where all the operations are on 256 bits (ymm registers) of packed doubles (plus a loop for the remainder); with avx512 these would be 512 bit zmm registers.

It would be the same idea with Stan, where it shouldn’t be any harder than swapping one container type for another (which for functions templated with respect to container types will be easy), and should result in a similar performance benefit on many computations.
(But so long as 50% of the time is spent reallocating arrays rather than performing computations, the possible performance gain from speeding up computation is limited.)

1 Like

I should just caveat - I haven’t profiled that many models and they can have very different bottlenecks. But yeah. I tried to use a custom allocator (I could try to find the branch) for std::vector and didn’t see much of a speedup, but didn’t investigate too hard. I don’t think I tried anything similar for Eigen types - it’s much less clear the optimal way to switch the allocator for those (is there anything documented?). I think further experimentation here could be very profitable, especially if we assume it’ll be a while until the math library is refactored to be generically templated over Eigen types (see this thread or this one for some commentary on that; not sure if there’s a definitive thread).

We do the same thing for Stan’s AD primitive, the var. So, if I recall correctly, we’re doing something much worse (from a SIMD perspective) than having a matrix full of structs as you were thinking; we actually have a matrix full of pointers to structs :P (each var is a pointer allocated on our custom AD stack memory pool to point to a vari struct containing a value and an adjoint). Those Eigen matrices are allocated on the heap, though the two doubles for each cell is allocated (individually) in our custom memory pool.

Actually I recently starting working on that. Here is link to the github issue: https://github.com/stan-dev/math/issues/1470. First PR is already waiting for a review. However it will take a long time to completely fix this as is needs to touch almost every single function in Stan Math.

After that is complete I have an intention to look into matrix storage format, but this could be months or a year from now. So if anyone wants to take it up in the meantime they are welcome to. We might be able to get away with writing a specialization for Matrix<var,...> that would internally be struct of arrays. If not, we need to make a custom type and this is again something that will touch whole codebase.

2 Likes

in the README of ProbabilityModels.jl, you write

Does source to source approach you used possible / make sense for the sort of linear algebra stuff derivatives Stan implements? Or is it orthogonal (in the design sense)?