Stan SIMD & Performance

Ah, well – looks like the vari type interleaves values with adjoints. This is (IMO) better than Julia’s ForwardDiff.Dual, which interleaves a value with a vector of (forward-mode) partials, but the interleaving is still not ideal.

If we have an Eigen matrix, what is the actual data layout?
Given that stan::math::var is a pointer to the vari, does that mean it is a matrix of pointers to value, adjoint pairs?

From my (admittedly poor) understanding of what Stan does under the hood, it doesn’t look like it can take advantage of SIMD.
Based on assessments of runtime performance vs SIMD-implementations, as well as comparisons of clock speeds (watch -n1 "cat /proc/cpuinfo | grep MHz") vs bios settings for non-avx, avx2, and avx512 workloads further support that conclusion.

As does the fact that the “vectorized” bernoulli_logit_glm function was several times slower than a Julia version I vectorized. Again, I’m fairly ignorant about the Stan version’s implementation and did not look at the assembly (although I did of the Julia version), so all I can speak of there is the observed performance difference implying unrealized potential in Stan.

Even with doubles instead of the stan::math::var type, Eigen as a library does not appear to do a good job for small, fixed-size, matrices (compared to libraries like PaddedMatrices.jl, blaze, or Intel MKL JIT. I haven’t investigated small, dynamically-sized Eigen matrices (which Stan uses), but it’d be a little odd if that performed better.

I’m not sure if I have anything constructive to say. Making good use of SIMD starts from the data structures / the algorithms need to be written with it in mind, therefore there (normally) aren’t any easy fixes or recommendations.
Especially when I haven’t spent the time to really study Stan’s internals and algorithms. And I’m also hopefully wrong about some of the things I’ve said here.

2 Likes

I cant say whether eigen matrices of vars receive simd instructions, but I have an example below from godbolt and it looks like Eigen does fine at generating code using VMOVUPD and other AVX instructions (I’m not the best at reading assembly so if that’s false my bad)

You bring up some interesting Qs though! I’ll check the gcc output with a stan example next week to see what sort of instructions eigen is generating. Could also probs bring the base files for stan into godbolt and do it all there. Also wonder what adding vectors of vars would look like. We keep the val and adj next to each other because in scalar arithmetic on vars we tend to use the val and adj of each var in a way that it’s better to have them next to each other

Also on a side note it would be a bit easier to see what’s going on in your post if you made some graphs over the different sizes / performance. Imo

I’d do more but on mobile atm

Back in the day Eigen needed compiler flags to enable taking advantage of chip-specific vectorization and optimizations – not sure if that’s the still the case or not but is worth verifying before discussing changing architectures (perhaps along with investigations of the assembly code).

1 Like

Eigen does use some SIMD instructions, but (as I said) it does a bad job of it.

Here are some benchmarks with neat sizes (16x24 * 24x14 = 16x14) and awkward prime sizes (31x37 * 37x29 = 31x29). Pardon the awkward code (particularly creating and allocating the matrices), but Eigen uses vmovapd instead of vmovupd; they’re more or less the same (I’ve not noticed a performance difference), except vmovapd causes segfaults without alignment.

julia> using PaddedMatrices, LinearAlgebra, Random, BenchmarkTools

julia> import PaddedMatrices: AbstractMutableFixedSizePaddedMatrix

julia> # Compiled with
       # g++ -O3 -fno-signed-zeros -fno-trapping-math -fassociative-math -march=skylake-avx512 -mprefer-vector-width=512 -I/usr/include/eigen3 -shared -fPIC eigen_matmul_test.cpp -o libgppeigentest.so
       # clang++ -O3 -fno-signed-zeros -fno-trapping-math -fassociative-math -march=skylake-avx512 -mprefer-vector-width=512 -I/usr/include/eigen3 -shared -fPIC eigen_matmul_test.cpp -o libclangeigentest.so
       const gpplib = "/home/chriselrod/Documents/progwork/Cxx/libgppeigentest.so"
"/home/chriselrod/Documents/progwork/Cxx/libgppeigentest.so"

julia> const clanglib = "/home/chriselrod/Documents/progwork/Cxx/libclangeigentest.so"
"/home/chriselrod/Documents/progwork/Cxx/libclangeigentest.so"

julia> function gppeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{16,14,Float64,16},
           A::AbstractMutableFixedSizePaddedMatrix{16,24,Float64,16},
           B::AbstractMutableFixedSizePaddedMatrix{24,14,Float64,24}
       )
           ccall(
               (:mul_16x24times24x14, gpplib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
gppeigen! (generic function with 1 method)

julia> function gppeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{31,29,Float64,31},
           A::AbstractMutableFixedSizePaddedMatrix{31,37,Float64,31},
           B::AbstractMutableFixedSizePaddedMatrix{37,29,Float64,37}
       )
           ccall(
               (:mul_31x37times37x29, gpplib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
gppeigen! (generic function with 2 methods)

julia> function clangeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{16,14,Float64,16},
           A::AbstractMutableFixedSizePaddedMatrix{16,24,Float64,16},
           B::AbstractMutableFixedSizePaddedMatrix{24,14,Float64,24}
       )
           ccall(
               (:mul_16x24times24x14, clanglib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
clangeigen! (generic function with 1 method)

julia> function clangeigen!(
           C::AbstractMutableFixedSizePaddedMatrix{31,29,Float64,31},
           A::AbstractMutableFixedSizePaddedMatrix{31,37,Float64,31},
           B::AbstractMutableFixedSizePaddedMatrix{37,29,Float64,37}
       )
           ccall(
               (:mul_31x37times37x29, clanglib), Cvoid,
               (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}),
               pointer(C), pointer(A), pointer(B)
           )
       end
clangeigen! (generic function with 2 methods)

julia> align_64(x) = (x + 63) & ~63
align_64 (generic function with 1 method)

julia> align_64(x, ptr::Ptr{T}) where {T} = align_64(x*sizeof(T) + ptr)
align_64 (generic function with 2 methods)

julia> const PTR = Base.unsafe_convert(Ptr{Float64}, Libc.malloc(1<<20)); # more than enough space

julia> align_64(x::Ptr{T}) where {T} = reinterpret(Ptr{T},align_64(reinterpret(UInt,x)))
align_64 (generic function with 3 methods)

julia> align_64(PTR)
Ptr{Float64} @0x00007fad8b585040

julia> A16x24 = PtrMatrix{16,24,Float64,16}(align_64(PTR)); randn!(A16x24);

julia> B24x14 = PtrMatrix{24,14,Float64,24}(align_64(16*24,pointer(A16x24))); randn!(B24x14);

julia> C16x14 = PtrMatrix{16,14,Float64,16}(align_64(24*14,pointer(B24x14)));

julia> clangeigen!(C16x14, A16x24, B24x14)

julia> @benchmark gppeigen!($C16x14, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     446.863 ns (0.00% GC)
  median time:      454.548 ns (0.00% GC)
  mean time:        457.793 ns (0.00% GC)
  maximum time:     701.843 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     197

julia> @benchmark clangeigen!($C16x14, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     470.245 ns (0.00% GC)
  median time:      474.612 ns (0.00% GC)
  mean time:        477.438 ns (0.00% GC)
  maximum time:     663.571 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     196

julia> @benchmark mul!($C16x14, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     88.141 ns (0.00% GC)
  median time:      90.651 ns (0.00% GC)
  mean time:        91.153 ns (0.00% GC)
  maximum time:     142.651 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     960

julia> # Prime sizes
       A31x37 = PtrMatrix{31,37,Float64,31}(align_64(16*14,pointer(C16x14))); randn!(A31x37);

julia> B37x29 = PtrMatrix{37,29,Float64,37}(align_64(31*37,pointer(A31x37))); randn!(B37x29);

julia> C31x29 = PtrMatrix{31,29,Float64,31}(align_64(37*29,pointer(B37x29)));

julia> @benchmark gppeigen!($C31x29, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.435 μs (0.00% GC)
  median time:      2.458 μs (0.00% GC)
  mean time:        2.494 μs (0.00% GC)
  maximum time:     6.107 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark clangeigen!($C31x29, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     2.346 μs (0.00% GC)
  median time:      2.373 μs (0.00% GC)
  mean time:        2.388 μs (0.00% GC)
  maximum time:     6.090 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark mul!($C31x29, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     585.606 ns (0.00% GC)
  median time:      607.706 ns (0.00% GC)
  mean time:        610.846 ns (0.00% GC)
  maximum time:     1.466 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     180

This roughly four-fold difference in performance was common across an array of small sizes.

You can look at the C++ code and the assembly here.
For comparison, here is the assembly Julia produced (I deleted all lines that only contained comments for brevity’s sake):

julia> @code_native mul!(C16x14, A16x24, B24x14)
	.text
	movq	(%rsi), %rax
	vmovupd	(%rax), %zmm26
	vmovupd	64(%rax), %zmm28
	movq	(%rdx), %rcx
	vbroadcastsd	(%rcx), %zmm1
	vmulpd	%zmm26, %zmm1, %zmm0
	vmulpd	%zmm28, %zmm1, %zmm1
	vbroadcastsd	192(%rcx), %zmm3
	vmulpd	%zmm26, %zmm3, %zmm2
	vmulpd	%zmm28, %zmm3, %zmm3
	vbroadcastsd	384(%rcx), %zmm5
	vmulpd	%zmm26, %zmm5, %zmm4
	vmulpd	%zmm28, %zmm5, %zmm5
	vbroadcastsd	576(%rcx), %zmm7
	vmulpd	%zmm26, %zmm7, %zmm6
	vmulpd	%zmm28, %zmm7, %zmm7
	vbroadcastsd	768(%rcx), %zmm9
	vmulpd	%zmm26, %zmm9, %zmm8
	vmulpd	%zmm28, %zmm9, %zmm9
	vbroadcastsd	960(%rcx), %zmm11
	vmulpd	%zmm26, %zmm11, %zmm10
	vmulpd	%zmm28, %zmm11, %zmm11
	vbroadcastsd	1152(%rcx), %zmm13
	vmulpd	%zmm26, %zmm13, %zmm12
	vmulpd	%zmm28, %zmm13, %zmm13
	vbroadcastsd	1344(%rcx), %zmm15
	vmulpd	%zmm26, %zmm15, %zmm14
	vmulpd	%zmm28, %zmm15, %zmm15
	vbroadcastsd	1536(%rcx), %zmm17
	vmulpd	%zmm26, %zmm17, %zmm16
	vmulpd	%zmm28, %zmm17, %zmm17
	vbroadcastsd	1728(%rcx), %zmm19
	vmulpd	%zmm26, %zmm19, %zmm18
	vmulpd	%zmm28, %zmm19, %zmm19
	vbroadcastsd	1920(%rcx), %zmm21
	vmulpd	%zmm26, %zmm21, %zmm20
	vmulpd	%zmm28, %zmm21, %zmm21
	vbroadcastsd	2112(%rcx), %zmm23
	vmulpd	%zmm26, %zmm23, %zmm22
	vmulpd	%zmm28, %zmm23, %zmm23
	vbroadcastsd	2304(%rcx), %zmm25
	vmulpd	%zmm26, %zmm25, %zmm24
	vmulpd	%zmm28, %zmm25, %zmm25
	vbroadcastsd	2496(%rcx), %zmm29
	vmulpd	%zmm26, %zmm29, %zmm27
	vmulpd	%zmm28, %zmm29, %zmm26
	addq	$192, %rax
	movq	$-23, %rdx
	nopw	%cs:(%rax,%rax)
	nopl	(%rax,%rax)
L336:
	vmovapd	%zmm27, %zmm28
	vmovupd	-64(%rax), %zmm27
	vmovupd	(%rax), %zmm29
	vbroadcastsd	192(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm0 # zmm0 = (zmm27 * zmm30) + zmm0
	vfmadd231pd	%zmm30, %zmm29, %zmm1 # zmm1 = (zmm29 * zmm30) + zmm1
	vbroadcastsd	384(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm2 # zmm2 = (zmm27 * zmm30) + zmm2
	vfmadd231pd	%zmm30, %zmm29, %zmm3 # zmm3 = (zmm29 * zmm30) + zmm3
	vbroadcastsd	576(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm4 # zmm4 = (zmm27 * zmm30) + zmm4
	vfmadd231pd	%zmm30, %zmm29, %zmm5 # zmm5 = (zmm29 * zmm30) + zmm5
	vbroadcastsd	768(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm6 # zmm6 = (zmm27 * zmm30) + zmm6
	vfmadd231pd	%zmm30, %zmm29, %zmm7 # zmm7 = (zmm29 * zmm30) + zmm7
	vbroadcastsd	960(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm8 # zmm8 = (zmm27 * zmm30) + zmm8
	vfmadd231pd	%zmm30, %zmm29, %zmm9 # zmm9 = (zmm29 * zmm30) + zmm9
	vbroadcastsd	1152(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm10 # zmm10 = (zmm27 * zmm30) + zmm10
	vfmadd231pd	%zmm30, %zmm29, %zmm11 # zmm11 = (zmm29 * zmm30) + zmm11
	vbroadcastsd	1344(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm12 # zmm12 = (zmm27 * zmm30) + zmm12
	vfmadd231pd	%zmm30, %zmm29, %zmm13 # zmm13 = (zmm29 * zmm30) + zmm13
	vbroadcastsd	1536(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm14 # zmm14 = (zmm27 * zmm30) + zmm14
	vfmadd231pd	%zmm30, %zmm29, %zmm15 # zmm15 = (zmm29 * zmm30) + zmm15
	vbroadcastsd	1728(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm16 # zmm16 = (zmm27 * zmm30) + zmm16
	vfmadd231pd	%zmm30, %zmm29, %zmm17 # zmm17 = (zmm29 * zmm30) + zmm17
	vbroadcastsd	1920(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm18 # zmm18 = (zmm27 * zmm30) + zmm18
	vfmadd231pd	%zmm30, %zmm29, %zmm19 # zmm19 = (zmm29 * zmm30) + zmm19
	vbroadcastsd	2112(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm20 # zmm20 = (zmm27 * zmm30) + zmm20
	vfmadd231pd	%zmm30, %zmm29, %zmm21 # zmm21 = (zmm29 * zmm30) + zmm21
	vbroadcastsd	2304(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm22 # zmm22 = (zmm27 * zmm30) + zmm22
	vfmadd231pd	%zmm30, %zmm29, %zmm23 # zmm23 = (zmm29 * zmm30) + zmm23
	vbroadcastsd	2496(%rcx,%rdx,8), %zmm30
	vfmadd231pd	%zmm30, %zmm27, %zmm24 # zmm24 = (zmm27 * zmm30) + zmm24
	vfmadd231pd	%zmm30, %zmm29, %zmm25 # zmm25 = (zmm29 * zmm30) + zmm25
	vbroadcastsd	2688(%rcx,%rdx,8), %zmm30
	vfmadd213pd	%zmm28, %zmm30, %zmm27 # zmm27 = (zmm30 * zmm27) + zmm28
	vfmadd231pd	%zmm30, %zmm29, %zmm26 # zmm26 = (zmm29 * zmm30) + zmm26
	subq	$-128, %rax
	incq	%rdx
	jne	L336
	movq	(%rdi), %rax
	vmovupd	%zmm0, (%rax)
	vmovupd	%zmm1, 64(%rax)
	vmovupd	%zmm2, 128(%rax)
	vmovupd	%zmm3, 192(%rax)
	vmovupd	%zmm4, 256(%rax)
	vmovupd	%zmm5, 320(%rax)
	vmovupd	%zmm6, 384(%rax)
	vmovupd	%zmm7, 448(%rax)
	vmovupd	%zmm8, 512(%rax)
	vmovupd	%zmm9, 576(%rax)
	vmovupd	%zmm10, 640(%rax)
	vmovupd	%zmm11, 704(%rax)
	vmovupd	%zmm12, 768(%rax)
	vmovupd	%zmm13, 832(%rax)
	vmovupd	%zmm14, 896(%rax)
	vmovupd	%zmm15, 960(%rax)
	vmovupd	%zmm16, 1024(%rax)
	vmovupd	%zmm17, 1088(%rax)
	vmovupd	%zmm18, 1152(%rax)
	vmovupd	%zmm19, 1216(%rax)
	vmovupd	%zmm20, 1280(%rax)
	vmovupd	%zmm21, 1344(%rax)
	vmovupd	%zmm22, 1408(%rax)
	vmovupd	%zmm23, 1472(%rax)
	vmovupd	%zmm24, 1536(%rax)
	vmovupd	%zmm25, 1600(%rax)
	vmovupd	%zmm27, 1664(%rax)
	vmovupd	%zmm26, 1728(%rax)
	vzeroupper
	retq
	nopl	(%rax)

and for the larger. awkwardly sized matrices:

julia> @code_native mul!(C31x29, A31x37, B37x29)
	.text
	pushq	%rbp
	pushq	%r15
	pushq	%r14
	pushq	%r13
	pushq	%r12
	pushq	%rbx
	movq	%rdx, -8(%rsp)
	movq	%rsi, -16(%rsp)
	movl	$1, %r9d
	movl	$1488, %esi             # imm = 0x5D0
	movb	$127, %r10b
	nopw	%cs:(%rax,%rax)
	nopl	(%rax)
L48:
	leaq	-1(%r9), %rcx
	movq	-16(%rsp), %rax
	movq	(%rax), %r11
	movq	-8(%rsp), %rax
	movq	(%rax), %r14
	imulq	$1488, %rcx, %rax       # imm = 0x5D0
	imulq	$1776, %rcx, %r13       # imm = 0x6F0
	addq	(%rdi), %rax
	leaq	(%r14,%rsi), %rcx
	leaq	440(%r11), %rdx
	xorl	%r12d, %r12d
	nopw	%cs:(%rax,%rax)
	nopl	(%rax)
L112:
	imulq	$248, %r12, %rbx
	leaq	(%r11,%rbx), %rbp
	vmovupd	(%r11,%rbx), %zmm19
	vmovupd	64(%r11,%rbx), %zmm22
	vmovupd	128(%r11,%rbx), %zmm23
	vbroadcastsd	(%r14,%r13), %zmm3
	vmovupd	192(%r11,%rbx), %zmm24
	vmulpd	%zmm19, %zmm3, %zmm0
	vmulpd	%zmm22, %zmm3, %zmm1
	vmulpd	%zmm23, %zmm3, %zmm2
	vbroadcastsd	296(%r14,%r13), %zmm7
	vmulpd	%zmm24, %zmm3, %zmm6
	vmulpd	%zmm19, %zmm7, %zmm3
	vmulpd	%zmm22, %zmm7, %zmm4
	vmulpd	%zmm23, %zmm7, %zmm5
	vbroadcastsd	592(%r14,%r13), %zmm11
	vmulpd	%zmm24, %zmm7, %zmm10
	vmulpd	%zmm19, %zmm11, %zmm7
	vmulpd	%zmm22, %zmm11, %zmm8
	vmulpd	%zmm23, %zmm11, %zmm9
	vbroadcastsd	888(%r14,%r13), %zmm15
	vmulpd	%zmm24, %zmm11, %zmm14
	vmulpd	%zmm19, %zmm15, %zmm11
	vmulpd	%zmm22, %zmm15, %zmm12
	vmulpd	%zmm23, %zmm15, %zmm13
	vbroadcastsd	1184(%r14,%r13), %zmm20
	vmulpd	%zmm24, %zmm15, %zmm18
	vmulpd	%zmm19, %zmm20, %zmm15
	vmulpd	%zmm22, %zmm20, %zmm16
	vmulpd	%zmm23, %zmm20, %zmm17
	vbroadcastsd	1480(%r14,%r13), %zmm25
	vmulpd	%zmm24, %zmm20, %zmm20
	vmulpd	%zmm19, %zmm25, %zmm21
	vmulpd	%zmm22, %zmm25, %zmm22
	vmulpd	%zmm23, %zmm25, %zmm23
	vmulpd	%zmm24, %zmm25, %zmm19
	movq	$-35, %r8
	movq	%rdx, %r15
	nopl	(%rax)
L368:
	vmovapd	%zmm23, %zmm24
	vmovapd	%zmm22, %zmm25
	vmovapd	%zmm21, %zmm26
	vmovupd	-192(%r15), %zmm21
	vmovupd	-128(%r15), %zmm22
	vmovupd	-64(%r15), %zmm23
	vmovupd	(%r15), %zmm27
	vbroadcastsd	-1200(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm0 # zmm0 = (zmm21 * zmm28) + zmm0
	vfmadd231pd	%zmm28, %zmm22, %zmm1 # zmm1 = (zmm22 * zmm28) + zmm1
	vfmadd231pd	%zmm28, %zmm23, %zmm2 # zmm2 = (zmm23 * zmm28) + zmm2
	vfmadd231pd	%zmm28, %zmm27, %zmm6 # zmm6 = (zmm27 * zmm28) + zmm6
	vbroadcastsd	-904(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm3 # zmm3 = (zmm21 * zmm28) + zmm3
	vfmadd231pd	%zmm28, %zmm22, %zmm4 # zmm4 = (zmm22 * zmm28) + zmm4
	vfmadd231pd	%zmm28, %zmm23, %zmm5 # zmm5 = (zmm23 * zmm28) + zmm5
	vfmadd231pd	%zmm28, %zmm27, %zmm10 # zmm10 = (zmm27 * zmm28) + zmm10
	vbroadcastsd	-608(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm7 # zmm7 = (zmm21 * zmm28) + zmm7
	vfmadd231pd	%zmm28, %zmm22, %zmm8 # zmm8 = (zmm22 * zmm28) + zmm8
	vfmadd231pd	%zmm28, %zmm23, %zmm9 # zmm9 = (zmm23 * zmm28) + zmm9
	vfmadd231pd	%zmm28, %zmm27, %zmm14 # zmm14 = (zmm27 * zmm28) + zmm14
	vbroadcastsd	-312(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm11 # zmm11 = (zmm21 * zmm28) + zmm11
	vfmadd231pd	%zmm28, %zmm22, %zmm12 # zmm12 = (zmm22 * zmm28) + zmm12
	vfmadd231pd	%zmm28, %zmm23, %zmm13 # zmm13 = (zmm23 * zmm28) + zmm13
	vfmadd231pd	%zmm28, %zmm27, %zmm18 # zmm18 = (zmm27 * zmm28) + zmm18
	vbroadcastsd	-16(%rcx,%r8,8), %zmm28
	vfmadd231pd	%zmm28, %zmm21, %zmm15 # zmm15 = (zmm21 * zmm28) + zmm15
	vfmadd231pd	%zmm28, %zmm22, %zmm16 # zmm16 = (zmm22 * zmm28) + zmm16
	vfmadd231pd	%zmm28, %zmm23, %zmm17 # zmm17 = (zmm23 * zmm28) + zmm17
	vfmadd231pd	%zmm28, %zmm27, %zmm20 # zmm20 = (zmm27 * zmm28) + zmm20
	vbroadcastsd	280(%rcx,%r8,8), %zmm28
	vfmadd213pd	%zmm26, %zmm28, %zmm21 # zmm21 = (zmm28 * zmm21) + zmm26
	vfmadd213pd	%zmm25, %zmm28, %zmm22 # zmm22 = (zmm28 * zmm22) + zmm25
	vfmadd213pd	%zmm24, %zmm28, %zmm23 # zmm23 = (zmm28 * zmm23) + zmm24
	vfmadd231pd	%zmm28, %zmm27, %zmm19 # zmm19 = (zmm27 * zmm28) + zmm19
	addq	$248, %r15
	incq	%r8
	jne	L368
	vmovupd	8928(%rbp), %zmm24
	vmovupd	8992(%rbp), %zmm25
	vmovupd	9056(%rbp), %zmm26
	kmovd	%r10d, %k1
	vmovupd	9120(%rbp), %zmm27 {%k1} {z}
	vbroadcastsd	288(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm0 # zmm0 = (zmm24 * zmm28) + zmm0
	vfmadd231pd	%zmm28, %zmm25, %zmm1 # zmm1 = (zmm25 * zmm28) + zmm1
	vfmadd231pd	%zmm28, %zmm26, %zmm2 # zmm2 = (zmm26 * zmm28) + zmm2
	vfmadd231pd	%zmm28, %zmm27, %zmm6 # zmm6 = (zmm27 * zmm28) + zmm6
	vbroadcastsd	584(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm3 # zmm3 = (zmm24 * zmm28) + zmm3
	vfmadd231pd	%zmm28, %zmm25, %zmm4 # zmm4 = (zmm25 * zmm28) + zmm4
	vfmadd231pd	%zmm28, %zmm26, %zmm5 # zmm5 = (zmm26 * zmm28) + zmm5
	vfmadd231pd	%zmm28, %zmm27, %zmm10 # zmm10 = (zmm27 * zmm28) + zmm10
	vbroadcastsd	880(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm7 # zmm7 = (zmm24 * zmm28) + zmm7
	vfmadd231pd	%zmm28, %zmm25, %zmm8 # zmm8 = (zmm25 * zmm28) + zmm8
	vfmadd231pd	%zmm28, %zmm26, %zmm9 # zmm9 = (zmm26 * zmm28) + zmm9
	vfmadd231pd	%zmm28, %zmm27, %zmm14 # zmm14 = (zmm27 * zmm28) + zmm14
	vbroadcastsd	1176(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm11 # zmm11 = (zmm24 * zmm28) + zmm11
	vfmadd231pd	%zmm28, %zmm25, %zmm12 # zmm12 = (zmm25 * zmm28) + zmm12
	vfmadd231pd	%zmm28, %zmm26, %zmm13 # zmm13 = (zmm26 * zmm28) + zmm13
	vfmadd231pd	%zmm28, %zmm27, %zmm18 # zmm18 = (zmm27 * zmm28) + zmm18
	vbroadcastsd	1472(%r14,%r13), %zmm28
	vfmadd231pd	%zmm28, %zmm24, %zmm15 # zmm15 = (zmm24 * zmm28) + zmm15
	vfmadd231pd	%zmm28, %zmm25, %zmm16 # zmm16 = (zmm25 * zmm28) + zmm16
	vfmadd231pd	%zmm28, %zmm26, %zmm17 # zmm17 = (zmm26 * zmm28) + zmm17
	vfmadd231pd	%zmm28, %zmm27, %zmm20 # zmm20 = (zmm27 * zmm28) + zmm20
	vbroadcastsd	1768(%r14,%r13), %zmm28
	vfmadd213pd	%zmm21, %zmm28, %zmm24 # zmm24 = (zmm28 * zmm24) + zmm21
	vfmadd213pd	%zmm22, %zmm28, %zmm25 # zmm25 = (zmm28 * zmm25) + zmm22
	vfmadd213pd	%zmm23, %zmm28, %zmm26 # zmm26 = (zmm28 * zmm26) + zmm23
	vmovupd	%zmm0, (%rax,%rbx)
	vmovupd	%zmm1, 64(%rax,%rbx)
	vmovupd	%zmm2, 128(%rax,%rbx)
	vmovupd	%zmm6, 192(%rax,%rbx) {%k1}
	vmovupd	%zmm3, 248(%rax,%rbx)
	vmovupd	%zmm4, 312(%rax,%rbx)
	vmovupd	%zmm5, 376(%rax,%rbx)
	vmovupd	%zmm10, 440(%rax,%rbx) {%k1}
	vmovupd	%zmm7, 496(%rax,%rbx)
	vmovupd	%zmm8, 560(%rax,%rbx)
	vmovupd	%zmm9, 624(%rax,%rbx)
	vmovupd	%zmm14, 688(%rax,%rbx) {%k1}
	vmovupd	%zmm11, 744(%rax,%rbx)
	vmovupd	%zmm12, 808(%rax,%rbx)
	vmovupd	%zmm13, 872(%rax,%rbx)
	vmovupd	%zmm18, 936(%rax,%rbx) {%k1}
	vmovupd	%zmm15, 992(%rax,%rbx)
	vmovupd	%zmm16, 1056(%rax,%rbx)
	vmovupd	%zmm17, 1120(%rax,%rbx)
	vmovupd	%zmm20, 1184(%rax,%rbx) {%k1}
	vmovupd	%zmm24, 1240(%rax,%rbx)
	vmovupd	%zmm25, 1304(%rax,%rbx)
	vmovupd	%zmm26, 1368(%rax,%rbx)
	vfmadd231pd	%zmm28, %zmm27, %zmm19 # zmm19 = (zmm27 * zmm28) + zmm19
	vmovupd	%zmm19, 1432(%rax,%rbx) {%k1}
	addq	$248, %rdx
	testq	%r12, %r12
	leaq	1(%r12), %r12
	jne	L112
	addq	$1776, %rsi             # imm = 0x6F0
	cmpq	$4, %r9
	leaq	1(%r9), %r9
	jne	L48
	movq	(%rdi), %rax
	movq	-16(%rsp), %rcx
	movq	(%rcx), %rsi
	movq	-8(%rsp), %rcx
	movq	(%rcx), %rcx
	vmovupd	(%rsi), %zmm16
	vmovupd	64(%rsi), %zmm18
	vmovupd	128(%rsi), %zmm19
	vmovupd	192(%rsi), %zmm20
	vbroadcastsd	7104(%rcx), %zmm3
	vmulpd	%zmm16, %zmm3, %zmm0
	vmulpd	%zmm18, %zmm3, %zmm1
	vmulpd	%zmm19, %zmm3, %zmm2
	vmulpd	%zmm20, %zmm3, %zmm3
	vbroadcastsd	7400(%rcx), %zmm7
	vmulpd	%zmm16, %zmm7, %zmm4
	vmulpd	%zmm18, %zmm7, %zmm5
	vmulpd	%zmm19, %zmm7, %zmm6
	vmulpd	%zmm20, %zmm7, %zmm7
	vbroadcastsd	7696(%rcx), %zmm11
	vmulpd	%zmm16, %zmm11, %zmm8
	vmulpd	%zmm18, %zmm11, %zmm9
	vmulpd	%zmm19, %zmm11, %zmm10
	vmulpd	%zmm20, %zmm11, %zmm11
	vbroadcastsd	7992(%rcx), %zmm15
	vmulpd	%zmm16, %zmm15, %zmm12
	vmulpd	%zmm18, %zmm15, %zmm13
	vmulpd	%zmm19, %zmm15, %zmm14
	vmulpd	%zmm20, %zmm15, %zmm15
	vbroadcastsd	8288(%rcx), %zmm21
	vmulpd	%zmm16, %zmm21, %zmm17
	vmulpd	%zmm18, %zmm21, %zmm18
	vmulpd	%zmm19, %zmm21, %zmm19
	vmulpd	%zmm20, %zmm21, %zmm16
	leaq	440(%rsi), %rdx
	movq	$-35, %rdi
	nopw	%cs:(%rax,%rax)
	nopl	(%rax,%rax)
L1408:
	vmovapd	%zmm19, %zmm20
	vmovapd	%zmm18, %zmm21
	vmovapd	%zmm17, %zmm22
	vmovupd	-192(%rdx), %zmm17
	vmovupd	-128(%rdx), %zmm18
	vmovupd	-64(%rdx), %zmm19
	vmovupd	(%rdx), %zmm23
	vbroadcastsd	7392(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm0 # zmm0 = (zmm17 * zmm24) + zmm0
	vfmadd231pd	%zmm24, %zmm18, %zmm1 # zmm1 = (zmm18 * zmm24) + zmm1
	vfmadd231pd	%zmm24, %zmm19, %zmm2 # zmm2 = (zmm19 * zmm24) + zmm2
	vfmadd231pd	%zmm24, %zmm23, %zmm3 # zmm3 = (zmm23 * zmm24) + zmm3
	vbroadcastsd	7688(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm4 # zmm4 = (zmm17 * zmm24) + zmm4
	vfmadd231pd	%zmm24, %zmm18, %zmm5 # zmm5 = (zmm18 * zmm24) + zmm5
	vfmadd231pd	%zmm24, %zmm19, %zmm6 # zmm6 = (zmm19 * zmm24) + zmm6
	vfmadd231pd	%zmm24, %zmm23, %zmm7 # zmm7 = (zmm23 * zmm24) + zmm7
	vbroadcastsd	7984(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm8 # zmm8 = (zmm17 * zmm24) + zmm8
	vfmadd231pd	%zmm24, %zmm18, %zmm9 # zmm9 = (zmm18 * zmm24) + zmm9
	vfmadd231pd	%zmm24, %zmm19, %zmm10 # zmm10 = (zmm19 * zmm24) + zmm10
	vfmadd231pd	%zmm24, %zmm23, %zmm11 # zmm11 = (zmm23 * zmm24) + zmm11
	vbroadcastsd	8280(%rcx,%rdi,8), %zmm24
	vfmadd231pd	%zmm24, %zmm17, %zmm12 # zmm12 = (zmm17 * zmm24) + zmm12
	vfmadd231pd	%zmm24, %zmm18, %zmm13 # zmm13 = (zmm18 * zmm24) + zmm13
	vfmadd231pd	%zmm24, %zmm19, %zmm14 # zmm14 = (zmm19 * zmm24) + zmm14
	vfmadd231pd	%zmm24, %zmm23, %zmm15 # zmm15 = (zmm23 * zmm24) + zmm15
	vbroadcastsd	8576(%rcx,%rdi,8), %zmm24
	vfmadd213pd	%zmm22, %zmm24, %zmm17 # zmm17 = (zmm24 * zmm17) + zmm22
	vfmadd213pd	%zmm21, %zmm24, %zmm18 # zmm18 = (zmm24 * zmm18) + zmm21
	vfmadd213pd	%zmm20, %zmm24, %zmm19 # zmm19 = (zmm24 * zmm19) + zmm20
	vfmadd231pd	%zmm24, %zmm23, %zmm16 # zmm16 = (zmm23 * zmm24) + zmm16
	addq	$248, %rdx
	incq	%rdi
	jne	L1408
	vmovupd	8928(%rsi), %zmm20
	vmovupd	8992(%rsi), %zmm21
	vmovupd	9056(%rsi), %zmm22
	movb	$127, %dl
	kmovd	%edx, %k1
	vmovupd	9120(%rsi), %zmm23 {%k1} {z}
	vbroadcastsd	7392(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm0 # zmm0 = (zmm20 * zmm24) + zmm0
	vfmadd231pd	%zmm24, %zmm21, %zmm1 # zmm1 = (zmm21 * zmm24) + zmm1
	vfmadd231pd	%zmm24, %zmm22, %zmm2 # zmm2 = (zmm22 * zmm24) + zmm2
	vfmadd231pd	%zmm24, %zmm23, %zmm3 # zmm3 = (zmm23 * zmm24) + zmm3
	vbroadcastsd	7688(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm4 # zmm4 = (zmm20 * zmm24) + zmm4
	vfmadd231pd	%zmm24, %zmm21, %zmm5 # zmm5 = (zmm21 * zmm24) + zmm5
	vfmadd231pd	%zmm24, %zmm22, %zmm6 # zmm6 = (zmm22 * zmm24) + zmm6
	vfmadd231pd	%zmm24, %zmm23, %zmm7 # zmm7 = (zmm23 * zmm24) + zmm7
	vbroadcastsd	7984(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm8 # zmm8 = (zmm20 * zmm24) + zmm8
	vfmadd231pd	%zmm24, %zmm21, %zmm9 # zmm9 = (zmm21 * zmm24) + zmm9
	vfmadd231pd	%zmm24, %zmm22, %zmm10 # zmm10 = (zmm22 * zmm24) + zmm10
	vfmadd231pd	%zmm24, %zmm23, %zmm11 # zmm11 = (zmm23 * zmm24) + zmm11
	vbroadcastsd	8280(%rcx), %zmm24
	vfmadd231pd	%zmm24, %zmm20, %zmm12 # zmm12 = (zmm20 * zmm24) + zmm12
	vfmadd231pd	%zmm24, %zmm21, %zmm13 # zmm13 = (zmm21 * zmm24) + zmm13
	vfmadd231pd	%zmm24, %zmm22, %zmm14 # zmm14 = (zmm22 * zmm24) + zmm14
	vfmadd231pd	%zmm24, %zmm23, %zmm15 # zmm15 = (zmm23 * zmm24) + zmm15
	vbroadcastsd	8576(%rcx), %zmm24
	vfmadd213pd	%zmm17, %zmm24, %zmm20 # zmm20 = (zmm24 * zmm20) + zmm17
	vfmadd213pd	%zmm18, %zmm24, %zmm21 # zmm21 = (zmm24 * zmm21) + zmm18
	vfmadd213pd	%zmm19, %zmm24, %zmm22 # zmm22 = (zmm24 * zmm22) + zmm19
	vfmadd231pd	%zmm24, %zmm23, %zmm16 # zmm16 = (zmm23 * zmm24) + zmm16
	vmovupd	%zmm0, 5952(%rax)
	vmovupd	%zmm1, 6016(%rax)
	vmovupd	%zmm2, 6080(%rax)
	vmovupd	%zmm3, 6144(%rax) {%k1}
	vmovupd	%zmm4, 6200(%rax)
	vmovupd	%zmm5, 6264(%rax)
	vmovupd	%zmm6, 6328(%rax)
	vmovupd	%zmm7, 6392(%rax) {%k1}
	vmovupd	%zmm8, 6448(%rax)
	vmovupd	%zmm9, 6512(%rax)
	vmovupd	%zmm10, 6576(%rax)
	vmovupd	%zmm11, 6640(%rax) {%k1}
	vmovupd	%zmm12, 6696(%rax)
	vmovupd	%zmm13, 6760(%rax)
	vmovupd	%zmm14, 6824(%rax)
	vmovupd	%zmm15, 6888(%rax) {%k1}
	vmovupd	%zmm20, 6944(%rax)
	vmovupd	%zmm21, 7008(%rax)
	vmovupd	%zmm22, 7072(%rax)
	vmovupd	%zmm16, 7136(%rax) {%k1}
	popq	%rbx
	popq	%r12
	popq	%r13
	popq	%r14
	popq	%r15
	popq	%rbp
	vzeroupper
	retq
	nop

The most striking difference from looking at the assembly is the far higher density of vmul and vfmadd instructions. These occur within blocks within loops in Julia code (the blocks being sized to avoid register spills). The Eigen assembly is far harder to follow, looking much more complex. And runs around 4 times slower.

1 Like

A quick workaround could be by using external BLAS from Eigen, see: https://eigen.tuxfamily.org/dox/TopicUsingBlasLapack.html

Where specifically are the bad versions of the instructions? For things like vmovapd usually gives you better access because it has guaranteed alignment (else the throw)

  1. Thanks for the code! I’ll try to replicate this week, though last time I tried installing Julia it was a bit of a mess. Could you try this with non-fixed size matrices? We use dynamic sized matrices in the math library because we don’t know the data shape at compile time.

  2. I’d recommend making summary charts and graphs out of your results. As they are it’s very hard for me to make inference from your raw print outs.

  3. I’d post this to the Eigen email chain, I think they would have better insight into what’s going on here.

  4. A 4x speedup over Eigen would put you out of the scope of the y-axis of the blaze benchmarks. that’s pretty wild! I’ll look into this but I’d very much check your work to make sure everything adds up.

Actually looking at the mat-mul benchmark here your results would be faster than the multi-threaded version of blaze.

My tests were on a CPU with the avx-512 instruction set.

( 16*14*(2*24-1) / 90 ) * 10^3 is approximately 117,000.

Those Blaze benchmarks were run on a E5 Haswell EP CPU at 2.3 GHz.
I ran my benchmarks on a 7900X. I’d have to check what the clock speed is set to in the bios, but (a) probably >50% higher, and (b) it has the avx512 instruction set. Avx512 doubles vector width (doubling theoretical FLOPS), and doubles the number of registers (allowing for larger kernels, making it easier to reach peak FLOPS in matrix multiplication).

Where specifically are the bad versions of the instructions? For things like vmovapd usually gives you better access because it has guaranteed alignment (else the throw)
Lines 177-780 from my earlier Godbolt link:

.L24:
        xor     ebx, ebx
        xor     r11d, r11d
        jmp     .L2
Eigen::internal::gemm_pack_rhs<double, long, Eigen::internal::const_blas_data_mapper<double, long, 0>, 4, 0, false, false>::operator()(double*, Eigen::internal::const_blas_data_mapper<double, long, 0> const&, long, long, long, long) [clone .constprop.0]:
        push    rbp
        mov     rbp, rsp
        push    r15
        mov     r15, rdx
        push    r14
        push    r13
        mov     r13, rdi
        push    r12
        push    rbx
        and     rsp, -64
        sub     rsp, 8
        mov     QWORD PTR [rsp-88], rsi
        mov     QWORD PTR [rsp-96], rcx
        test    rcx, rcx
        lea     rax, [rcx+3]
        cmovns  rax, rcx
        add     rdx, 3
        test    r15, r15
        cmovns  rdx, r15
        and     rdx, -4
        and     rax, -4
        mov     QWORD PTR [rsp-64], rdx
        mov     QWORD PTR [rsp-80], rax
        jle     .L84
        mov     rax, rdx
        dec     rax
        shr     rax, 2
        lea     rsi, [4+rax*4]
        sal     rax, 4
        mov     QWORD PTR [rsp-112], rsi
        mov     QWORD PTR [rsp-104], rax
        mov     QWORD PTR [rsp-8], 0
        mov     QWORD PTR [rsp-24], 0
        mov     QWORD PTR [rsp], r15
        vmovdqa64       zmm5, ZMMWORD PTR .LC0[rip]
        vmovdqa64       zmm4, ZMMWORD PTR .LC1[rip]
        mov     r14, rdi
.L55:
        mov     rax, QWORD PTR [rsp-88]
        mov     rbx, QWORD PTR [rsp-24]
        mov     r10, QWORD PTR [rax+8]
        mov     r9, QWORD PTR [rax]
        lea     rax, [rbx+2]
        imul    rax, r10
        mov     r15, rbx
        lea     r11, [rbx+1]
        mov     QWORD PTR [rsp-32], rax
        lea     rdi, [r9+rax*8]
        lea     rax, [rbx+3]
        imul    rax, r10
        imul    r15, r10
        imul    r11, r10
        mov     QWORD PTR [rsp-40], rax
        lea     r8, [r9+rax*8]
        mov     rax, QWORD PTR [rsp-8]
        mov     r10, QWORD PTR [rsp-64]
        lea     rdx, [r14+rax*8]
        xor     eax, eax
        cmp     QWORD PTR [rsp-64], 0
        lea     rcx, [r9+r15*8]
        lea     rsi, [r9+r11*8]
        jle     .L85
.L47:
        vmovupd ymm7, YMMWORD PTR [rcx+rax*8]
        sub     rdx, -128
        vunpckhpd       ymm2, ymm7, YMMWORD PTR [rsi+rax*8]
        vunpcklpd       ymm0, ymm7, YMMWORD PTR [rsi+rax*8]
        vmovupd ymm7, YMMWORD PTR [rdi+rax*8]
        vunpckhpd       ymm3, ymm7, YMMWORD PTR [r8+rax*8]
        vunpcklpd       ymm1, ymm7, YMMWORD PTR [r8+rax*8]
        vshuff64x2      ymm6, ymm2, ymm3, 0
        add     rax, 4
        vshuff64x2      ymm2, ymm2, ymm3, 3
        vshuff64x2      ymm3, ymm0, ymm1, 0
        vshuff64x2      ymm0, ymm0, ymm1, 3
        vmovupd YMMWORD PTR [rdx-128], ymm3
        vmovupd YMMWORD PTR [rdx-96], ymm6
        vmovupd YMMWORD PTR [rdx-64], ymm0
        vmovupd YMMWORD PTR [rdx-32], ymm2
        cmp     r10, rax
        jg      .L47
        mov     rax, QWORD PTR [rsp-8]
        mov     rbx, QWORD PTR [rsp-104]
        mov     rdx, QWORD PTR [rsp-112]
        lea     rax, [rax+16+rbx]
        mov     QWORD PTR [rsp-8], rax
.L48:
        cmp     QWORD PTR [rsp], rdx
        jle     .L45
        mov     rbx, QWORD PTR [rsp]
        mov     r12, QWORD PTR [rsp-8]
        sub     rbx, rdx
        mov     rax, rbx
        dec     rax
        lea     r10, [r12+rbx*4]
        mov     QWORD PTR [rsp-48], rax
        mov     r13, QWORD PTR [rsp-32]
        lea     rax, [r14+r12*8]
        lea     r12, [r14+r10*8]
        lea     r10, [r15+rdx]
        mov     QWORD PTR [rsp-56], rbx
        lea     rbx, [r9+r10*8]
        lea     r10, [r11+rdx]
        mov     QWORD PTR [rsp-16], r12
        lea     r12, [r9+r10*8]
        lea     r10, [r13+0+rdx]
        lea     r13, [r9+r10*8]
        mov     r10, QWORD PTR [rsp-40]
        add     r15, QWORD PTR [rsp]
        add     r10, rdx
        lea     r10, [r9+r10*8]
        mov     QWORD PTR [rsp-72], r10
        lea     r10, [r9+r15*8]
        cmp     rax, r10
        setnb   r15b
        cmp     rbx, QWORD PTR [rsp-16]
        setnb   r10b
        or      r15d, r10d
        mov     r10, QWORD PTR [rsp]
        add     r10, r11
        lea     r10, [r9+r10*8]
        cmp     rax, r10
        setnb   r11b
        cmp     QWORD PTR [rsp-16], r12
        setbe   r10b
        or      r10d, r11d
        and     r10d, r15d
        cmp     QWORD PTR [rsp-48], 6
        seta    r11b
        and     r10d, r11d
        mov     r11, QWORD PTR [rsp-32]
        add     r11, QWORD PTR [rsp]
        lea     r11, [r9+r11*8]
        cmp     rax, r11
        setnb   r11b
        cmp     QWORD PTR [rsp-16], r13
        setbe   r15b
        or      r11d, r15d
        test    r10b, r11b
        je      .L49
        mov     r11, QWORD PTR [rsp-40]
        add     r11, QWORD PTR [rsp]
        lea     r9, [r9+r11*8]
        mov     r11, QWORD PTR [rsp-72]
        cmp     rax, r9
        setnb   r9b
        cmp     QWORD PTR [rsp-16], r11
        setbe   r10b
        or      r9b, r10b
        je      .L49
        mov     r10, QWORD PTR [rsp-56]
        xor     r9d, r9d
        shr     r10, 3
        sal     r10, 6
.L50:
        vmovupd zmm0, ZMMWORD PTR [rbx+r9]
        vmovupd zmm3, ZMMWORD PTR [r13+0+r9]
        vmovupd zmm1, ZMMWORD PTR [r12+r9]
        vmovupd zmm6, ZMMWORD PTR [r11+r9]
        vmovapd zmm2, zmm0
        vpermt2pd       zmm2, zmm5, zmm3
        vpermt2pd       zmm0, zmm4, zmm3
        vmovapd zmm3, zmm1
        vpermt2pd       zmm3, zmm5, zmm6
        vpermt2pd       zmm1, zmm4, zmm6
        vmovapd zmm6, zmm2
        vpermt2pd       zmm2, zmm4, zmm3
        vmovupd ZMMWORD PTR [rax+64+r9*4], zmm2
        vmovapd zmm2, zmm0
        vpermt2pd       zmm6, zmm5, zmm3
        vpermt2pd       zmm2, zmm5, zmm1
        vpermt2pd       zmm0, zmm4, zmm1
        vmovupd ZMMWORD PTR [rax+r9*4], zmm6
        vmovupd ZMMWORD PTR [rax+128+r9*4], zmm2
        vmovupd ZMMWORD PTR [rax+192+r9*4], zmm0
        add     r9, 64
        cmp     r9, r10
        jne     .L50
        mov     rbx, QWORD PTR [rsp-56]
        mov     r15, QWORD PTR [rsp-8]
        mov     rax, rbx
        and     rax, -8
        add     rdx, rax
        lea     r10, [r15+rax*4]
        cmp     rbx, rax
        je      .L52
        vmovsd  xmm0, QWORD PTR [rcx+rdx*8]
        lea     rax, [0+r10*8]
        vmovsd  QWORD PTR [r14+r10*8], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+rdx*8]
        mov     rbx, QWORD PTR [rsp]
        vmovsd  QWORD PTR [r14+8+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+rdx*8]
        lea     r11, [rdx+1]
        vmovsd  QWORD PTR [r14+16+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+rdx*8]
        add     r10, 4
        vmovsd  QWORD PTR [r14+24+rax], xmm0
        cmp     rbx, r11
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+r11*8]
        lea     rax, [0+r10*8]
        vmovsd  QWORD PTR [r14+r10*8], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+r11*8]
        lea     r10, [rdx+2]
        vmovsd  QWORD PTR [r14+8+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+r11*8]
        lea     r9, [0+r11*8]
        vmovsd  QWORD PTR [r14+16+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+r11*8]
        vmovsd  QWORD PTR [r14+24+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+8+r9]
        lea     r10, [rdx+3]
        vmovsd  QWORD PTR [r14+32+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+8+r9]
        vmovsd  QWORD PTR [r14+40+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+8+r9]
        vmovsd  QWORD PTR [r14+48+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+8+r9]
        vmovsd  QWORD PTR [r14+56+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+16+r9]
        lea     r10, [rdx+4]
        vmovsd  QWORD PTR [r14+64+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+16+r9]
        vmovsd  QWORD PTR [r14+72+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+16+r9]
        vmovsd  QWORD PTR [r14+80+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+16+r9]
        vmovsd  QWORD PTR [r14+88+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+24+r9]
        lea     r10, [rdx+5]
        vmovsd  QWORD PTR [r14+96+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+24+r9]
        vmovsd  QWORD PTR [r14+104+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+24+r9]
        vmovsd  QWORD PTR [r14+112+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+24+r9]
        vmovsd  QWORD PTR [r14+120+rax], xmm0
        cmp     rbx, r10
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+32+r9]
        add     rdx, 6
        vmovsd  QWORD PTR [r14+128+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+32+r9]
        vmovsd  QWORD PTR [r14+136+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+32+r9]
        vmovsd  QWORD PTR [r14+144+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+32+r9]
        vmovsd  QWORD PTR [r14+152+rax], xmm0
        cmp     rbx, rdx
        jle     .L52
        vmovsd  xmm0, QWORD PTR [rcx+40+r9]
        vmovsd  QWORD PTR [r14+160+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+40+r9]
        vmovsd  QWORD PTR [r14+168+rax], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+40+r9]
        vmovsd  QWORD PTR [r14+176+rax], xmm0
        vmovsd  xmm0, QWORD PTR [r8+40+r9]
        vmovsd  QWORD PTR [r14+184+rax], xmm0
.L52:
        mov     rax, QWORD PTR [rsp-8]
        mov     rsi, QWORD PTR [rsp-48]
        lea     rax, [rax+4+rsi*4]
        mov     QWORD PTR [rsp-8], rax
.L45:
        add     QWORD PTR [rsp-24], 4
        mov     rax, QWORD PTR [rsp-24]
        cmp     QWORD PTR [rsp-80], rax
        jg      .L55
        mov     r15, QWORD PTR [rsp]
        mov     r13, r14
.L54:
        mov     rsi, QWORD PTR [rsp-80]
        cmp     QWORD PTR [rsp-96], rsi
        jle     .L82
        mov     rax, QWORD PTR [rsp-88]
        mov     r8, QWORD PTR [rax]
        mov     r14, QWORD PTR [rax+8]
        test    r15, r15
        jle     .L82
        mov     r9, r15
        and     r9, -8
        lea     rdi, [r9+3]
        mov     QWORD PTR [rsp-40], rdi
        lea     rdi, [r9+4]
        mov     QWORD PTR [rsp-48], rdi
        lea     rdi, [r9+5]
        mov     rcx, rsi
        mov     QWORD PTR [rsp-56], rdi
        lea     rdi, [r9+6]
        imul    rcx, r14
        mov     QWORD PTR [rsp-64], rdi
        lea     rdi, [r15-1]
        mov     rbx, QWORD PTR [rsp-8]
        lea     rax, [0+r14*8]
        mov     r10, rsi
        mov     QWORD PTR [rsp-24], rdi
        mov     rsi, r15
        mov     QWORD PTR [rsp-8], r14
        lea     r12, [r9+2]
        mov     QWORD PTR [rsp], rax
        shr     rsi, 3
        lea     rax, [0+r15*8]
        mov     QWORD PTR [rsp-16], rax
        mov     QWORD PTR [rsp-32], r12
        lea     rdx, [r8+rcx*8]
        lea     rax, [r13+0+rbx*8]
        sal     rsi, 6
        lea     r11, [r9+1]
        mov     r12, rbx
.L60:
        lea     rdi, [rdx+64]
        cmp     rax, rdi
        lea     rdi, [rax+64]
        setnb   bl
        cmp     rdx, rdi
        setnb   dil
        or      bl, dil
        mov     edi, 0
        je      .L56
        cmp     QWORD PTR [rsp-24], 6
        jbe     .L56
.L57:
        vmovupd zmm4, ZMMWORD PTR [rdx+rdi]
        vmovupd ZMMWORD PTR [rax+rdi], zmm4
        add     rdi, 64
        cmp     rdi, rsi
        jne     .L57
        lea     rdi, [r12+r9]
        cmp     r9, r15
        je      .L59
        lea     rbx, [rcx+r9]
        vmovsd  xmm0, QWORD PTR [r8+rbx*8]
        vmovsd  QWORD PTR [r13+0+rdi*8], xmm0
        inc     rdi
        cmp     r15, r11
        jle     .L59
        lea     r14, [rcx+r11]
        vmovsd  xmm0, QWORD PTR [r8+r14*8]
        lea     rbx, [0+rdi*8]
        vmovsd  QWORD PTR [r13+0+rdi*8], xmm0
        mov     rdi, QWORD PTR [rsp-32]
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-40]
        vmovsd  QWORD PTR [r13+8+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-48]
        vmovsd  QWORD PTR [r13+16+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-56]
        vmovsd  QWORD PTR [r13+24+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        mov     rdi, QWORD PTR [rsp-64]
        vmovsd  QWORD PTR [r13+32+rbx], xmm0
        cmp     r15, rdi
        jle     .L59
        add     rdi, rcx
        vmovsd  xmm0, QWORD PTR [r8+rdi*8]
        vmovsd  QWORD PTR [r13+40+rbx], xmm0
.L59:
        inc     r10
        add     r12, r15
        add     rdx, QWORD PTR [rsp]
        add     rax, QWORD PTR [rsp-16]
        add     rcx, QWORD PTR [rsp-8]
        cmp     QWORD PTR [rsp-96], r10
        jne     .L60
.L82:
        vzeroupper
        lea     rsp, [rbp-40]
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     r15
        pop     rbp
        ret
.L56:
        vmovsd  xmm0, QWORD PTR [rdx+rdi*8]
        vmovsd  QWORD PTR [rax+rdi*8], xmm0
        inc     rdi
        cmp     r15, rdi
        jne     .L56
        jmp     .L59
.L49:
        mov     r9, QWORD PTR [rsp]
.L53:
        vmovsd  xmm0, QWORD PTR [rcx+rdx*8]
        add     rax, 32
        vmovsd  QWORD PTR [rax-32], xmm0
        vmovsd  xmm0, QWORD PTR [rsi+rdx*8]
        vmovsd  QWORD PTR [rax-24], xmm0
        vmovsd  xmm0, QWORD PTR [rdi+rdx*8]
        vmovsd  QWORD PTR [rax-16], xmm0
        vmovsd  xmm0, QWORD PTR [r8+rdx*8]
        inc     rdx
        vmovsd  QWORD PTR [rax-8], xmm0
        cmp     r9, rdx
        jne     .L53
        jmp     .L52
.L85:
        xor     edx, edx
        jmp     .L48
.L84:
        mov     QWORD PTR [rsp-8], 0
        jmp     .L54
Eigen::internal::gebp_kernel<double, double, long, Eigen::internal::blas_data_mapper<double, long, 0, 0>, 12, 4, false, false>::operator()(Eigen::internal::blas_data_mapper<double, long, 0, 0> const&, double const*, double const*, long, long, long, double, long, long, long, long) [clone .constprop.0]:
        push    rbp
        lea     rax, [r9+3]
        vmovapd zmm19, zmm0
        mov     rbp, rsp
        push    r15
        mov     r15, rdx
        push    r14
        mov     r14, rdi
        mov     rdi, rcx
        push    r13
        mov     r13, r8
        push    r12
        push    rbx
        mov     rbx, r8
        and     rsp, -64
        sub     rsp, 136
        test    r9, r9
        cmovns  rax, r9
        mov     r11, rax
        mov     QWORD PTR [rsp-8], rdx
        mov     rax, rcx
        movabs  rdx, 3074457345618258603
        imul    rdx
        sar     rdi, 63
        and     r11, -4
        mov     rax, rdx
        sar     rax
        sub     rax, rdi
        lea     rax, [rax+rax*2]
        lea     rdi, [0+rax*4]
        mov     rax, rcx
        sub     rax, rdi
        mov     rdx, rax
        lea     rax, [rax+7]
        cmovns  rax, rdx
        mov     QWORD PTR [rsp], rcx
        and     rax, -8
        add     rax, rdi
        mov     QWORD PTR [rsp-24], rax
        test    rcx, rcx
        lea     rax, [rcx+3]
        cmovns  rax, rcx
        xor     edx, edx
        and     rax, -4
        mov     QWORD PTR [rsp-16], rax
        mov     eax, 1012
        sub     rax, r8
        sal     rax, 5
        mov     rcx, rax
        mov     QWORD PTR [rsp-56], rax
        lea     rax, [r8+r8*2]
        sal     rax, 5
        mov     r10, rax
        mov     QWORD PTR [rsp+96], rax
        mov     rax, rcx
        div     r10
        mov     QWORD PTR [rsp-32], rsi
        mov     QWORD PTR [rsp+56], r9
        mov     QWORD PTR [rsp+64], r11
        mov     QWORD PTR [rsp-40], rdi
        and     rbx, -8
        mov     edx, 1
        test    rax, rax
        cmovle  rax, rdx
        lea     rax, [rax+rax*2]
        sal     rax, 2
        mov     QWORD PTR [rsp-80], rax
        test    rdi, rdi
        jle     .L87
        imul    rax, r8
        imul    r11, r8
        mov     QWORD PTR [rsp+40], rsi
        sal     rax, 3
        mov     QWORD PTR [rsp-72], rax
        mov     rax, r8
        sal     rax, 5
        mov     QWORD PTR [rsp+8], rax
        lea     rax, [rbx-1]
        shr     rax, 3
        inc     rax
        mov     rcx, rax
        sal     rcx, 8
        lea     rdx, [rax+rax*2]
        sal     rax, 6
        mov     QWORD PTR [rsp+32], rcx
        mov     QWORD PTR [rsp+24], rax
        lea     rcx, [0+r8*8]
        mov     rax, rbx
        mov     QWORD PTR [rsp+16], rcx
        mov     r12, rdx
        lea     rcx, [r15+r11*8]
        neg     rax
        sal     r12, 8
        sal     rax, 3
        mov     QWORD PTR [rsp-64], rcx
        mov     QWORD PTR [rsp-48], 0
        mov     QWORD PTR [rsp+72], rax
        mov     rax, r12
        vbroadcastsd    ymm20, xmm0
        mov     r12, r8
        mov     r13, rax
.L101:
        mov     rcx, QWORD PTR [rsp-48]
        mov     rdi, QWORD PTR [rsp-80]
        mov     QWORD PTR [rsp+48], rcx
        mov     rax, rcx
        add     rcx, rdi
        mov     rdi, QWORD PTR [rsp-40]
        mov     QWORD PTR [rsp-48], rcx
        cmp     rdi, rcx
        cmovle  rcx, rdi
        cmp     QWORD PTR [rsp+64], 0
        mov     QWORD PTR [rsp+128], rcx
        jle     .L88
        cmp     rcx, rax
        jle     .L89
        xor     eax, eax
        mov     r11, QWORD PTR [rsp-8]
        mov     QWORD PTR [rsp+80], r13
        mov     r13, rax
.L92:
        lea     rax, [r13+1]
        mov     QWORD PTR [rsp+120], rax
        lea     rax, [r13+2]
        mov     QWORD PTR [rsp+112], rax
        lea     rax, [r13+3]
        mov     QWORD PTR [rsp+104], rax
        mov     rax, QWORD PTR [rsp+32]
        mov     r10, QWORD PTR [rsp+40]
        add     rax, r11
        mov     QWORD PTR [rsp+88], rax
        mov     r9, QWORD PTR [rsp+48]
.L90:
        mov     rdx, QWORD PTR [r14+8]
        mov     rcx, QWORD PTR [r14]
        mov     rsi, rdx
        imul    rsi, r13
        mov     rax, r10
        prefetcht0      [r10]
        add     rsi, r9
        lea     r8, [rcx+rsi*8]
        mov     rsi, QWORD PTR [rsp+120]
        prefetcht0      [r8]
        imul    rsi, rdx
        prefetcht0      [r11]
        add     rsi, r9
        lea     rdi, [rcx+rsi*8]
        mov     rsi, QWORD PTR [rsp+112]
        prefetcht0      [rdi]
        imul    rsi, rdx
        imul    rdx, QWORD PTR [rsp+104]
        add     rsi, r9
        add     rdx, r9
        lea     rsi, [rcx+rsi*8]
        lea     rcx, [rcx+rdx*8]
        prefetcht0      [rsi]
        prefetcht0      [rcx]
        test    rbx, rbx
        jle     .L207
        vxorpd  xmm0, xmm0, xmm0
        lea     rdx, [r10+512]
        mov     rax, r11
        vmovapd ymm2, ymm0
        vmovapd ymm3, ymm0
        vmovapd ymm18, ymm0
        vmovapd ymm6, ymm0
        vmovapd ymm7, ymm0
        vmovapd ymm8, ymm0
        vmovapd ymm9, ymm0
        vmovapd ymm10, ymm0
        vmovapd ymm11, ymm0
        vmovapd ymm12, ymm0
        vmovapd ymm13, ymm0
        xor     r15d, r15d

For fixed size matrix multiplication at these sizes, packing is unnecessary (although it is necessary for large sizes, where it prevents memory throttling; packing also has an O(N^2) cost and because matrix multiplication is O(N^3) is thus also “free” at larger sizes).

The kernel code itself otherwise looks more or less fine, other than that they don’t support the avx512 instruction set:

.L97:
        prefetcht0      [rax]
        vmovapd ymm16, YMMWORD PTR [rdx-448]
        vmovapd ymm15, YMMWORD PTR [rdx-512]
        vmovapd ymm14, YMMWORD PTR [rdx-480]
        vbroadcastsd    ymm1, QWORD PTR [rax]
        vmovapd ymm5, ymm16
        vfmadd231pd ymm15, ymm1, ymm13
        vfmadd231pd ymm5, ymm1, ymm11
        vfmadd231pd ymm14, ymm1, ymm12
        vbroadcastsd    ymm1, QWORD PTR [rax+8]
        vfmadd231pd ymm5, ymm1, ymm8
        vmovapd ymm5, ymm18
        vbroadcastsd    ymm18, QWORD PTR [rax+24]
        vmovapd ymm4, ymm16
        vfmadd231pd ymm15, ymm1, ymm10
        vfmadd231pd ymm14, ymm1, ymm9
        prefetcht0      [rdx]
        vbroadcastsd    ymm1, QWORD PTR [rax+16]
        vfmadd231pd ymm4, ymm1, ymm5
        vfmadd231pd ymm15, ymm1, ymm7
        vmovapd ymm4, ymm18
        vfmadd231pd ymm14, ymm1, ymm6
        vfmadd231pd ymm15, ymm4, ymm3

Note the 256-bit ymm registers instead of 512-bit zmm registers, despite compiling with -march=skylake-avx512 -mprefer-vector-width=512.

That said, my statement should be qualified “poor vectorization at small fixed sizes with the avx512 instruction set”, because packing isn’t helpful yet at those sizes, and it doesn’t seem to use that instruction set. Given larger matrices (where you want packing) and if you only have avx2 (such as the Haswell E5-2650V3 from Blaze’s benchmarks), my criticisms are non-applicable.
[I looked at Eigen’s asm more closely in making this post; earlier I just saw that the asm was long, that it included things like `vunpckhpd`, `vshuff64`, etc, that it didn’t use `zmm` registers, and that performance was slow in those examples, and perhaps judged it unfairly without first looking more closely at what it was trying to do.]

  1. My Julia matrix multuiplication library doesn’t really support dynamically sized matrices yet. Intel MKL JIT would probably be a great choice if it can be combined with Eigen like @andre.pfeuffer’s link, although MKL may have poor performance on AMD CPUs (ie, it may simply use SSE instead of avx? I’m not sure about this).
  2. I’ll try and get around to it.

It would be nice to see the benchmark for yours with the blaze benchmarks on your system! Tone often gets lost on the internet I hope you don’t think I was dismissing your work at all, it’s v cool! You certainly seem to know more about these things then I do

After this comment I was trying to look where in the packet math that builds this would be. It looks like the avx512 gemm stuff may have had an update in the pipeline for Eigen 3.4

http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1642

// yadayada
        vmovapd zmm11, ZMMWORD PTR [rdx+256]
        vfmadd231pd     zmm6, zmm11, QWORD PTR [rax-256]{1to8}
        vfmadd231pd     zmm4, zmm11, QWORD PTR [rax-248]{1to8}
        vfmadd231pd     zmm9, zmm11, QWORD PTR [rax-240]{1to8}
        vfmadd132pd     zmm11, zmm2, QWORD PTR [rax-232]{1to8}
// yadayada

Though there is still a bunch of yadayada and the unpacking that will probs leave your implementation faster for small sizes

Again no worries! It’s good to be excited about things

Intel “tries” to have good performance on AMD, but yeah it probably won’t be as nice. Maybe we can look into checking if a user has mlk setup and build Eigen with it if they do.

Thanks for diving down into the assembly. Some of this sounds like it might be better discussed on the Eigen lists. They’re very responsive if there are general techniques that can speed things up. You might also ask them how to get the most out of their code, as it may require some flags to set up.

It’s worse than that—it totally defeats locality and we have to do a lot of copying to get things into vector form to pass into other algorithms.

If it’s Eigen<var, -1, -1>, then it’s effectively Eigen<vari*, -1, -1>. Each pointer points to a contiguous pair of value and adjoint. These pointers don’t need to point to things contiguous (just imagine an assignment, which sets a new pointer).

Only after unpacking that mess into an Eigen<double, -1, -1>, where it’s a contiguous array of double values in column-major order.

The vectorization here is of the function, not the assembly. We need clearer names. There’s a reduction inside the bernoulli_logit_glm so that it returns a single log probability mass. Was the Julia version doing reverse-mode autodiff?

It’s really more geared toward big dynamic matrices. Partly we just don’t care so much when the matrices are small as everything’s fast.

Of course—it’s just general memory locality issues to start.

Nobody’s found a good data structure that lets you do everything we want, like have matrices where you can set elements for reverse-mode autodiff and where you get memory locality. You can design systems that do matrix derivatives, but then accessing their elements is painful. Everyone wants to go down this path when they first see our (and others’) code. I invite you to try, because if someone could genuinely solve the locality problem in a useful way, it’d be huge.

The best thing to study for our autodiff aside from the code is the arXiv paper.

One of the reasons we used Eigen was that it supports templating. But we wind up doing most of the packing/unpacking on Eigen<var, ...> structures, which eats up the gains from something like SIMD.

Sure; here are the same three sizes I tested with Eigen:

julia> using BenchmarkTools, LinearAlgebra, PaddedMatrices, Random

julia> using PaddedMatrices: AbstractMutableFixedSizePaddedMatrix

julia> const CXXDIR = "/home/chriselrod/Documents/progwork/Cxx";

julia> const PADDED_BLAZE_LIB = joinpath(CXXDIR, "libblazemul_padded.so");

julia> const UNPADDED_BLAZE_LIB = joinpath(CXXDIR, "libblazemul_unpadded.so");

julia> @inline aligned_8(N) = (N & 7) == 0
aligned_8 (generic function with 1 method)

julia> @generated function blazemul!(
           C::AbstractMutableFixedSizePaddedMatrix{M,N,T,PC},
           A::AbstractMutableFixedSizePaddedMatrix{M,K,T,PA},
           B::AbstractMutableFixedSizePaddedMatrix{K,N,T,PB}
       ) where {M,N,K,T,PC,PA,PB}
           lib = if aligned_8(PC) && aligned_8(PA) && aligned_8(PB)
               PADDED_BLAZE_LIB
           else
               UNPADDED_BLAZE_LIB
           end
           func = QuoteNode(Symbol(:mul_,M,:x,K,:times,K,:x,N))
           quote
               ccall(
                   ($func, $lib), Cvoid,
                   (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble}),
                   C, A, B
               )
           end
       end
blazemul! (generic function with 1 method)

julia> align_64(x) = (x + 63) & ~63
align_64 (generic function with 1 method)

julia> align_64(x, ptr::Ptr{T}) where {T} = align_64(x*sizeof(T) + ptr)
align_64 (generic function with 2 methods)

julia> const PTR = Base.unsafe_convert(Ptr{Float64}, Libc.malloc(1<<21)); # more than enough space

julia> align_64(x::Ptr{T}) where {T} = reinterpret(Ptr{T},align_64(reinterpret(UInt,x)))
align_64 (generic function with 3 methods)

julia> approx_equal(A,B) = all(x -> isapprox(x[1],x[2]), zip(A,B))
approx_equal (generic function with 1 method)

julia> A16x24 = PtrMatrix{16,24,Float64,16}(align_64(PTR)); randn!(A16x24);

julia> B24x14 = PtrMatrix{24,14,Float64,24}(align_64(16*24,pointer(A16x24))); randn!(B24x14);

julia> C16x14b = PtrMatrix{16,14,Float64,16}(align_64(24*14,pointer(B24x14)));

julia> C16x14j = PtrMatrix{16,14,Float64,16}(align_64(16*14,pointer(C16x14b)));

julia> @benchmark blazemul!($C16x14b, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     101.626 ns (0.00% GC)
  median time:      102.874 ns (0.00% GC)
  mean time:        105.372 ns (0.00% GC)
  maximum time:     512.652 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     942

julia> @benchmark mul!($C16x14j, $A16x24, $B24x14)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     91.679 ns (0.00% GC)
  median time:      91.904 ns (0.00% GC)
  mean time:        93.996 ns (0.00% GC)
  maximum time:     136.302 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     955

julia> approx_equal(C16x14b, C16x14j)
true

julia> # Prime sizes
       A31x37 = PtrMatrix{31,37,Float64,31}(align_64(16*14,pointer(C16x14j))); randn!(A31x37);

julia> B37x29 = PtrMatrix{37,29,Float64,37}(align_64(31*37,pointer(A31x37))); randn!(B37x29);

julia> C31x29b = PtrMatrix{31,29,Float64,31}(align_64(37*29,pointer(B37x29)));

julia> C31x29j = PtrMatrix{31,29,Float64,31}(align_64(31*29,pointer(C31x29b)));

julia> @benchmark blazemul!($C31x29b, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     1.868 μs (0.00% GC)
  median time:      1.886 μs (0.00% GC)
  mean time:        1.929 μs (0.00% GC)
  maximum time:     4.423 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark mul!($C31x29j, $A31x37, $B37x29)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     582.324 ns (0.00% GC)
  median time:      602.297 ns (0.00% GC)
  mean time:        612.824 ns (0.00% GC)
  maximum time:     2.398 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     182

julia> approx_equal(C31x29b, C31x29j)
true

julia> A31x37p = PtrMatrix{31,37,Float64,32}(align_64(31*29,pointer(C31x29j))); randn!(A31x37p);

julia> B37x29p = PtrMatrix{37,29,Float64,40}(align_64(32*37,pointer(A31x37p))); randn!(B37x29p);

julia> C31x29bp = PtrMatrix{31,29,Float64,32}(align_64(40*29,pointer(B37x29p)));

julia> C31x29jp = PtrMatrix{31,29,Float64,32}(align_64(32*29,pointer(C31x29bp)));

julia> @benchmark blazemul!($C31x29bp, $A31x37p, $B37x29p)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     628.894 ns (0.00% GC)
  median time:      648.347 ns (0.00% GC)
  mean time:        663.371 ns (0.00% GC)
  maximum time:     972.388 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     170

julia> @benchmark mul!($C31x29jp, $A31x37p, $B37x29p)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     556.533 ns (0.00% GC)
  median time:      575.101 ns (0.00% GC)
  mean time:        584.645 ns (0.00% GC)
  maximum time:     1.001 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     184

julia> approx_equal(C31x29bp, C31x29jp)
true 

Summary (minimum times in nanoseconds):
16x24 * 24x14
Eigen: 446.9 (g++)
PaddedMatrices.jl: 91.5
Blaze: 101.7

31x37 * 37x29
Eigen: 2346 (clang++)
PaddedMatrices: 594.2
Blaze: 1885

31x37 * 37x29 with padding:
PaddedMatrices: 570.8
Blaze: 643.8

I compiled Blaze with

g++ -Ofast -march=native -mprefer-vector-width=512 -shared -fPIC blaze_test.cpp -o libblazemul_padded.so
g++ -DBLAZE_USE_PADDING=0 -Ofast -march=native -mprefer-vector-width=512 -shared -fPIC blaze_test.cpp -o libblazemul_unpadded.so

For with and without padding, respectively.

No worries. I did think you think you might be implying that my results were invalid – ie, that the benchmarked code wasn’t doing what I thought it is doing. Eg, either that the matrix multiplication code was broken and producing invalid results, or that the compiler was able to recognize constants in the benchmark and optimize them away / hoist them out.

I hope I don’t come across as dismissive or disrespectful either. I’m here because I’m a fan of the Stan project / MCMC and HMC/NUTS. It would be great if Stan could be made faster.

It may still be behind until they add kernels to take advantage of all 32 floating point registers:

It could be reduced by exploiting the additional registers offered by AVX512 but this has yet to be implemented.

At least, until relatively large sizes. I’ll eventually look into adding packing/unpacking to my library. But it isn’t a high priority. Libraries like MKL are so well optimized, I may as well use it (ie, use the library Julia is linked with, and let users choose which library that is to the extant they can/care).

Thanks for those explanations. That all sounds far from ideal.

Does this happen automatically? If so, then it’ll at least scale well / that unpacking will be “free” asymptotically (with respect to matrix multiplication).

Well, the confusing thing is that (in my experience) that is normally what people coming from interpreted languages mean by vectorization, while people coming from compiled langauges are normally talking about the assembly / autovectorizer.
Most folks I talk to in person mostly have R experience, so it’d be nice if there were less ambiguous language to use in general.

Returns the logdensity (reduced to fit into one register, ie down to 4 or 8 doubles) and analytical gradients (equivalent to the Jacobians here, given that the functions are N-to-1).
Multiple dispatch makes it easy to take advantage of sparsity; I can define a lot of common lpdfs and other commonly used functions, and have them return diagonal, block-diagonal, or even specially-typed objects representing the Jacobians.

That is the jist of my approach:

  1. Write optimized versions of functions that also return (sparse, when possible) representations of the Jacobians.
  2. Use source to source transformation to call these functions.
  3. Multiple dispatch dispatches to the appropriate update functions for each Jacobian type, returning the derivative.

This lets me avoid having to use any dual number types / data layouts that may interfere with vectorization.
It isn’t as flexible as Stan’s approach (I don’t even have loops or control flow like if/else working!), but I plan on moving to a hybrid approach where I just use some other autodiff library to differentiate the parts that my library can’t. Probably Zygote once it gets more mature.

As an example of the “write optimized versions of functions”, the multi normal cholesky function:

#include <cmath>
#include <stan/math.hpp>

extern "C" {
double normal_lpdf_cholesky(double* dY, double* dmu, double* dL, double* Y, double* mu, double* L, long N, long P){
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,Eigen::Dynamic> vY(P,N), vL(P,P);
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,1> vmu(P), vY1col(P);
  for (long p = 0; p < P; p++){
    vmu(p) = mu[p];
    for (long pr = 0; pr < p; pr++){
      vL(pr,p) = 0;
    }
    for (long pr = p; pr < P; pr++){
      vL(pr,p) = L[pr + p*P];
    }
  }

  stan::math::var lp = 0;

  for (long n = 0; n < N; n++){
    for (long p = 0; p < P; p++){
      vY1col(p) = Y[p + n*P];
    }
    lp += stan::math::multi_normal_cholesky_log<true>(vY1col, vmu, vL);
    for (long p = 0; p < P; p++){
      vY(p,n) = vY1col(p);
    }    
  }

  lp.grad();

  for (long p = 0; p < P; p++){
    dmu[p] = vmu(p).adj();
    for (long pr = p; pr < P; pr++){
      dL[pr + p*P] = vL(pr,p).adj();
    }
  }
  for (long n = 0; n < N; n++){
    for (long p = 0; p < P; p++){
      dY[p + n*P] = vY(p,n).adj();
    }
  }  
  return lp.val();
}

}

which I compiled with

g++ -O3 -march=native -shared -fPIC -DNDEBUG -DEIGEN_NO_DEBUG -DADEPT_STACK_THREAD_UNSAFE -fno-signed-zeros -fno-trapping-math -fassociative-math -pipe -feliminate-unused-debug-types -I$STAN_MATH -I$STAN_MATH/lib/eigen_3.3.3 -I$STAN_MATH/lib/boost_1.69.0 -I$STAN_MATH/lib/sundials_4.1.0/include stan_normal_test.cpp -o libstan_normal_chol.so

If those flags can be improved, let me know.

julia> using ProbabilityDistributions, BenchmarkTools, PaddedMatrices, LinearAlgebra, SIMDPirates, VectorizationBase

julia> using ProbabilityDistributions: Normal, ∂Normal

julia> using StructuredMatrices: MutableLowerTriangularMatrix, MutableSymmetricMatrixL, choltest!

julia> approx_equal(A,B) = all(x -> isapprox(x[1],x[2]), zip(A,B))
approx_equal (generic function with 1 method)

julia> N, P = 480, 20;

julia> Y = @Mutable randn(N, P);

julia> μ = @Mutable randn(P);

julia> Σ = (@Constant randn(P,3P>>1)) |> x -> x * x' |> MutableSymmetricMatrixL;

julia> L = MutableLowerTriangularMatrix{P,Float64}(undef);

julia> choltest!(L, Σ); # test because the function is a work in progress

julia> Ya′ = Array(Y');

julia> μa = Array(μ);

julia> La = LowerTriangular(Array(L));

julia> const SPTR = PaddedMatrices.StackPointer(VectorizationBase.align(Libc.malloc(1<<26) + 63));

julia> # Not calling ∂Normal directly as workaround for https://github.com/JuliaLang/julia/issues/32414
       function normal_chol(sptr, Y, μ, L) 
           sptr2, (lp, ∂Y, ∂μ, ∂L) = ∂Normal(sptr, Y, μ, L, Val{(true,true,true)}())
           SIMDPirates.vsum(lp), ∂Y, ∂μ, ∂L 
       end
normal_chol (generic function with 1 method)

julia> # g++ -O3 -march=native -shared -fPIC -DNDEBUG -DEIGEN_NO_DEBUG -DADEPT_STACK_THREAD_UNSAFE -fno-signed-zeros -fno-trapping-math -fassociative-math -pipe -feliminate-unused-debug-types -I$STAN_MATH -I$STAN_MATH/lib/eigen_3.3.3 -I$STAN_MATH/lib/boost_1.69.0 -I$STAN_MATH/lib/sundials_4.1.0/include stan_normal_test.cpp -o libstan_normal_chol.so
       const STANDIR = "/home/chriselrod/Documents/progwork/Cxx";

julia> const STANLIB0 = joinpath(STANDIR, "libstan_normal_chol.so");

julia> function stan_multinormal_cholesky!(∂Y, ∂μ, ∂L, Y, μ, L)
           P, N = size(Y)
           ccall(
               (:normal_lpdf_cholesky, STANLIB0), Cdouble,
               (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Clong,Clong),
               ∂Y, ∂μ, ∂L.data, Y, μ, L.data, N, P
           )
       end
stan_multinormal_cholesky! (generic function with 1 method)

julia> ∂Ya′ = similar(Ya′);

julia> ∂μa = similar(μa);

julia> ∂La = similar(La);

julia> lp, ∂Y, ∂μ, ∂L = normal_chol(SPTR, Y, μ, L); lp
-15591.837040912393

julia> stan_multinormal_cholesky!(∂Ya′, ∂μa, ∂La, Ya′, μa, La)
-15591.837040912407

julia> approx_equal(∂Y', ∂Ya′)
true

julia> approx_equal(∂μ', ∂μa')
true

julia> approx_equal(∂L, ∂La)
true

julia> @benchmark normal_chol(SPTR, $Y, $μ, $L)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.845 μs (0.00% GC)
  median time:      7.961 μs (0.00% GC)
  mean time:        8.073 μs (0.00% GC)
  maximum time:     105.388 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

julia> @benchmark stan_multinormal_cholesky!($∂Ya′, $∂μa, $∂La, $Ya′, $μa, $La)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     70.675 ms (0.00% GC)
  median time:      81.732 ms (0.00% GC)
  mean time:        82.230 ms (0.00% GC)
  maximum time:     94.341 ms (0.00% GC)
  --------------
  samples:          61
  evals/sample:     1

That is 8 microseconds (for the Julia code) vs 70 ms for Stan’s autodiff, or a 8,750-fold difference, to get the log density (dropping constants) and the gradients with respect to all the arguments.
(Edit improved Julia’s performance, because I forgot to allign Julia’s stack – I guess that’s why to use vmovapd instead of vmovupd, AFAIK they’re equally fast, it’s just that the latter is slower while the former crahses when unaligned.)

I thought stan::math would be much faster if instead of calling stan::math::multi_normal_cholesky_log<true> I called (Y' / L).dot_self() - N * L.diag().sum(), but as the following post illustrates, this doesn’t seem to be the case.

I’m definitely a non-expert when it comes to autodiff and their implementations, so I seriously doubt I’ve solved the locality problem. But are there any problems with the general hybrid approach of:

  1. First line source to source, using optimized methods when available (like the bernoulli logit, or multi-normal cholesky) as described in the earlier 1-3. Keep all data packed to take advantage of SIMD when possible.
  2. Fall back to more general autodiff for complicated things like control flow and array indexing (especially assignments). If you need dual number types for this, try to use struct-of-array type data structures – both because they often enable SIMD, and because that stops you from having to pack and unpack. I would do this if I had to rely on Julia’s ForwardDiff, but Zygote (another source to source library) doesn’t need this.

I also describe the approach a little here. I show an example model and the code transformations to produce the reverse-mode AD code – largely a wall of ProbabilityModels.PaddedMatrices.RESERVED_INCREMENT_SEED_RESERVED but because of multiple dispatch (like operator overloading), they each call different methods appropriate to the types.

I broke the code recently as I’m working on tracking changes to DynamicHMC.jl, but I plan on fixing it over labor day weekend.

Compiling the following function:

double normal_lpdf_cholesky2(double* dY, double* dmu, double* dL, double* Y, double* mu, double* L, long N, long P){
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,Eigen::Dynamic> vY(N,P), vL(P,P);
  Eigen::Matrix<stan::math::var,1,Eigen::Dynamic> vmu(P);
  //  std::cout << "N: " << N << " ; P: " << P << std::endl;
  for (long p = 0; p < P; p++){
    vmu(p) = mu[p];
    for (long pr = p; pr < P; pr++){
      vL(pr,p) = L[pr + p*P];
    }
    for (long n = 0; n < N; n++){
      vY(n,p) = Y[n + p*N];
    }
  }

  
  Eigen::Matrix<stan::math::var,Eigen::Dynamic,Eigen::Dynamic> vDelta = vY.rowwise() - vmu;

  vL.triangularView<Eigen::Lower>().adjoint().solveInPlace<Eigen::OnTheRight>(vDelta);
  
  stan::math::var lp = -0.5*(vDelta.squaredNorm()) - N * (vL.diagonal().array().log().sum());

  lp.grad();

  for (long p = 0; p < P; p++){
    dmu[p] = vmu(p).adj();
    for (long pr = p; pr < P; pr++){
      dL[pr + p*P] = vL(pr,p).adj();
    }
    for (long n = 0; n < N; n++){
      dY[n + p*N] = vY(n,p).adj();
    }
  }
  return lp.val();
}

and using the same data layout for Y as my Julia version, I now get:

julia> function stan_multinormal_cholesky2!(∂Y, ∂μ, ∂L, Y, μ, L)
           N, P = size(Y)
           ccall(
               (:normal_lpdf_cholesky2, STANLIB0), Cdouble,
               (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Clong,Clong),
               ∂Y, ∂μ, ∂L.data, Y, μ, L.data, N, P
           )
       end
stan_multinormal_cholesky2! (generic function with 1 method)

julia> Ya = Array(Y);

julia> ∂Ya = similar(Ya);

julia> ∂μa = similar(μa);

julia> ∂La = similar(La);

julia> lp, ∂Y, ∂μ, ∂L = normal_chol(SPTR, Y, μ, L); lp
-15591.837040912393

julia> stan_multinormal_cholesky2!(∂Ya, ∂μa, ∂La, Ya, μa, La)
-15591.837040912395

julia> approx_equal(∂Y, ∂Ya)
true

julia> approx_equal(∂μ, ∂μa)
true

julia> approx_equal(∂L, ∂La)
true

julia> @benchmark normal_chol(SPTR, $Y, $μ, $L)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     7.955 μs (0.00% GC)
  median time:      8.004 μs (0.00% GC)
  mean time:        8.101 μs (0.00% GC)
  maximum time:     15.369 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

julia> @benchmark stan_multinormal_cholesky2!($∂Ya, $∂μa, $∂La, $Ya, $μa, $La)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     172.438 ms (0.00% GC)
  median time:      193.080 ms (0.00% GC)
  mean time:        191.596 ms (0.00% GC)
  maximum time:     214.988 ms (0.00% GC)
  --------------
  samples:          27
  evals/sample:     1

Or about a 20,000-fold difference in performance to compute the same thing.
For reference, I’m on the latest development version of Stan-math. I compiled the code using:

g++ -O3 -march=native -shared -fPIC -DNDEBUG -DEIGEN_NO_DEBUG -DADEPT_STACK_THREAD_UNSAFE -fno-signed-zeros -fno-trapping-math -fassociative-math -pipe -feliminate-unused-debug-types -I$STAN_MATH -I$STAN_MATH/lib/eigen_3.3.3 -I$STAN_MATH/lib/boost_1.69.0 -I$STAN_MATH/lib/sundials_4.1.0/include stan_normal_test.cpp -o libstan_normal_chol.so

The macro definitions were just copy and pasted from the Stan-math paper. The rest was generic optimization flags or include dirs (I installed Stan math and the dependencies via git cloning CmdStan).

I’m on the latest g++:

> g++ --version
g++ (Clear Linux OS for Intel Architecture) 9.2.1 20190828 gcc-9-branch@274989
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

EDIT:
For the sake of it, adding the analytical gradients using Eigen (note that the Eigen code is using dynamically sized arrays, while the Julia code is using statically sized; additionally, the Eigen code is copying a little excessively, but I didn’t want to look too heavily into interop for dynamically sized arrays):

julia> function stan_multinormal_cholesky3!(∂Y, ∂μ, ∂L, Y, μ, L)
           N, P = size(Y)
           ccall(
               (:normal_lpdf_cholesky3, STANLIB0), Cdouble,
               (Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Ptr{Cdouble},Clong,Clong),
               ∂Y, ∂μ, ∂L.data, Y, μ, L.data, N, P
           )
       end
stan_multinormal_cholesky3! (generic function with 1 method)

julia> lp, ∂Y, ∂μ, ∂L = normal_chol(SPTR, Y, μ, L); lp
-14785.977219901622

julia> stan_multinormal_cholesky3!(∂Ya, ∂μa, ∂La, Ya, μa, La)
-14785.97721990162

julia> approx_equal(∂Y, ∂Ya)
true

julia> approx_equal(∂μ, ∂μa)
true

julia> approx_equal(∂L, ∂La)
true

julia> @benchmark normal_chol(SPTR, $Y, $μ, $L)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     8.098 μs (0.00% GC)
  median time:      8.360 μs (0.00% GC)
  mean time:        8.414 μs (0.00% GC)
  maximum time:     13.498 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     4

julia> @benchmark stan_multinormal_cholesky3!($∂Ya, $∂μa, $∂La, $Ya, $μa, $La)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     57.995 μs (0.00% GC)
  median time:      60.663 μs (0.00% GC)
  mean time:        60.579 μs (0.00% GC)
  maximum time:     181.221 μs (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1

Over 7 times slower, but the analytical gradient is definitely much better here.
C++ code:


double normal_lpdf_cholesky3(double* dY, double* dmu, double* dL, double* Y, double* mu, double* L, long N, long P){
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> vY(N,P), vL(P,P);
  Eigen::Matrix<double,1,Eigen::Dynamic> vmu(P);
  //  std::cout << "N: " << N << " ; P: " << P << std::endl;
  for (long p = 0; p < P; p++){
    vmu(p) = mu[p];
    dmu[p] = 0;
  }
  for (long p = 0; p < P; p++){
    for (long pr = p; pr < P; pr++){
      vL(pr,p) = L[pr + p*P];
    }
    for (long n = 0; n < N; n++){
      vY(n,p) = Y[n + p*N];
    }
  }

  
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> vDelta = vY.rowwise() - vmu;

  vL.triangularView<Eigen::Lower>().adjoint().solveInPlace<Eigen::OnTheRight>(vDelta);
  
  double lp = -0.5*(vDelta.squaredNorm()) - N * (vL.diagonal().array().log().sum());
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> dlpdy = vL.triangularView<Eigen::Lower>().solve<Eigen::OnTheRight>(vDelta);
  Eigen::Matrix<double,Eigen::Dynamic,Eigen::Dynamic> dlpdl = dlpdy.adjoint() * vDelta;

  for (long p = 0; p < P; p++){
    for (long n = 0; n < N; n++){
      dmu[p] += dlpdy(n,p);
      dY[n + p*N] = -dlpdy(n,p);
    }
    dL[p + p*P] = dlpdl(p,p) - N / L[p + p*P];
    for (long pr = p+1; pr < P; pr++){
      dL[pr + p*P] = dlpdl(pr,p);
    }
  }
  return lp;
}

I compiled with the same flags as before.

2 Likes

I’m not sure what you mean by automatically. What happens is that we construct two new matrices and do a loop to copy out the values and adjoints. It’s expensive and has to hit the heap for the matrices.

Yes, we know, and thus are sorry we ever used that term. Do you have a better term for what we’re doing with things like our probability functions?

Sparsity is key. The alternative is to use the adjoint-vector multiply operation that never has to construct an explicit Jacobian. That’s what we do for most of our matrix operations.

That’s a lot! What is the speedup coming from? And it looks even bigger in later posts. Is this something we could do in Stan?

I don’t know how to do this.

1 Like