From ef10e525adf65c20188a8f870c8f9ff32f3b5cd0 Mon Sep 17 00:00:00 2001 From: Sukera <11753998+Seelengrab@users.noreply.github.com> Date: Mon, 9 May 2022 17:47:17 +0200 Subject: [PATCH] Add export for `Splat(f)`, replacing `Base.splat` (#42717) * Deprecate `Base.splat(x)` in favor of `Splat(x)` (now exported) * Add pretty printing of `Splat(f)` --- NEWS.md | 3 +++ base/deprecated.jl | 6 ++++++ base/exports.jl | 1 + base/iterators.jl | 2 +- base/operators.jl | 24 +++++++++++++++++------ base/show.jl | 1 + base/strings/search.jl | 4 ++-- doc/src/base/base.md | 2 +- doc/src/devdocs/ast.md | 2 +- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 2 +- stdlib/LinearAlgebra/src/qr.jl | 2 +- test/broadcast.jl | 4 ++-- test/compiler/inference.jl | 2 +- test/iterators.jl | 2 +- 14 files changed, 40 insertions(+), 17 deletions(-) diff --git a/NEWS.md b/NEWS.md index c2e60b4bc0745..36660f0078c76 100644 --- a/NEWS.md +++ b/NEWS.md @@ -44,6 +44,8 @@ New library functions --------------------- * `Iterators.flatmap` was added ([#44792]). +* New helper `Splat(f)` which acts like `x -> f(x...)`, with pretty printing for + inspecting which function `f` was originally wrapped. ([#42717]) Library changes --------------- @@ -120,6 +122,7 @@ Standard library changes Deprecated or removed --------------------- +* Unexported `splat` is deprecated in favor of exported `Splat`, which has pretty printing of the wrapped function. ([#42717]) External dependencies --------------------- diff --git a/base/deprecated.jl b/base/deprecated.jl index 28a35e23635f4..3ff5155f44821 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -294,3 +294,9 @@ const var"@_noinline_meta" = var"@noinline" @deprecate getindex(t::Tuple, i::Real) t[convert(Int, i)] # END 1.8 deprecations + +# BEGIN 1.9 deprecations + +@deprecate splat(x) Splat(x) false + +# END 1.9 deprecations diff --git a/base/exports.jl b/base/exports.jl index dff6b0c9bc208..a8c8ff8dcda33 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -807,6 +807,7 @@ export atreplinit, exit, ntuple, + Splat, # I/O and events close, diff --git a/base/iterators.jl b/base/iterators.jl index 2702375d0f630..dd72772756795 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -300,7 +300,7 @@ the `zip` iterator is a tuple of values of its subiterators. `zip` orders the calls to its subiterators in such a way that stateful iterators will not advance when another iterator finishes in the current iteration. -See also: [`enumerate`](@ref), [`splat`](@ref Base.splat). +See also: [`enumerate`](@ref), [`Splat`](@ref Base.Splat). # Examples ```jldoctest diff --git a/base/operators.jl b/base/operators.jl index 92c016d00bf03..62874fd3c1a85 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -1185,27 +1185,39 @@ used to implement specialized methods. <(x) = Fix2(<, x) """ - splat(f) + Splat(f) -Defined as +Equivalent to ```julia - splat(f) = args->f(args...) + my_splat(f) = args->f(args...) ``` i.e. given a function returns a new function that takes one argument and splats its argument into the original function. This is useful as an adaptor to pass a multi-argument function in a context that expects a single argument, but -passes a tuple as that single argument. +passes a tuple as that single argument. Additionally has pretty printing. # Example usage: ```jldoctest -julia> map(Base.splat(+), zip(1:3,4:6)) +julia> map(Base.Splat(+), zip(1:3,4:6)) 3-element Vector{Int64}: 5 7 9 + +julia> my_add = Base.Splat(+) +Splat(+) + +julia> my_add((1,2,3)) +6 ``` """ -splat(f) = args->f(args...) +struct Splat{F} <: Function + f::F + Splat(f) = new{Core.Typeof(f)}(f) +end +(s::Splat)(args) = s.f(args...) +print(io::IO, s::Splat) = print(io, "Splat(", s.f, ')') +show(io::IO, s::Splat) = print(io, s) ## in and related operators diff --git a/base/show.jl b/base/show.jl index ca3ca90b29f1b..9af8bfbe8a57e 100644 --- a/base/show.jl +++ b/base/show.jl @@ -46,6 +46,7 @@ end show(io::IO, ::MIME"text/plain", c::ComposedFunction) = show(io, c) show(io::IO, ::MIME"text/plain", c::Returns) = show(io, c) +show(io::IO, ::MIME"text/plain", s::Splat) = show(io, s) function show(io::IO, ::MIME"text/plain", iter::Union{KeySet,ValueIterator}) isempty(iter) && get(io, :compact, false) && return show(io, iter) diff --git a/base/strings/search.jl b/base/strings/search.jl index 938ed8d527d99..6423c01a162bc 100644 --- a/base/strings/search.jl +++ b/base/strings/search.jl @@ -179,7 +179,7 @@ function _searchindex(s::Union{AbstractString,ByteArray}, if i === nothing return 0 end ii = nextind(s, i)::Int a = Iterators.Stateful(trest) - matched = all(splat(==), zip(SubString(s, ii), a)) + matched = all(Splat(==), zip(SubString(s, ii), a)) (isempty(a) && matched) && return i i = ii end @@ -435,7 +435,7 @@ function _rsearchindex(s::AbstractString, a = Iterators.Stateful(trest) b = Iterators.Stateful(Iterators.reverse( pairs(SubString(s, 1, ii)))) - matched = all(splat(==), zip(a, (x[2] for x in b))) + matched = all(Splat(==), zip(a, (x[2] for x in b))) if matched && isempty(a) isempty(b) && return firstindex(s) return nextind(s, popfirst!(b)[1])::Int diff --git a/doc/src/base/base.md b/doc/src/base/base.md index 6b80072fbd630..dd6e51518acf3 100644 --- a/doc/src/base/base.md +++ b/doc/src/base/base.md @@ -256,7 +256,7 @@ new Base.:(|>) Base.:(∘) Base.ComposedFunction -Base.splat +Base.Splat Base.Fix1 Base.Fix2 ``` diff --git a/doc/src/devdocs/ast.md b/doc/src/devdocs/ast.md index 9e9da6da70cb2..1978cd19a9a79 100644 --- a/doc/src/devdocs/ast.md +++ b/doc/src/devdocs/ast.md @@ -425,7 +425,7 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form. * `splatnew` Similar to `new`, except field values are passed as a single tuple. Works similarly to - `Base.splat(new)` if `new` were a first-class function, hence the name. + `Base.Splat(new)` if `new` were a first-class function, hence the name. * `isdefined` diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index ec93556988485..14bf761b8f817 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -18,7 +18,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as vec, zero using Base: IndexLinear, promote_eltype, promote_op, promote_typeof, @propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing, - splat + Splat using Base.Broadcast: Broadcasted, broadcasted using OpenBLAS_jll using libblastrampoline_jll diff --git a/stdlib/LinearAlgebra/src/qr.jl b/stdlib/LinearAlgebra/src/qr.jl index 6334c8a3474ef..61e3b092b2a38 100644 --- a/stdlib/LinearAlgebra/src/qr.jl +++ b/stdlib/LinearAlgebra/src/qr.jl @@ -159,7 +159,7 @@ function Base.hash(F::QRCompactWY, h::UInt) return hash(F.factors, foldr(hash, _triuppers_qr(F.T); init=hash(QRCompactWY, h))) end function Base.:(==)(A::QRCompactWY, B::QRCompactWY) - return A.factors == B.factors && all(splat(==), zip(_triuppers_qr.((A.T, B.T))...)) + return A.factors == B.factors && all(Splat(==), zip(_triuppers_qr.((A.T, B.T))...)) end function Base.isequal(A::QRCompactWY, B::QRCompactWY) return isequal(A.factors, B.factors) && all(zip(_triuppers_qr.((A.T, B.T))...)) do (a, b) diff --git a/test/broadcast.jl b/test/broadcast.jl index 113614505ba74..39af6e20b9a08 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -870,13 +870,13 @@ end ys = 1:2:20 bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys)) @test IndexStyle(bc) == IndexLinear() - @test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys)) + @test sum(bc) == mapreduce(Base.Splat(*), +, zip(xs, ys)) xs2 = reshape(xs, 1, :) ys2 = reshape(ys, 1, :) bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs2, ys2)) @test IndexStyle(bc) == IndexCartesian() - @test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys)) + @test sum(bc) == mapreduce(Base.Splat(*), +, zip(xs, ys)) xs = 1:5:3*5 ys = 1:4:3*4 diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index b400f17cb1fb3..15307a64c355b 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2842,7 +2842,7 @@ j30385(T, y) = k30385(f30385(T, y)) @test @inferred(j30385(:dummy, 1)) == "dummy" @test Base.return_types(Tuple, (NamedTuple{<:Any,Tuple{Any,Int}},)) == Any[Tuple{Any,Int}] -@test Base.return_types(Base.splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}] +@test Base.return_types(Base.Splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}] # test that return_type_tfunc isn't affected by max_methods differently than return_type _rttf_test(::Int8) = 0 diff --git a/test/iterators.jl b/test/iterators.jl index 554e120d94fd6..7ce47233f2ed5 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -608,7 +608,7 @@ end @test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I) @test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I) end - @test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter)) + @test all(Base.Splat(==), zip(Iterators.flatten(map(collect, P)), iter)) end end @testset "empty/invalid partitions" begin