Skip to content

Commit

Permalink
Fix data races in n_avail(::Channel) to fix isready/isempty (#41833)
Browse files Browse the repository at this point in the history
This removes the data race from isready() and isempty(), which are now
implemented in terms of n_avail(). A new atomic `n_avail` field is added
to track the "current number of available items" (buffered + waiting
tasks). This is separate from the buffer and wait queue because these
consist of `Vector`s which cannot easily have their length fields read
and written atomically.

For buffered channels, the n_avail now includes a count of any waiting
tasks in addition to the number of buffered items. This makes it
consistent with the computation for unbuffered channels.

Co-authored-by: Takafumi Arakaki <[email protected]>
  • Loading branch information
c42f and tkf authored Nov 10, 2021
1 parent f317d57 commit 924a13a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
35 changes: 30 additions & 5 deletions base/channels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ mutable struct Channel{T} <: AbstractChannel{T}
excp::Union{Exception, Nothing} # exception to be thrown when state !== :open

data::Vector{T}
@atomic n_avail_items::Int # Available items for taking, can be read without lock
sz_max::Int # maximum size of channel

function Channel{T}(sz::Integer = 0) where T
Expand All @@ -46,7 +47,7 @@ mutable struct Channel{T} <: AbstractChannel{T}
lock = ReentrantLock()
cond_put, cond_take = Threads.Condition(lock), Threads.Condition(lock)
cond_wait = (sz == 0 ? Threads.Condition(lock) : cond_take) # wait is distinct from take iff unbuffered
return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), sz)
return new(cond_take, cond_wait, cond_put, :open, nothing, Vector{T}(), 0, sz)
end
end

Expand Down Expand Up @@ -121,7 +122,7 @@ julia> chnl = Channel{Char}(1, spawn=true) do ch
put!(ch, c)
end
end
Channel{Char}(1) (1 item available)
Channel{Char}(1) (2 items available)
julia> String(collect(chnl))
"hello world"
Expand Down Expand Up @@ -317,17 +318,35 @@ function put!(c::Channel{T}, v) where T
return isbuffered(c) ? put_buffered(c, v) : put_unbuffered(c, v)
end

# Atomically update channel n_avail, *assuming* we hold the channel lock.
function _increment_n_avail(c, inc)
# We hold the channel lock so it's safe to non-atomically read and
# increment c.n_avail_items
newlen = c.n_avail_items + inc
# Atomically store c.n_avail_items to prevent data races with other threads
# reading this outside the lock.
@atomic :monotonic c.n_avail_items = newlen
end

function put_buffered(c::Channel, v)
lock(c)
did_buffer = false
try
# Increment channel n_avail eagerly (before push!) to count data in the
# buffer as well as offers from tasks which are blocked in wait().
_increment_n_avail(c, 1)
while length(c.data) == c.sz_max
check_channel_state(c)
wait(c.cond_put)
end
push!(c.data, v)
did_buffer = true
# notify all, since some of the waiters may be on a "fetch" call.
notify(c.cond_take, nothing, true, false)
finally
# Decrement the available items if this task had an exception before pushing the
# item to the buffer (e.g., during `wait(c.cond_put)`):
did_buffer || _increment_n_avail(c, -1)
unlock(c)
end
return v
Expand All @@ -336,6 +355,7 @@ end
function put_unbuffered(c::Channel, v)
lock(c)
taker = try
_increment_n_avail(c, 1)
while isempty(c.cond_take.waitq)
check_channel_state(c)
notify(c.cond_wait)
Expand All @@ -344,6 +364,7 @@ function put_unbuffered(c::Channel, v)
# unfair scheduled version of: notify(c.cond_take, v, false, false); yield()
popfirst!(c.cond_take.waitq)
finally
_increment_n_avail(c, -1)
unlock(c)
end
schedule(taker, v)
Expand Down Expand Up @@ -390,6 +411,7 @@ function take_buffered(c::Channel)
wait(c.cond_take)
end
v = popfirst!(c.data)
_increment_n_avail(c, -1)
notify(c.cond_put, nothing, false, false) # notify only one, since only one slot has become available for a put!.
return v
finally
Expand Down Expand Up @@ -419,8 +441,11 @@ For unbuffered channels returns `true` if there are tasks waiting
on a [`put!`](@ref).
"""
isready(c::Channel) = n_avail(c) > 0
n_avail(c::Channel) = isbuffered(c) ? length(c.data) : length(c.cond_put.waitq)
isempty(c::Channel) = isbuffered(c) ? isempty(c.data) : isempty(c.cond_put.waitq)
isempty(c::Channel) = n_avail(c) == 0
function n_avail(c::Channel)
# Lock-free equivalent to `length(c.data) + length(c.cond_put.waitq)`
@atomic :monotonic c.n_avail_items
end

lock(c::Channel) = lock(c.cond_take)
lock(f, c::Channel) = lock(f, c.cond_take)
Expand Down Expand Up @@ -456,7 +481,7 @@ function show(io::IO, ::MIME"text/plain", c::Channel)
print(io, " (empty)")
else
s = n == 1 ? "" : "s"
print(io, " (", n_avail(c), " item$s available)")
print(io, " (", n, " item$s available)")
end
end
end
Expand Down
41 changes: 41 additions & 0 deletions test/channels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using Random
using Base: Experimental
using Base: n_avail

@testset "single-threaded Condition usage" begin
a = Condition()
Expand Down Expand Up @@ -578,3 +579,43 @@ let c = Channel(3)
close(c)
@test repr(MIME("text/plain"), c) == "Channel{Any}(3) (closed)"
end

# PR #41833: data races in Channel
@testset "n_avail(::Channel)" begin
# Buffered: n_avail() = buffer length + number of waiting tasks
let c = Channel(2)
@test n_avail(c) == 0; put!(c, 0)
@test n_avail(c) == 1; put!(c, 0)
@test n_avail(c) == 2; t1 = @task put!(c, 0); yield(t1)
@test n_avail(c) == 3; t2 = @task put!(c, 0); yield(t2)
@test n_avail(c) == 4
# Test n_avail(c) after interrupting a task waiting on the channel
t3 = @task put!(c, 0)
yield(t3)
@test n_avail(c) == 5
@async Base.throwto(t3, ErrorException("Exit put!"))
try wait(t3) catch end
@test n_avail(c) == 4
close(c)
try wait(t1) catch end
try wait(t2) catch end
@test n_avail(c) == 2 # Already-buffered items remain
end
# Unbuffered: n_avail() = number of waiting tasks
let c = Channel()
@test n_avail(c) == 0; t1 = @task put!(c, 0); yield(t1)
@test n_avail(c) == 1; t2 = @task put!(c, 0); yield(t2)
@test n_avail(c) == 2
# Test n_avail(c) after interrupting a task waiting on the channel
t3 = @task put!(c, 0)
yield(t3)
@test n_avail(c) == 3
@async Base.throwto(t3, ErrorException("Exit put!"))
try wait(t3) catch end
@test n_avail(c) == 2
close(c)
try wait(t1) catch end
try wait(t2) catch end
@test n_avail(c) == 0
end
end

0 comments on commit 924a13a

Please sign in to comment.