diff --git a/base/channels.jl b/base/channels.jl index 206e8d98927bc..7258aca63657f 100644 --- a/base/channels.jl +++ b/base/channels.jl @@ -301,6 +301,26 @@ struct InvalidStateException <: Exception end showerror(io::IO, ex::InvalidStateException) = print(io, "InvalidStateException: ", ex.msg) +""" + tryput!(c::Channel, v) -> success::Bool + +Try to append an item `v` to the channel `c` and return `true`. Return `false` +if the channel `c` is closed. + +This function blocks until the channel is not full or closed. +""" +function tryput!(c::Channel, v) + try + put!(c, v) + return true + catch e + if isa(e, InvalidStateException) && e.state === :closed + return false + end + rethrow() + end +end + """ put!(c::Channel, v) @@ -393,6 +413,25 @@ function fetch_buffered(c::Channel) end fetch_unbuffered(c::Channel) = throw(ErrorException("`fetch` is not supported on an unbuffered Channel.")) +""" + maybetake!(c::Channel{T}) -> Some(item) or nothing + +Take an `item` from channel `c` if it is open and return `Some(item)`. Return +`nothing` if it is closed. + +When this function is called on an empty channel, it blocks until an item is +available or the channel is closed. +""" +function maybetake!(c::Channel{T}) where T + try + return Some{T}(take!(c)) + catch e + if isa(e, InvalidStateException) && e.state === :closed + return nothing + end + rethrow() + end +end """ take!(c::Channel) @@ -487,16 +526,7 @@ function show(io::IO, ::MIME"text/plain", c::Channel) end end -function iterate(c::Channel, state=nothing) - try - return (take!(c), nothing) - catch e - if isa(e, InvalidStateException) && e.state === :closed - return nothing - else - rethrow() - end - end -end +iterate(c::Channel, ::Nothing = nothing) = + (@something(maybetake!(c), return nothing), nothing) IteratorSize(::Type{<:Channel}) = SizeUnknown() diff --git a/base/exports.jl b/base/exports.jl index 84c53ca405e7d..8f7ccdebb9ce1 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -696,6 +696,8 @@ export # channels take!, put!, + maybetake!, + tryput!, isready, fetch, bind, diff --git a/stdlib/Distributed/src/Distributed.jl b/stdlib/Distributed/src/Distributed.jl index d428a6df0e683..ef41e427a858e 100644 --- a/stdlib/Distributed/src/Distributed.jl +++ b/stdlib/Distributed/src/Distributed.jl @@ -7,7 +7,7 @@ module Distributed # imports for extension import Base: getindex, wait, put!, take!, fetch, isready, push!, length, - hash, ==, kill, close, isopen, showerror + hash, ==, kill, close, isopen, showerror, maybetake!, tryput! # imports for use using Base: Process, Semaphore, JLOptions, buffer_writes, @sync_add, diff --git a/stdlib/Distributed/src/remotecall.jl b/stdlib/Distributed/src/remotecall.jl index 75caf7f3065b7..de82c2499017b 100644 --- a/stdlib/Distributed/src/remotecall.jl +++ b/stdlib/Distributed/src/remotecall.jl @@ -623,18 +623,17 @@ function put_future(rid, v, caller) del_client(rid, caller) nothing end - - put!(rv::RemoteValue, args...) = put!(rv.c, args...) -function put_ref(rid, caller, args...) + +function tryput_ref(rid, caller, args...) rv = lookup_ref(rid) - put!(rv, args...) + success = tryput!(rv.c, args...) if myid() == caller && rv.synctake !== nothing # Wait till a "taken" value is serialized out - github issue #29932 lock(rv.synctake) unlock(rv.synctake) end - nothing + return success end """ @@ -644,12 +643,18 @@ Store a set of values to the [`RemoteChannel`](@ref). If the channel is full, blocks until space is available. Return the first argument. """ -put!(rr::RemoteChannel, args...) = (call_on_owner(put_ref, rr, myid(), args...); rr) +function put!(rr::RemoteChannel, args...) + tryput!(rr, args...) || throw(Base.closed_exception()) + return rr +end + +tryput!(rr::RemoteChannel, args...) = call_on_owner(tryput_ref, rr, myid(), args...)::Bool # take! is not supported on Future take!(rv::RemoteValue, args...) = take!(rv.c, args...) -function take_ref(rid, caller, args...) + +function maybetake_ref(rid, caller, args...) rv = lookup_ref(rid) synctake = false if myid() != caller && rv.synctake !== nothing @@ -660,7 +665,7 @@ function take_ref(rid, caller, args...) end v = try - take!(rv, args...) + maybetake!(rv.c, args...) catch e # avoid unmatched unlock when exception occurs # github issue #33972 @@ -683,7 +688,13 @@ end Fetch value(s) from a [`RemoteChannel`](@ref) `rr`, removing the value(s) in the process. """ -take!(rr::RemoteChannel, args...) = call_on_owner(take_ref, rr, myid(), args...)::eltype(rr) +take!(rr::RemoteChannel, args...) = + @something(maybetake!(rr, args...), throw(Base.closed_exception())) +maybetake!(rr::RemoteChannel, args...) = + call_on_owner(maybetake_ref, rr, myid(), args...)::Union{Some{<:eltype(rr)},Nothing} + +Base.iterate(c::RemoteChannel, ::Nothing = nothing) = + (@something(maybetake!(c), return nothing), nothing) # close and isopen are not supported on Future diff --git a/stdlib/Distributed/test/distributed_exec.jl b/stdlib/Distributed/test/distributed_exec.jl index 78182e876f459..e9cdda67a80b6 100644 --- a/stdlib/Distributed/test/distributed_exec.jl +++ b/stdlib/Distributed/test/distributed_exec.jl @@ -425,12 +425,19 @@ function test_channel(c) @test take!(c) == 5.0 @test isready(c) == false @test isopen(c) == true + @test tryput!(c, :World) + @test tryput!(c, nothing) + @test maybetake!(c) === Some(:World) close(c) @test isopen(c) == false + @test !tryput!(c, nothing) + @test maybetake!(c) === Some(nothing) + @test maybetake!(c) === nothing end test_channel(Channel(10)) test_channel(RemoteChannel(()->Channel(10))) +test_channel(RemoteChannel(()->Channel(10), procs()[end])) c=Channel{Int}(1) @test_throws MethodError put!(c, "Hello") @@ -454,6 +461,9 @@ function test_iteration(in_c, out_c) end test_iteration(Channel(10), Channel(10)) +test_iteration(RemoteChannel(()->Channel(10)), Channel(10)) +test_iteration(RemoteChannel(()->Channel(10), procs()[end]), Channel(10)) + # make sure exceptions propagate when waiting on Tasks @test_throws CompositeException (@sync (@async error("oops"))) try diff --git a/test/channels.jl b/test/channels.jl index 1a989747c3863..bb58597cb1aff 100644 --- a/test/channels.jl +++ b/test/channels.jl @@ -580,6 +580,62 @@ let c = Channel(3) @test repr(MIME("text/plain"), c) == "Channel{Any}(3) (closed)" end +@testset "maybetake!(c)" begin + @testset "buffered" begin + c = Channel(Inf) + put!(c, 1) + close(c) + @test maybetake!(c) === Some(1) + @test maybetake!(c) === nothing + end + + @testset "unbuffered" begin + c = Channel() + @testset "on put!" begin + t = @task maybetake!(c) + yield(t) + @test !istaskdone(t) + put!(c, 1) + @test fetch(t) === Some(1) + end + @testset "on close" begin + t = @task maybetake!(c) + yield(t) + @test !istaskdone(t) + close(c) + @test fetch(t) === nothing + end + end +end + +@testset "tryput!(c, _)" begin + @testset "buffered" begin + c = Channel(Inf) + @test tryput!(c, 1) + close(c) + @test !tryput!(c, 2) + @test collect(c) == [1] + end + + @testset "unbuffered" begin + c = Channel() + @testset "on take!" begin + t = @task tryput!(c, 1) + yield(t) + @test !istaskdone(t) + @test take!(c) == 1 + @test fetch(t) + end + @testset "on close" begin + t = @task tryput!(c, 2) + yield(t) + @test !istaskdone(t) + close(c) + @test !fetch(t) + end + end +end + # PR #41833: data races in Channel @testset "n_avail(::Channel)" begin # Buffered: n_avail() = buffer length + number of waiting tasks