Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: simplify parallel dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 28, 2024
1 parent 6190360 commit 4dfcfe3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 53 deletions.
52 changes: 14 additions & 38 deletions ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module MLDataDevicesMLUtilsExt

using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice,
CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator,
Internal
using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
Expand All @@ -12,44 +11,21 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
transfer is currently not implemented. Ignoring `buffer=true`."
end
return ParallelDeviceDataLoader(D, dataloader)
end
return DeviceIterator(D, dataloader)
end
end

# Parallel DataLoader that does the device transfer in the same task
struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <:
AbstractDeviceIterator{D, DL}
dev::D
iterator::DL
end

# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl
function Base.iterate(c::ParallelDeviceDataLoader)
data = MLUtils.ObsView(c.iterator.data)
# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl
data = MLUtils.ObsView(dataloader.data)
data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data
data = if dataloader.batchsize > 0
MLUtils.BatchView(
data; dataloader.batchsize, dataloader.partial, dataloader.collate)
else
data
end

data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data
data = if c.iterator.batchsize > 0
MLUtils.BatchView(
data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate)
else
data
return DeviceIterator(D, eachobsparallel(D, data))
end
return DeviceIterator(D, dataloader)
end

iter = eachobsparallel(c.dev, data)
item = iterate(iter)
item === nothing && return nothing
dev_batch, next_state = item
return dev_batch, ((iter, next_state), dev_batch)
end

function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch))
item = iterate(iter, state)
item === nothing && return nothing
dev_batch, next_state = item
Internal.unsafe_free!(prev_batch) # free the previous batch
return dev_batch, ((iter, next_state), dev_batch)
end

function eachobsparallel(dev::AbstractDevice, data)
Expand Down
21 changes: 7 additions & 14 deletions src/iterator.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end

function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I}
return Base.IteratorSize(I)
end
Base.length(c::AbstractDeviceIterator) = length(c.iterator)
Base.axes(c::AbstractDeviceIterator) = axes(c.iterator)

function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I}
return Base.IteratorEltype(I)
end
Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator)

# This is based on CuIterator but generalized to work with any device
struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I}
struct DeviceIterator{D <: AbstractDevice, I}
dev::D
iterator::I
end
Expand All @@ -33,3 +20,9 @@ function Base.iterate(c::DeviceIterator, (state, prev_batch))
dev_batch = c.dev(batch)
return dev_batch, (next_state, dev_batch)
end

Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I)
Base.length(c::DeviceIterator) = length(c.iterator)
Base.axes(c::DeviceIterator) = axes(c.iterator)

Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown()
3 changes: 2 additions & 1 deletion test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote
@test check_no_stale_explicit_imports(MLDataDevices) === nothing
@test check_no_self_qualified_accesses(MLDataDevices) === nothing
@test check_all_explicit_imports_via_owners(MLDataDevices) === nothing
@test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing
@test check_all_qualified_accesses_via_owners(
MLDataDevices; ignore=(:SparseArrays,)) === nothing
# mostly upstream problems
@test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing
@test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing
Expand Down

0 comments on commit 4dfcfe3

Please sign in to comment.