-
-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
argmax with user-defined function #16
Comments
You have caught my attention! The standard library already provides this as the iterator method fn arg_max<T, F, I>(f: F, iter: I) -> <I as IntoIterator>::Item
where
I: IntoIterator,
F: FnMut(&<I as IntoIterator>::Item) -> T,
T: Ord,
{
iter.into_iter().max_by_key(f).unwrap()
} User-defined functions can be instantiated inline. I'm not quite sure which number type Julia has assumed in the code above, so I went for the 32-bit signed integer. let range_max = |f, a, b| arg_max(f, i32::min(a, b)..i32::max(a, b));
range_max(|x: &i32| -(x - 3).abs(), -10, 10) Also, the resulting machine code from the compiler explorer: subq $24, %rsp
movl $-10, 12(%rsp)
leaq 12(%rsp), %rax
movl 12(%rsp), %edx
movl $10, 16(%rsp)
leaq 16(%rsp), %rax
movl 16(%rsp), %r8d
cmpl %r8d, %edx
movl %edx, %eax
cmovgl %r8d, %eax
cmovgel %edx, %r8d
cmpl %r8d, %eax
jge .LBB0_5
leal 1(%rax), %edx
cmpl %r8d, %edx
jge .LBB0_4
leal -3(%rax), %edi
movl $3, %esi
subl %eax, %esi
testl %edi, %edi
cmovnsl %edi, %esi
negl %esi
.LBB0_3:
leal -3(%rdx), %edi
movl $3, %ecx
subl %edx, %ecx
testl %edi, %edi
cmovnsl %edi, %ecx
negl %ecx
cmpl %ecx, %esi
cmovlel %edx, %eax
leal 1(%rdx), %edi
cmovlel %ecx, %esi
movl %edi, %edx
cmpl %r8d, %edi
jne .LBB0_3
.LBB0_4:
movl %eax, 20(%rsp)
leaq 20(%rsp), %rax
addq $24, %rsp
retq
.LBB0_5:
leaq .Lbyte_str.2(%rip), %rdi
callq core::panicking::panic@PLT
ud2 |
Nice, looks like pretty comparably good machine code! |
The I've tried to illustrate this point with an example of the iterated logisitic map (see below, from my Why Julia? talk). I would really like an example showing efficient code generated from plugging user-defined differential equations into a generic ODE solver. But bringing Arrays into it brings in a whole bunch of function calls and kind of obscures the main point. Maybe StaticArrays or a 1d ODE plugged into a simple generic 1d ODE solver would be nice & clear. The logisitic map example function iterator(f, N)
# construct f^N
function fN(x)
for i ∈ 1:N
x = f(x)
end
x
end
fN
end
f(x) = 4*x*(1-x) # f(x) = logistic map
fᴺ = iterator(f, 10^6) # fᴺ(x) = millionth iterate of logistic map The LLVM IR code shows that the code for
|
A simple implementation of RK4 without intermediate saving on a 1D ODE involves no arrays and would show off higher order functions very well. If you want to use tuples then something like the Lorenz equation would work too. For higher order functions, the hidden difference is static vs shared library compilation. All of the wrapped codes typically used in R/Python/etc. are compiled as shared libraries for obvious reasons and then linked to. This adds quite a bit of overhead for smaller ODEs. I think this should be highlighted somehow since it's not just ODEs but also optimization where this matters. |
Yes... function rungekutta4(f, x, N, Δt)
Δt2 = Δt/2
Δt6 = Δt/6
t = 0.0
for n = 1:N
s1 = f(t, x)
s2 = f(t + Δt2, x + Δt2*s1)
s3 = f(t + Δt2, x + Δt2*s2)
s4 = f(t + Δt, x + Δt *s3)
x = x + Δt6*(s1 + 2s2 + 2s3 + s4)
end
x
end
f(t,x) = x*(1-x) # 1d logistic ODE LLVM IR for
|
Timings on that Julia code and two C implementations: one that passes f to rk4 by a function pointer, another that inlines f by defining it inline prior to the rk4 definition. For 10^5 rk4 steps of the 1d logistic ODE,
I'm surprised that the function-call overhead is so low. Thinking forward, I'm not seeing how to do fair & convincing comparisons with other languages. User-written rk4s will be terrible in interpreted languages, people will justly complain "use the library function!" Library functions will all use dynamically-allocated vectors and return the full trajectory (t,x). |
That’s why I thought argmax with a user-defined function was a good choice: it’s not something that’s in a library. |
I'm coming around to |
I think the inability to write a good differential equation solver in many languages is something that is of note itself... |
Hi, I just wanted to give some examples from the world of D language. So, I have some experience in writing numerical code in C, Fortran and D, and I even developed ODE solvers in D, that are super efficient. It looks like Rust and Julia produce really efficient code for numerical codes even when using high level abstractions. And that is great. Here is a short comparison with D: argmaxargmax example, https://run.dlang.io/is/dJgJxA (use import std.algorithm;
import std.range;
auto arg_max(alias f, Range)(Range r) {
return r.map!f.maxIndex;
}
auto range_max(alias f, T)(T a, T b) {
return arg_max!f(iota(min(a,b), max(a,b)));
}
auto go(T)(T a, T b) {
import std.math : abs;
return range_max!(x => -abs(x-3))(a, b);
}
int main(string[] args) {
auto a = cast(long) args.length - 11; // -10
auto b = cast(long) args.length + 9; // 10
return cast(int) go(a, b) & 0x100;
} The dynamic dependence on args.length in main is added otherwise ldc is smart enough to optimize entire program to Notice that I am using resulting go function in asm: cmpl %edi, %esi
movl %edi, %r10d
cmovlel %esi, %r10d
movl %esi, %edx
cmovll %edi, %edx
movq $-1, %rax
subl %r10d, %edx
jle .LBB1_7
cmpl $2, %edx
jb .LBB1_2
movslq %edx, %r8
notl %esi
notl %edi
cmpl %edi, %esi
cmovgel %esi, %edi
movl $-3, %r9d
subl %edi, %r9d
addl $3, %edi
movl $1, %edx
xorl %r11d, %r11d
.p2align 4, 0x90
.LBB1_4:
leal (%r9,%rdx), %ecx
addl $-1, %ecx
testl %ecx, %ecx
cmovsl %edi, %ecx
negl %ecx
leal (%r10,%r11), %eax
movl $3, %esi
subl %eax, %esi
leal (%r10,%r11), %eax
addl $-3, %eax
testl %eax, %eax
cmovnsl %eax, %esi
negl %esi
movq %rdx, %rax
cmpl %ecx, %esi
jl .LBB1_6
movq %r11, %rax
.LBB1_6:
addq $1, %rdx
addl $-1, %edi
movq %rax, %r11
cmpq %r8, %rdx
jb .LBB1_4
.LBB1_7:
retq
.LBB1_2:
xorl %eax, %eax
retq Runge-Kutta 4For the rk4 (ignoring the fact that https://run.dlang.io/is/fHzucd (again compiler with auto rk4(alias f, T)(T x, int N, T dt) {
import std.range : iota;
auto dt2 = dt/2;
auto dt6 = dt/6;
auto t = 0.0;
foreach (n; iota(0, N)) {
auto s1 = f(t, x);
auto s2 = f(t + dt2, x + dt2*s1);
auto s3 = f(t + dt2, x + dt2*s2);
auto s4 = f(t + dt, x + dt *s3);
x = x + dt6*(s1 + 2.0*s2 + 2.0*s3 + s4);
}
return x;
}
int main(string[] args) {
import std.stdio : writeln;
writeln(rk4!((t, x) => x*(1.0-x))(0.5, 100_000_000, 0.01)); // 1d logistic ODE
return 0;
} I used 10^8, instead of 10^5. That is 1000 longer execution. This results in 1.49ms per 10^5 iterations on my machine. So even faster than Julia, but that might be just machine dependence. I would personally use Compiler is also able to figure out there will be no memory allocations or exceptions thrown in entire code, and execution will be deterministic (pure function), which is nice. Changing The main loop in asm: ... // setup stack in main() (2 instructions)
movsd .LCPI0_0(%rip), %xmm5
movl $100000000, %eax
movsd .LCPI0_1(%rip), %xmm2
movsd .LCPI0_2(%rip), %xmm1
movapd .LCPI0_3(%rip), %xmm8
movsd .LCPI0_4(%rip), %xmm3
.p2align 4, 0x90
.LBB0_1:
movapd %xmm5, %xmm4
movapd %xmm2, %xmm5
subsd %xmm4, %xmm5
mulsd %xmm4, %xmm5
movapd %xmm5, %xmm6
mulsd %xmm1, %xmm6
addsd %xmm4, %xmm6
movapd %xmm2, %xmm7
subsd %xmm6, %xmm7
mulsd %xmm6, %xmm7
movapd %xmm7, %xmm6
mulsd %xmm1, %xmm6
addsd %xmm4, %xmm6
movapd %xmm2, %xmm0
subsd %xmm6, %xmm0
mulsd %xmm6, %xmm0
addsd %xmm7, %xmm7
addsd %xmm5, %xmm7
movddup %xmm0, %xmm0
mulpd %xmm8, %xmm0
unpcklpd %xmm4, %xmm7
addpd %xmm0, %xmm7
movapd %xmm7, %xmm0
movhlps %xmm7, %xmm0
movapd %xmm2, %xmm5
subsd %xmm0, %xmm5
mulsd %xmm0, %xmm5
addsd %xmm7, %xmm5
mulsd %xmm3, %xmm5
addsd %xmm4, %xmm5
addl $-1, %eax
jne .LBB0_1
leaq 24(%rsp), %rbx
movq %rbx, %rdi
movapd %xmm5, (%rsp)
... // call writeln ... Here is a LLVM IR for the rk4 function itself (zero uses in program, because it is inlined in main anyway): ; [#uses = 0]
; Function Attrs: norecurse uwtable
define weak_odr double @pure nothrow @nogc @safe double onlineapp.rk4!(onlineapp.main(immutable(char)[][]).__lambda2, double).rk4(double, int, double)(i8* nonnull %.nest_arg, double %dt_arg, i32 %N_arg, double %x_arg) local_unnamed_addr #1 comdat {
%1 = fmul double %dt_arg, 5.000000e-01 ; [#uses = 2]
%2 = fdiv double %dt_arg, 6.000000e+00 ; [#uses = 1]
%3 = icmp sgt i32 %N_arg, 0 ; [#uses = 1]
%spec.select.i = select i1 %3, i32 %N_arg, i32 0 ; [#uses = 1]
%4 = icmp slt i32 %N_arg, 1 ; [#uses = 1]
br i1 %4, label %endfor, label %forbody
forbody: ; preds = %0, %forbody
%x.024 = phi double [ %25, %forbody ], [ %x_arg, %0 ] ; [#uses = 6, type = double]
%__r43.sroa.0.023 = phi i32 [ %26, %forbody ], [ 0, %0 ] ; [#uses = 1, type = i32]
%5 = fsub double 1.000000e+00, %x.024 ; [#uses = 1]
%6 = fmul double %x.024, %5 ; [#uses = 2]
%7 = fmul double %1, %6 ; [#uses = 1]
%8 = fadd double %x.024, %7 ; [#uses = 2]
%9 = fsub double 1.000000e+00, %8 ; [#uses = 1]
%10 = fmul double %8, %9 ; [#uses = 2]
%11 = fmul double %1, %10 ; [#uses = 1]
%12 = fadd double %x.024, %11 ; [#uses = 2]
%13 = fsub double 1.000000e+00, %12 ; [#uses = 1]
%14 = fmul double %12, %13 ; [#uses = 2]
%15 = fmul double %14, %dt_arg ; [#uses = 1]
%16 = fadd double %x.024, %15 ; [#uses = 2]
%17 = fsub double 1.000000e+00, %16 ; [#uses = 1]
%18 = fmul double %16, %17 ; [#uses = 1]
%19 = fmul double %10, 2.000000e+00 ; [#uses = 1]
%20 = fadd double %6, %19 ; [#uses = 1]
%21 = fmul double %14, 2.000000e+00 ; [#uses = 1]
%22 = fadd double %20, %21 ; [#uses = 1]
%23 = fadd double %22, %18 ; [#uses = 1]
%24 = fmul double %2, %23 ; [#uses = 1]
%25 = fadd double %x.024, %24 ; [#uses = 2]
%26 = add nuw nsw i32 %__r43.sroa.0.023, 1 ; [#uses = 2]
%27 = icmp eq i32 %26, %spec.select.i ; [#uses = 1]
br i1 %27, label %endfor, label %forbody
endfor: ; preds = %forbody, %0
%x.0.lcssa = phi double [ %x_arg, %0 ], [ %25, %forbody ] ; [#uses = 1, type = double]
ret double %x.0.lcssa
} To a large extent with many multi-dimensional ODEs systems, D will outperform C library functions (that call function via pointer and possibly move data via stack or arrays), because of very heavily inlining and ability for compiler to do sub-expression folding and software pipelineing, that basically allows CPU to run multiple versions of |
This is a good benchmark of how well higher-order programming can be optimized, which is something we're actually notably missing in Julia. A basic implementation in Julia was given here:
The really neat part is how tight the resulting machine code can be for a simple user-defined function based on this generic higher-order
argmax
defintion, e.g.:the native code for a call like
rangemax(x -> -abs(x-3), -10, 10)
is quite efficient:This is a case where C may actually have a fairly hard time and C++ will do better using templates. Rust seems like it would do well at this too; most other languages will get killed on this one.
The text was updated successfully, but these errors were encountered: