Skip to content

Commit

Permalink
Make DictEncoding reading threadsafe (#535)
Browse files Browse the repository at this point in the history
Fixes #534
  • Loading branch information
quinnj authored Nov 27, 2024
1 parent 7fa18aa commit fc8b899
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 86 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ jobs:
- name: ArrowTypes.jl
dir: './src/ArrowTypes'
version:
- '1.6'
- 'min'
- 'lts'
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ SentinelArrays = "1"
Tables = "1.1"
TimeZones = "1"
TranscodingStreams = "0.9.12, 0.10, 0.11"
julia = "1.6"
julia = "1.9"
176 changes: 92 additions & 84 deletions src/table.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ mutable struct Stream
names::Vector{Symbol}
types::Vector{Type}
schema::Union{Nothing,Meta.Schema}
dictencodings::Dict{Int64,DictEncoding} # dictionary id => DictEncoding
dictencodings::Lockable{Dict{Int64,DictEncoding}} # dictionary id => DictEncoding
dictencoded::Dict{Int64,Meta.Field} # dictionary id => field
convert::Bool
compression::Ref{Union{Symbol,Nothing}}
Expand All @@ -82,7 +82,7 @@ function Stream(inputs::Vector{ArrowBlob}; convert::Bool=true)
names = Symbol[]
types = Type[]
schema = nothing
dictencodings = Dict{Int64,DictEncoding}()
dictencodings = Lockable(Dict{Int64,DictEncoding}())
dictencoded = Dict{Int64,Meta.Field}()
compression = Ref{Union{Symbol,Nothing}}(nothing)
Stream(
Expand Down Expand Up @@ -210,8 +210,26 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0))
if recordbatch.compression !== nothing
compression = recordbatch.compression
end
if haskey(x.dictencodings, id) && header.isDelta
# delta
@lock x.dictencodings begin
dictencodings = x.dictencodings[]
if haskey(dictencodings, id) && header.isDelta
# delta
field = x.dictencoded[id]
values, _, _ = build(
field,
field.type,
batch,
recordbatch,
x.dictencodings,
Int64(1),
Int64(1),
x.convert,
)
dictencoding = dictencodings[id]
append!(dictencoding.data, values)
continue
end
# new dictencoding or replace
field = x.dictencoded[id]
values, _, _ = build(
field,
Expand All @@ -223,32 +241,17 @@ function Base.iterate(x::Stream, (pos, id)=(1, 0))
Int64(1),
x.convert,
)
dictencoding = x.dictencodings[id]
append!(dictencoding.data, values)
continue
end
# new dictencoding or replace
field = x.dictencoded[id]
values, _, _ = build(
field,
field.type,
batch,
recordbatch,
x.dictencodings,
Int64(1),
Int64(1),
x.convert,
)
A = ChainedVector([values])
S =
field.dictionary.indexType === nothing ? Int32 :
juliaeltype(field, field.dictionary.indexType, false)
x.dictencodings[id] = DictEncoding{eltype(A),S,typeof(A)}(
id,
A,
field.dictionary.isOrdered,
values.metadata,
)
A = ChainedVector([values])
S =
field.dictionary.indexType === nothing ? Int32 :
juliaeltype(field, field.dictionary.indexType, false)
dictencodings[id] = DictEncoding{eltype(A),S,typeof(A)}(
id,
A,
field.dictionary.isOrdered,
values.metadata,
)
end # lock
@debug "parsed dictionary batch message: id=$id, data=$values\n"
elseif header isa Meta.RecordBatch
@debug "parsing record batch message: compression = $(header.compression)"
Expand Down Expand Up @@ -415,7 +418,7 @@ Table(inputs::Vector; kw...) =
function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
t = Table()
sch = nothing
dictencodings = Dict{Int64,DictEncoding}() # dictionary id => DictEncoding
dictencodingslockable = Lockable(Dict{Int64,DictEncoding}()) # dictionary id => DictEncoding
dictencoded = Dict{Int64,Meta.Field}() # dictionary id => field
sync = OrderedSynchronizer()
tsks = Channel{Any}(Inf)
Expand Down Expand Up @@ -465,65 +468,68 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
elseif header isa Meta.DictionaryBatch
id = header.id
recordbatch = header.data
@debug "parsing dictionary batch message: id = $id, compression = $(recordbatch.compression)"
if haskey(dictencodings, id) && header.isDelta
# delta
@info "parsing dictionary batch message: id = $id, compression = $(recordbatch.compression)"
@lock dictencodingslockable begin
dictencodings = dictencodingslockable[]
if haskey(dictencodings, id) && header.isDelta
# delta
field = dictencoded[id]
values, _, _ = build(
field,
field.type,
batch,
recordbatch,
dictencodingslockable,
Int64(1),
Int64(1),
convert,
)
dictencoding = dictencodings[id]
if typeof(dictencoding.data) <: ChainedVector
append!(dictencoding.data, values)
else
A = ChainedVector([dictencoding.data, values])
S =
field.dictionary.indexType === nothing ? Int32 :
juliaeltype(field, field.dictionary.indexType, false)
dictencodings[id] = DictEncoding{eltype(A),S,typeof(A)}(
id,
A,
field.dictionary.isOrdered,
values.metadata,
)
end
continue
end
# new dictencoding or replace
field = dictencoded[id]
values, _, _ = build(
field,
field.type,
batch,
recordbatch,
dictencodings,
dictencodingslockable,
Int64(1),
Int64(1),
convert,
)
dictencoding = dictencodings[id]
if typeof(dictencoding.data) <: ChainedVector
append!(dictencoding.data, values)
else
A = ChainedVector([dictencoding.data, values])
S =
field.dictionary.indexType === nothing ? Int32 :
juliaeltype(field, field.dictionary.indexType, false)
dictencodings[id] = DictEncoding{eltype(A),S,typeof(A)}(
id,
A,
field.dictionary.isOrdered,
values.metadata,
)
end
continue
end
# new dictencoding or replace
field = dictencoded[id]
values, _, _ = build(
field,
field.type,
batch,
recordbatch,
dictencodings,
Int64(1),
Int64(1),
convert,
)
A = values
S =
field.dictionary.indexType === nothing ? Int32 :
juliaeltype(field, field.dictionary.indexType, false)
dictencodings[id] = DictEncoding{eltype(A),S,typeof(A)}(
id,
A,
field.dictionary.isOrdered,
values.metadata,
)
A = values
S =
field.dictionary.indexType === nothing ? Int32 :
juliaeltype(field, field.dictionary.indexType, false)
dictencodings[id] = DictEncoding{eltype(A),S,typeof(A)}(
id,
A,
field.dictionary.isOrdered,
values.metadata,
)
end # lock
@debug "parsed dictionary batch message: id=$id, data=$values\n"
elseif header isa Meta.RecordBatch
anyrecordbatches = true
@debug "parsing record batch message: compression = $(header.compression)"
@wkspawn begin
cols = collect(VectorIterator(sch, $batch, dictencodings, convert))
cols = collect(VectorIterator(sch, $batch, dictencodingslockable, convert))
put!(() -> put!(tsks, cols), sync, $(rbi))
end
rbi += 1
Expand Down Expand Up @@ -610,7 +616,7 @@ end
struct VectorIterator
schema::Meta.Schema
batch::Batch # batch.msg.header MUST BE RecordBatch
dictencodings::Dict{Int64,DictEncoding}
dictencodings::Lockable{Dict{Int64,DictEncoding}}
convert::Bool
end

Expand Down Expand Up @@ -654,14 +660,16 @@ function build(field::Meta.Field, batch, rb, de, nodeidx, bufferidx, convert)
buffer = rb.buffers[bufferidx]
S = d.indexType === nothing ? Int32 : juliaeltype(field, d.indexType, false)
bytes, indices = reinterp(S, batch, buffer, rb.compression)
encoding = de[d.id]
A = DictEncoded(
bytes,
validity,
indices,
encoding,
buildmetadata(field.custom_metadata),
)
@lock de begin
encoding = de[][d.id]
A = DictEncoded(
bytes,
validity,
indices,
encoding,
buildmetadata(field.custom_metadata),
)
end
nodeidx += 1
bufferidx += 1
else
Expand Down

0 comments on commit fc8b899

Please sign in to comment.