diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl index a3c083e..693e661 100644 --- a/ext/MLDataDevicesMLUtilsExt.jl +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -1,8 +1,7 @@ module MLDataDevicesMLUtilsExt -using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUDevice, - CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, DeviceIterator, - Internal +using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, DeviceIterator using MLUtils: MLUtils, DataLoader for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @@ -12,44 +11,21 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) @warn "Using `buffer=true` for parallel DataLoader with automatic device \ transfer is currently not implemented. Ignoring `buffer=true`." end - return ParallelDeviceDataLoader(D, dataloader) - end - return DeviceIterator(D, dataloader) - end -end - -# Parallel DataLoader that does the device transfer in the same task -struct ParallelDeviceDataLoader{D <: AbstractDevice, DL <: DataLoader} <: - AbstractDeviceIterator{D, DL} - dev::D - iterator::DL -end -# Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl -function Base.iterate(c::ParallelDeviceDataLoader) - data = MLUtils.ObsView(c.iterator.data) + # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl + data = MLUtils.ObsView(dataloader.data) + data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data + data = if dataloader.batchsize > 0 + MLUtils.BatchView( + data; dataloader.batchsize, dataloader.partial, dataloader.collate) + else + data + end - data = c.iterator.shuffle ? MLUtils.shuffleobs(c.iterator.rng, data) : data - data = if c.iterator.batchsize > 0 - MLUtils.BatchView( - data; c.iterator.batchsize, c.iterator.partial, c.iterator.collate) - else - data + return DeviceIterator(D, eachobsparallel(D, data)) + end + return DeviceIterator(D, dataloader) end - - iter = eachobsparallel(c.dev, data) - item = iterate(iter) - item === nothing && return nothing - dev_batch, next_state = item - return dev_batch, ((iter, next_state), dev_batch) -end - -function Base.iterate(::ParallelDeviceDataLoader, ((iter, state), prev_batch)) - item = iterate(iter, state) - item === nothing && return nothing - dev_batch, next_state = item - Internal.unsafe_free!(prev_batch) # free the previous batch - return dev_batch, ((iter, next_state), dev_batch) end function eachobsparallel(dev::AbstractDevice, data) diff --git a/src/iterator.jl b/src/iterator.jl index 47969be..3b4345e 100644 --- a/src/iterator.jl +++ b/src/iterator.jl @@ -1,18 +1,5 @@ -abstract type AbstractDeviceIterator{D <: AbstractDevice, I} end - -function Base.IteratorSize(::Type{AbstractDeviceIterator{D, I}}) where {D, I} - return Base.IteratorSize(I) -end -Base.length(c::AbstractDeviceIterator) = length(c.iterator) -Base.axes(c::AbstractDeviceIterator) = axes(c.iterator) - -function Base.IteratorEltype(::Type{AbstractDeviceIterator{D, I}}) where {D, I} - return Base.IteratorEltype(I) -end -Base.eltype(c::AbstractDeviceIterator) = eltype(c.iterator) - # This is based on CuIterator but generalized to work with any device -struct DeviceIterator{D, I} <: AbstractDeviceIterator{D, I} +struct DeviceIterator{D <: AbstractDevice, I} dev::D iterator::I end @@ -33,3 +20,9 @@ function Base.iterate(c::DeviceIterator, (state, prev_batch)) dev_batch = c.dev(batch) return dev_batch, (next_state, dev_batch) end + +Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I) +Base.length(c::DeviceIterator) = length(c.iterator) +Base.axes(c::DeviceIterator) = axes(c.iterator) + +Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown() diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 938908a..b5e4cb6 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -11,7 +11,8 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_stale_explicit_imports(MLDataDevices) === nothing @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing - @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing + @test check_all_qualified_accesses_via_owners( + MLDataDevices; ignore=(:SparseArrays,)) === nothing # mostly upstream problems @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing