From 00238ab2e6fb65f2291c092a6cd81bb954b9c86c Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 16 Aug 2018 17:45:37 -0400 Subject: [PATCH] Fix reinterpret performance This fixes #25014 by making it more obvious what's going on to LLVM. Instead of a memcpy loop, we use a ccall to :memcpy and turn this into llvm.memcpy at the IR level, which is enough for LLVM to fold everything away. In the benchmark from #25014, we still see some regressions from 0.6, but that is because it needs to dereference through the pointers in the reinterpret and reshape wrappers. In any real code, that dereferencing should be loop-invariantly moved out of the inner loop. (cherry picked from commit 777810b8ea5cb84eacda87ecce5304788f86ebd4) --- base/reinterpretarray.jl | 34 +++++++++++++--------------------- base/reshapedarray.jl | 2 +- src/ccall.cpp | 14 ++++++++++++++ 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/base/reinterpretarray.jl b/base/reinterpretarray.jl index 2cde44e9a386d3..60b268465c14ad 100644 --- a/base/reinterpretarray.jl +++ b/base/reinterpretarray.jl @@ -104,6 +104,8 @@ end _getindex_ra(a, inds[1], tail(inds)) end +@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n) + @inline @propagate_inbounds function _getindex_ra(a::ReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT} # Make sure to match the scalar reinterpret if that is applicable if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0 @@ -123,11 +125,9 @@ end # once it knows the data layout while nbytes_copied < sizeof(T) s[] = a.parent[ind_start + i, tailinds...] - while nbytes_copied < sizeof(T) && sidx < sizeof(S) - unsafe_store!(tptr, unsafe_load(sptr, sidx + 1), nbytes_copied + 1) - sidx += 1 - nbytes_copied += 1 - end + nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied) + _memcpy!(tptr + nbytes_copied, sptr + sidx, nb) + nbytes_copied += nb sidx = 0 i += 1 end @@ -173,34 +173,26 @@ end # element from the original array and overwrite the relevant parts if sidx != 0 s[] = a.parent[ind_start + i, tailinds...] - while nbytes_copied < sizeof(T) && sidx < sizeof(S) - unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1) - sidx += 1 - nbytes_copied += 1 - end + nb = min(sizeof(S) - sidx, sizeof(T)) + _memcpy!(sptr + sidx, tptr, nb) + nbytes_copied += nb a.parent[ind_start + i, tailinds...] = s[] i += 1 sidx = 0 end # Deal with the main body of elements while nbytes_copied < sizeof(T) && (sizeof(T) - nbytes_copied) > sizeof(S) - while nbytes_copied < sizeof(T) && sidx < sizeof(S) - unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1) - sidx += 1 - nbytes_copied += 1 - end + nb = min(sizeof(S), sizeof(T) - nbytes_copied) + _memcpy!(sptr, tptr + nbytes_copied, nb) + nbytes_copied += nb a.parent[ind_start + i, tailinds...] = s[] i += 1 - sidx = 0 end # Deal with trailing partial elements if nbytes_copied < sizeof(T) s[] = a.parent[ind_start + i, tailinds...] - while nbytes_copied < sizeof(T) && sidx < sizeof(S) - unsafe_store!(sptr, unsafe_load(tptr, nbytes_copied + 1), sidx + 1) - sidx += 1 - nbytes_copied += 1 - end + nb = min(sizeof(S), sizeof(T) - nbytes_copied) + _memcpy!(sptr, tptr + nbytes_copied, nb) a.parent[ind_start + i, tailinds...] = s[] end end diff --git a/base/reshapedarray.jl b/base/reshapedarray.jl index ebd2efba68b36d..18b2008f2c4d11 100644 --- a/base/reshapedarray.jl +++ b/base/reshapedarray.jl @@ -222,7 +222,7 @@ end I = ind2sub_rs(axes(A.parent), A.mi, i) _unsafe_getindex_rs(parent(A), I) end -_unsafe_getindex_rs(A, i::Integer) = (@inbounds ret = A[i]; ret) +@inline _unsafe_getindex_rs(A, i::Integer) = (@inbounds ret = A[i]; ret) @inline _unsafe_getindex_rs(A, I) = (@inbounds ret = A[I...]; ret) @inline function setindex!(A::ReshapedArrayLF, val, index::Int) diff --git a/src/ccall.cpp b/src/ccall.cpp index 23e8d1c3ad9800..30a919e027427e 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -1820,6 +1820,20 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs) JL_GC_POP(); return mark_or_box_ccall_result(ctx, strp, retboxed, rt, unionall, static_rt); } + else if (is_libjulia_func(memcpy)) { + const jl_cgval_t &dst = argv[0]; + const jl_cgval_t &src = argv[1]; + const jl_cgval_t &n = argv[2]; + ctx.builder.CreateMemCpy( + ctx.builder.CreateIntToPtr( + emit_unbox(ctx, T_size, dst, (jl_value_t*)jl_voidpointer_type), T_pint8), + ctx.builder.CreateIntToPtr( + emit_unbox(ctx, T_size, src, (jl_value_t*)jl_voidpointer_type), T_pint8), + emit_unbox(ctx, T_size, n, (jl_value_t*)jl_ulong_type), 1, + false); + JL_GC_POP(); + return ghostValue(jl_void_type); + } jl_cgval_t retval = sig.emit_a_ccall( ctx,