Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

argmax with user-defined function #16

Open
StefanKarpinski opened this issue Jun 26, 2018 · 10 comments
Open

argmax with user-defined function #16

StefanKarpinski opened this issue Jun 26, 2018 · 10 comments

Comments

@StefanKarpinski
Copy link
Member

StefanKarpinski commented Jun 26, 2018

This is a good benchmark of how well higher-order programming can be optimized, which is something we're actually notably missing in Julia. A basic implementation in Julia was given here:

function argmax(f, itr)
    r = iterate(itr)
    r === nothing && error("empty collection")
    m, state = r
    f_m = f(m)
    while true
        r = iterate(itr, state)
        r === nothing && break
        x, state = r
        f_x = f(x)
        isless(f_m, f_x) || continue
        m, f_m = x, f_x
    end
    return m
end

The really neat part is how tight the resulting machine code can be for a simple user-defined function based on this generic higher-order argmax defintion, e.g.:

rangemax(f, a, b) = argmax(f, min(a,b):max(a,b))

the native code for a call like rangemax(x -> -abs(x-3), -10, 10) is quite efficient:

	cmpq	%rdi, %rsi
	movq	%rdi, %r8
	cmovleq	%rsi, %r8
	cmovlq	%rdi, %rsi
	cmpq	%rsi, %r8
	jne	L23
	movq	%r8, %rax
	retq
L23:
	leaq	-3(%r8), %rax
	movl	$3, %ecx
	subq	%r8, %rcx
	testq	%rax, %rax
	cmovnsq	%rax, %rcx
	negq	%rcx
	movq	%r8, %rax
L48:
	movl	$2, %edi
	subq	%r8, %rdi
	leaq	-2(%r8), %r9
	leaq	1(%r8), %rdx
	testq	%r9, %r9
	cmovnsq	%r9, %rdi
	negq	%rdi
	cmpq	%rdi, %rcx
	cmovlq	%rdx, %rax
	cmovlq	%rdi, %rcx
	movq	%rdx, %r8
	cmpq	%rsi, %rdx
	jne	L48
	retq

This is a case where C may actually have a fairly hard time and C++ will do better using templates. Rust seems like it would do well at this too; most other languages will get killed on this one.

@Enet4
Copy link
Contributor

Enet4 commented Jun 26, 2018

Rust seems like it would do well at this too

You have caught my attention! The standard library already provides this as the iterator method max_by_key. So we have this:

fn arg_max<T, F, I>(f: F, iter: I) -> <I as IntoIterator>::Item
where
    I: IntoIterator,
    F: FnMut(&<I as IntoIterator>::Item) -> T,
    T: Ord,
{
    iter.into_iter().max_by_key(f).unwrap()
}

User-defined functions can be instantiated inline. I'm not quite sure which number type Julia has assumed in the code above, so I went for the 32-bit signed integer.

let range_max = |f, a, b| arg_max(f, i32::min(a, b)..i32::max(a, b));
range_max(|x: &i32| -(x - 3).abs(), -10, 10)

Also, the resulting machine code from the compiler explorer:

  subq $24, %rsp
  movl $-10, 12(%rsp)
  leaq 12(%rsp), %rax
  movl 12(%rsp), %edx
  movl $10, 16(%rsp)
  leaq 16(%rsp), %rax
  movl 16(%rsp), %r8d
  cmpl %r8d, %edx
  movl %edx, %eax
  cmovgl %r8d, %eax
  cmovgel %edx, %r8d
  cmpl %r8d, %eax
  jge .LBB0_5
  leal 1(%rax), %edx
  cmpl %r8d, %edx
  jge .LBB0_4
  leal -3(%rax), %edi
  movl $3, %esi
  subl %eax, %esi
  testl %edi, %edi
  cmovnsl %edi, %esi
  negl %esi
.LBB0_3:
  leal -3(%rdx), %edi
  movl $3, %ecx
  subl %edx, %ecx
  testl %edi, %edi
  cmovnsl %edi, %ecx
  negl %ecx
  cmpl %ecx, %esi
  cmovlel %edx, %eax
  leal 1(%rdx), %edi
  cmovlel %ecx, %esi
  movl %edi, %edx
  cmpl %r8d, %edi
  jne .LBB0_3
.LBB0_4:
  movl %eax, 20(%rsp)
  leaq 20(%rsp), %rax
  addq $24, %rsp
  retq
.LBB0_5:
  leaq .Lbyte_str.2(%rip), %rdi
  callq core::panicking::panic@PLT
  ud2

@StefanKarpinski
Copy link
Member Author

Nice, looks like pretty comparably good machine code!

@johnfgibson
Copy link
Contributor

The rangemax illustrates the point well but for the benchmarks I feel like a clearer connection to a widely-familiar application would be better (IMHO).

I've tried to illustrate this point with an example of the iterated logisitic map (see below, from my Why Julia? talk). I would really like an example showing efficient code generated from plugging user-defined differential equations into a generic ODE solver. But bringing Arrays into it brings in a whole bunch of function calls and kind of obscures the main point. Maybe StaticArrays or a 1d ODE plugged into a simple generic 1d ODE solver would be nice & clear.

The logisitic map example

function iterator(f, N)
    
    # construct f^N
    function fN(x)
      for i  1:N             
        x = f(x)
      end
      x
    end    
    
    fN 
end

f(x) = 4*x*(1-x)            # f(x) = logistic map

fᴺ = iterator(f, 10^6)      # fᴺ(x) = millionth iterate of logistic map

The LLVM IR code shows that the code for fᴺ is as tight a loop as you could hope for, the same as for (int n=0; n<1000000; ++n) x = 4*x*(1-x) in C.

define double @julia_fN_62614(%"#fN#1"* nocapture readonly dereferenceable(8), double) #0 !dbg !5 {
top:
  %2 = getelementptr inbounds %"#fN#1", %"#fN#1"* %0, i64 0, i32 1
  %3 = load i64, i64* %2, align 8
  %4 = icmp slt i64 %3, 1
  br i1 %4, label %L14, label %if.preheader

if.preheader:                                     ; preds = %top
  br label %if

if:                                               ; preds = %if.preheader, %if
  %x.03 = phi double [ %8, %if ], [ %1, %if.preheader ]
  %"#temp#.02" = phi i64 [ %5, %if ], [ 1, %if.preheader ]
  %5 = add i64 %"#temp#.02", 1
  %6 = fmul double %x.03, 4.000000e+00
  %7 = fsub double 1.000000e+00, %x.03
  %8 = fmul double %6, %7
  %9 = icmp eq i64 %"#temp#.02", %3
  br i1 %9, label %L14.loopexit, label %if

L14.loopexit:                                     ; preds = %if
  br label %L14

L14:                                              ; preds = %L14.loopexit, %top
  %x.0.lcssa = phi double [ %1, %top ], [ %8, %L14.loopexit ]
  ret double %x.0.lcssa
}

@ChrisRackauckas
Copy link
Member

A simple implementation of RK4 without intermediate saving on a 1D ODE involves no arrays and would show off higher order functions very well. If you want to use tuples then something like the Lorenz equation would work too.

For higher order functions, the hidden difference is static vs shared library compilation. All of the wrapped codes typically used in R/Python/etc. are compiled as shared libraries for obvious reasons and then linked to. This adds quite a bit of overhead for smaller ODEs. I think this should be highlighted somehow since it's not just ODEs but also optimization where this matters.

@johnfgibson
Copy link
Contributor

Yes...

function rungekutta4(f, x, N, Δt) 
    Δt2 = Δt/2
    Δt6 = Δt/6
    t = 0.0
    for n = 1:N
        s1 = f(t, x)
        s2 = f(t + Δt2, x + Δt2*s1)
        s3 = f(t + Δt2, x + Δt2*s2)
        s4 = f(t + Δt,  x + Δt *s3)
        x = x + Δt6*(s1 + 2s2 + 2s3 + s4)
    end
    x
end

f(t,x) = x*(1-x)  # 1d logistic ODE

LLVM IR for rungekutta4(f, 0.5, 100, 0.01) is

define double @julia_rungekutta4_62609(double, i64, double) #0 !dbg !5 {
top:
  %3 = icmp slt i64 %1, 1
  br i1 %3, label %L28, label %if.lr.ph

if.lr.ph:                                         ; preds = %top
  %4 = fmul double %2, 5.000000e-01
  %5 = fdiv double %2, 6.000000e+00
  br label %if

if:                                               ; preds = %if.lr.ph, %if
  %x.03 = phi double [ %0, %if.lr.ph ], [ %27, %if ]
  %"#temp#.02" = phi i64 [ 1, %if.lr.ph ], [ %6, %if ]
  %6 = add i64 %"#temp#.02", 1
  %7 = fsub double 1.000000e+00, %x.03
  %8 = fmul double %x.03, %7
  %9 = fmul double %4, %8
  %10 = fadd double %x.03, %9
  %11 = fsub double 1.000000e+00, %10
  %12 = fmul double %10, %11
  %13 = fmul double %4, %12
  %14 = fadd double %x.03, %13
  %15 = fsub double 1.000000e+00, %14
  %16 = fmul double %14, %15
  %17 = fmul double %16, %2
  %18 = fadd double %x.03, %17
  %19 = fsub double 1.000000e+00, %18
  %20 = fmul double %18, %19
  %21 = fmul double %12, 2.000000e+00
  %22 = fadd double %8, %21
  %23 = fmul double %16, 2.000000e+00
  %24 = fadd double %22, %23
  %25 = fadd double %24, %20
  %26 = fmul double %5, %25
  %27 = fadd double %x.03, %26
  %28 = icmp eq i64 %"#temp#.02", %1
  br i1 %28, label %L28.loopexit, label %if

L28.loopexit:                                     ; preds = %if
  br label %L28

L28:                                              ; preds = %L28.loopexit, %top
  %x.0.lcssa = phi double [ %0, %top ], [ %27, %L28.loopexit ]
  ret double %x.0.lcssa
}

@johnfgibson
Copy link
Contributor

Timings on that Julia code and two C implementations: one that passes f to rk4 by a function pointer, another that inlines f by defining it inline prior to the rk4 definition. For 10^5 rk4 steps of the 1d logistic ODE,

cputime code
1.69 ms Julia
1.72 ms C, inlined
1.82 ms C function pointer

I'm surprised that the function-call overhead is so low.

Thinking forward, I'm not seeing how to do fair & convincing comparisons with other languages. User-written rk4s will be terrible in interpreted languages, people will justly complain "use the library function!" Library functions will all use dynamically-allocated vectors and return the full trajectory (t,x).

@StefanKarpinski
Copy link
Member Author

That’s why I thought argmax with a user-defined function was a good choice: it’s not something that’s in a library.

@johnfgibson
Copy link
Contributor

I'm coming around to argmax. Maybe it's worth having both userfunc_argmax and rungekutta4_lorenz, the latter using libraries for rk2 and user funcs for lorenz.

@ChrisRackauckas
Copy link
Member

I think the inability to write a good differential equation solver in many languages is something that is of note itself...

@baryluk
Copy link

baryluk commented Nov 23, 2018

Hi, I just wanted to give some examples from the world of D language.

So, I have some experience in writing numerical code in C, Fortran and D, and I even developed ODE solvers in D, that are super efficient.

It looks like Rust and Julia produce really efficient code for numerical codes even when using high level abstractions. And that is great.

Here is a short comparison with D:

argmax

argmax example, https://run.dlang.io/is/dJgJxA (use ldc compiler, and compiler options -release -O3):

import std.algorithm;
import std.range;

auto arg_max(alias f, Range)(Range r) {
    return r.map!f.maxIndex;
}

auto range_max(alias f, T)(T a, T b) {
    return arg_max!f(iota(min(a,b), max(a,b)));
}

auto go(T)(T a, T b) {
    import std.math : abs;
    return range_max!(x => -abs(x-3))(a, b);
}

int main(string[] args) {
    auto a = cast(long) args.length - 11;  // -10
    auto b = cast(long) args.length + 9;   // 10
    return cast(int) go(a, b) & 0x100;
}

The dynamic dependence on args.length in main is added otherwise ldc is smart enough to optimize entire program to movl $13, %eax; retq! (because the maximal index is 13, and compiler knows that at compile time due to compile time function evaluation - CTFE). I used cast because otherwise I would be operating on unsigned types. Binary and at the end is just to make return code 0 anyway. but not optimize out go function call.

Notice that I am using r.map!f.maxIndex, which would seems will first construct r.map!f, which will take memory allocation fill it up, and then use maxIndex over it with. Nope. It does it on the fly using lazy evaluation and to some extent because r is random access input range. You can still print r.map!f using writeln, and it will show up as a normal array. It is transparent. So to some extent this is more sophisticated than what Rust is doing using max_by_key, which is specialized function for this task, instead of D version combining to orthogonal methods map and maxIndex. In Phobos D, for finding min/max values of elements, one could use r.map!f.maxElement, or r.maxElement!f, which conceptually are different, but in reality the resulting assembly code is going to be exactly the same, and no temporaries will be generated.

resulting go function in asm:

	cmpl	%edi, %esi
	movl	%edi, %r10d
	cmovlel	%esi, %r10d
	movl	%esi, %edx
	cmovll	%edi, %edx
	movq	$-1, %rax
	subl	%r10d, %edx
	jle	.LBB1_7
	cmpl	$2, %edx
	jb	.LBB1_2
	movslq	%edx, %r8
	notl	%esi
	notl	%edi
	cmpl	%edi, %esi
	cmovgel	%esi, %edi
	movl	$-3, %r9d
	subl	%edi, %r9d
	addl	$3, %edi
	movl	$1, %edx
	xorl	%r11d, %r11d
	.p2align	4, 0x90
.LBB1_4:
	leal	(%r9,%rdx), %ecx
	addl	$-1, %ecx
	testl	%ecx, %ecx
	cmovsl	%edi, %ecx
	negl	%ecx
	leal	(%r10,%r11), %eax
	movl	$3, %esi
	subl	%eax, %esi
	leal	(%r10,%r11), %eax
	addl	$-3, %eax
	testl	%eax, %eax
	cmovnsl	%eax, %esi
	negl	%esi
	movq	%rdx, %rax
	cmpl	%ecx, %esi
	jl	.LBB1_6
	movq	%r11, %rax
.LBB1_6:
	addq	$1, %rdx
	addl	$-1, %edi
	movq	%rax, %r11
	cmpq	%r8, %rdx
	jb	.LBB1_4
.LBB1_7:
	retq
.LBB1_2:
	xorl	%eax, %eax
	retq

Runge-Kutta 4

For the rk4 (ignoring the fact that t is not updated in the main loop of Julia version):

https://run.dlang.io/is/fHzucd (again compiler with ldc, and -release -O3 options).

auto rk4(alias f, T)(T x, int N, T dt) {
    import std.range : iota;
    auto dt2 = dt/2;
    auto dt6 = dt/6;
    auto t = 0.0;
    foreach (n; iota(0, N)) {
        auto s1 = f(t, x);
        auto s2 = f(t + dt2, x + dt2*s1);
        auto s3 = f(t + dt2, x + dt2*s2);
        auto s4 = f(t + dt,  x + dt *s3);
        x = x + dt6*(s1 + 2.0*s2 + 2.0*s3 + s4);
    }
    return x;
}

int main(string[] args) {
    import std.stdio : writeln;
    writeln(rk4!((t, x) => x*(1.0-x))(0.5, 100_000_000, 0.01));  // 1d logistic ODE
    return 0;
}

I used 10^8, instead of 10^5. That is 1000 longer execution. This results in 1.49ms per 10^5 iterations on my machine. So even faster than Julia, but that might be just machine dependence.

I would personally use const in the rk4 for all local variables, as a documentation, but auto is short, and compiler is smart anyway to figure out they are const on its own.

Compiler is also able to figure out there will be no memory allocations or exceptions thrown in entire code, and execution will be deterministic (pure function), which is nice.

Changing foreach over iota (which is implemented in Phobos standard library in D), to built-in language construct foreach (n; 0 .. N), doesn't change generated assembly code. Using 0.0 vs 0, 1.0-x vs 1-x or 2.0*s2 vs 2*s2, doesn't change it either. I just use it as convenience to use same types everywhere.

The main loop in asm:

... //  setup stack in main() (2 instructions)

	movsd	.LCPI0_0(%rip), %xmm5
	movl	$100000000, %eax
	movsd	.LCPI0_1(%rip), %xmm2
	movsd	.LCPI0_2(%rip), %xmm1
	movapd	.LCPI0_3(%rip), %xmm8
	movsd	.LCPI0_4(%rip), %xmm3
	.p2align	4, 0x90
.LBB0_1:
	movapd	%xmm5, %xmm4
	movapd	%xmm2, %xmm5
	subsd	%xmm4, %xmm5
	mulsd	%xmm4, %xmm5
	movapd	%xmm5, %xmm6
	mulsd	%xmm1, %xmm6
	addsd	%xmm4, %xmm6
	movapd	%xmm2, %xmm7
	subsd	%xmm6, %xmm7
	mulsd	%xmm6, %xmm7
	movapd	%xmm7, %xmm6
	mulsd	%xmm1, %xmm6
	addsd	%xmm4, %xmm6
	movapd	%xmm2, %xmm0
	subsd	%xmm6, %xmm0
	mulsd	%xmm6, %xmm0
	addsd	%xmm7, %xmm7
	addsd	%xmm5, %xmm7
	movddup	%xmm0, %xmm0
	mulpd	%xmm8, %xmm0
	unpcklpd	%xmm4, %xmm7
	addpd	%xmm0, %xmm7
	movapd	%xmm7, %xmm0
	movhlps	%xmm7, %xmm0
	movapd	%xmm2, %xmm5
	subsd	%xmm0, %xmm5
	mulsd	%xmm0, %xmm5
	addsd	%xmm7, %xmm5
	mulsd	%xmm3, %xmm5
	addsd	%xmm4, %xmm5
	addl	$-1, %eax
	jne	.LBB0_1
	leaq	24(%rsp), %rbx
	movq	%rbx, %rdi
	movapd	%xmm5, (%rsp)
... // call writeln ...

Here is a LLVM IR for the rk4 function itself (zero uses in program, because it is inlined in main anyway):

; [#uses = 0]
; Function Attrs: norecurse uwtable
define weak_odr double @pure nothrow @nogc @safe double onlineapp.rk4!(onlineapp.main(immutable(char)[][]).__lambda2, double).rk4(double, int, double)(i8* nonnull %.nest_arg, double %dt_arg, i32 %N_arg, double %x_arg) local_unnamed_addr #1 comdat {
  %1 = fmul double %dt_arg, 5.000000e-01          ; [#uses = 2]
  %2 = fdiv double %dt_arg, 6.000000e+00          ; [#uses = 1]
  %3 = icmp sgt i32 %N_arg, 0                     ; [#uses = 1]
  %spec.select.i = select i1 %3, i32 %N_arg, i32 0 ; [#uses = 1]
  %4 = icmp slt i32 %N_arg, 1                     ; [#uses = 1]
  br i1 %4, label %endfor, label %forbody

forbody:                                          ; preds = %0, %forbody
  %x.024 = phi double [ %25, %forbody ], [ %x_arg, %0 ] ; [#uses = 6, type = double]
  %__r43.sroa.0.023 = phi i32 [ %26, %forbody ], [ 0, %0 ] ; [#uses = 1, type = i32]
  %5 = fsub double 1.000000e+00, %x.024           ; [#uses = 1]
  %6 = fmul double %x.024, %5                     ; [#uses = 2]
  %7 = fmul double %1, %6                         ; [#uses = 1]
  %8 = fadd double %x.024, %7                     ; [#uses = 2]
  %9 = fsub double 1.000000e+00, %8               ; [#uses = 1]
  %10 = fmul double %8, %9                        ; [#uses = 2]
  %11 = fmul double %1, %10                       ; [#uses = 1]
  %12 = fadd double %x.024, %11                   ; [#uses = 2]
  %13 = fsub double 1.000000e+00, %12             ; [#uses = 1]
  %14 = fmul double %12, %13                      ; [#uses = 2]
  %15 = fmul double %14, %dt_arg                  ; [#uses = 1]
  %16 = fadd double %x.024, %15                   ; [#uses = 2]
  %17 = fsub double 1.000000e+00, %16             ; [#uses = 1]
  %18 = fmul double %16, %17                      ; [#uses = 1]
  %19 = fmul double %10, 2.000000e+00             ; [#uses = 1]
  %20 = fadd double %6, %19                       ; [#uses = 1]
  %21 = fmul double %14, 2.000000e+00             ; [#uses = 1]
  %22 = fadd double %20, %21                      ; [#uses = 1]
  %23 = fadd double %22, %18                      ; [#uses = 1]
  %24 = fmul double %2, %23                       ; [#uses = 1]
  %25 = fadd double %x.024, %24                   ; [#uses = 2]
  %26 = add nuw nsw i32 %__r43.sroa.0.023, 1      ; [#uses = 2]
  %27 = icmp eq i32 %26, %spec.select.i           ; [#uses = 1]
  br i1 %27, label %endfor, label %forbody

endfor:                                           ; preds = %forbody, %0
  %x.0.lcssa = phi double [ %x_arg, %0 ], [ %25, %forbody ] ; [#uses = 1, type = double]
  ret double %x.0.lcssa
}

To a large extent with many multi-dimensional ODEs systems, D will outperform C library functions (that call function via pointer and possibly move data via stack or arrays), because of very heavily inlining and ability for compiler to do sub-expression folding and software pipelineing, that basically allows CPU to run multiple versions of f in parallel on suitable hardware even between calls.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants