diff --git a/base/channels.jl b/base/channels.jl index 31cbd98a2abb26..206e8d98927bc9 100644 --- a/base/channels.jl +++ b/base/channels.jl @@ -37,6 +37,7 @@ mutable struct Channel{T} <: AbstractChannel{T} excp::Union{Exception, Nothing} # exception to be thrown when state !== :open data::Vector{T} + @atomic n_avail_items::Int # Available items for taking, can be read without lock sz_max::Int # maximum size of channel function Channel{T}(sz::Integer = 0) where T @@ -46,7 +47,7 @@ mutable struct Channel{T} <: AbstractChannel{T} lock = ReentrantLock() cond_put, cond_take = Threads.Condition(lock), Threads.Condition(lock) cond_wait = (sz == 0 ? Threads.Condition(lock) : cond_take) # wait is distinct from take iff unbuffered - return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), sz) + return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), 0, sz) end end @@ -121,7 +122,7 @@ julia> chnl = Channel{Char}(1, spawn=true) do ch put!(ch, c) end end -Channel{Char}(1) (1 item available) +Channel{Char}(1) (2 items available) julia> String(collect(chnl)) "hello world" @@ -317,17 +318,35 @@ function put!(c::Channel{T}, v) where T return isbuffered(c) ? put_buffered(c, v) : put_unbuffered(c, v) end +# Atomically update channel n_avail, *assuming* we hold the channel lock. +function _increment_n_avail(c, inc) + # We hold the channel lock so it's safe to non-atomically read and + # increment c.n_avail_items + newlen = c.n_avail_items + inc + # Atomically store c.n_avail_items to prevent data races with other threads + # reading this outside the lock. + @atomic :monotonic c.n_avail_items = newlen +end + function put_buffered(c::Channel, v) lock(c) + did_buffer = false try + # Increment channel n_avail eagerly (before push!) to count data in the + # buffer as well as offers from tasks which are blocked in wait(). + _increment_n_avail(c, 1) while length(c.data) == c.sz_max check_channel_state(c) wait(c.cond_put) end push!(c.data, v) + did_buffer = true # notify all, since some of the waiters may be on a "fetch" call. notify(c.cond_take, nothing, true, false) finally + # Decrement the available items if this task had an exception before pushing the + # item to the buffer (e.g., during `wait(c.cond_put)`): + did_buffer || _increment_n_avail(c, -1) unlock(c) end return v @@ -336,6 +355,7 @@ end function put_unbuffered(c::Channel, v) lock(c) taker = try + _increment_n_avail(c, 1) while isempty(c.cond_take.waitq) check_channel_state(c) notify(c.cond_wait) @@ -344,6 +364,7 @@ function put_unbuffered(c::Channel, v) # unfair scheduled version of: notify(c.cond_take, v, false, false); yield() popfirst!(c.cond_take.waitq) finally + _increment_n_avail(c, -1) unlock(c) end schedule(taker, v) @@ -390,6 +411,7 @@ function take_buffered(c::Channel) wait(c.cond_take) end v = popfirst!(c.data) + _increment_n_avail(c, -1) notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!. return v finally @@ -419,8 +441,11 @@ For unbuffered channels returns `true` if there are tasks waiting on a [`put!`](@ref). """ isready(c::Channel) = n_avail(c) > 0 -n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.cond_put.waitq) -isempty(c::Channel) = isbuffered(c) ? isempty(c.data) : isempty(c.cond_put.waitq) +isempty(c::Channel) = n_avail(c) == 0 +function n_avail(c::Channel) + # Lock-free equivalent to `length(c.data) + length(c.cond_put.waitq)` + @atomic :monotonic c.n_avail_items +end lock(c::Channel) = lock(c.cond_take) lock(f, c::Channel) = lock(f, c.cond_take) @@ -456,7 +481,7 @@ function show(io::IO, ::MIME"text/plain", c::Channel) print(io, " (empty)") else s = n == 1 ? "" : "s" - print(io, " (", n_avail(c), " item$s available)") + print(io, " (", n, " item$s available)") end end end diff --git a/test/channels.jl b/test/channels.jl index 0611b387e6f884..1a989747c38635 100644 --- a/test/channels.jl +++ b/test/channels.jl @@ -2,6 +2,7 @@ using Random using Base: Experimental +using Base: n_avail @testset "single-threaded Condition usage" begin a = Condition() @@ -578,3 +579,43 @@ let c = Channel(3) close(c) @test repr(MIME("text/plain"), c) == "Channel{Any}(3) (closed)" end + +# PR #41833: data races in Channel +@testset "n_avail(::Channel)" begin + # Buffered: n_avail() = buffer length + number of waiting tasks + let c = Channel(2) + @test n_avail(c) == 0; put!(c, 0) + @test n_avail(c) == 1; put!(c, 0) + @test n_avail(c) == 2; t1 = @task put!(c, 0); yield(t1) + @test n_avail(c) == 3; t2 = @task put!(c, 0); yield(t2) + @test n_avail(c) == 4 + # Test n_avail(c) after interrupting a task waiting on the channel + t3 = @task put!(c, 0) + yield(t3) + @test n_avail(c) == 5 + @async Base.throwto(t3, ErrorException("Exit put!")) + try wait(t3) catch end + @test n_avail(c) == 4 + close(c) + try wait(t1) catch end + try wait(t2) catch end + @test n_avail(c) == 2 # Already-buffered items remain + end + # Unbuffered: n_avail() = number of waiting tasks + let c = Channel() + @test n_avail(c) == 0; t1 = @task put!(c, 0); yield(t1) + @test n_avail(c) == 1; t2 = @task put!(c, 0); yield(t2) + @test n_avail(c) == 2 + # Test n_avail(c) after interrupting a task waiting on the channel + t3 = @task put!(c, 0) + yield(t3) + @test n_avail(c) == 3 + @async Base.throwto(t3, ErrorException("Exit put!")) + try wait(t3) catch end + @test n_avail(c) == 2 + close(c) + try wait(t1) catch end + try wait(t2) catch end + @test n_avail(c) == 0 + end +end