diff --git a/Project.toml b/Project.toml index 35da279..0602650 100644 --- a/Project.toml +++ b/Project.toml @@ -49,7 +49,7 @@ ChainRulesCore = "1.23" FillArrays = "1" Functors = "0.4.8" GPUArrays = "10" -MLUtils = "0.4" +MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" diff --git a/ext/MLDataDevicesMLUtilsExt.jl b/ext/MLDataDevicesMLUtilsExt.jl index 57db601..a3c083e 100644 --- a/ext/MLDataDevicesMLUtilsExt.jl +++ b/ext/MLDataDevicesMLUtilsExt.jl @@ -5,9 +5,8 @@ using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUD Internal using MLUtils: MLUtils, DataLoader -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) - ldev = Symbol(dev, :Device) - @eval function (D::$(ldev))(dataloader::DataLoader) +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + @eval function (D::$(dev))(dataloader::DataLoader) if dataloader.parallel if dataloader.buffer @warn "Using `buffer=true` for parallel DataLoader with automatic device \ diff --git a/test/Project.toml b/test/Project.toml index f770c7a..9914e0f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -28,6 +29,7 @@ ExplicitImports = "1.9.0" FillArrays = "1" ForwardDiff = "0.10.36" Functors = "0.4.8" +MLUtils = "0.4" Pkg = "1.10" Random = "1.10" RecursiveArrayTools = "3.8" diff --git a/test/iterator_tests.jl b/test/iterator_tests.jl new file mode 100644 index 0000000..78d4601 --- /dev/null +++ b/test/iterator_tests.jl @@ -0,0 +1,53 @@ +using MLDataDevices, MLUtils + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + +if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all" + using LuxCUDA +end + +if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all" + using AMDGPU +end + +if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all" + using Metal +end + +if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all" + using oneAPI +end + +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice] + +freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) +freed_if_can_be_freed(::Type{CPUDevice}, x) = true +function freed_if_can_be_freed(::Type, x) + try + Array(x) + return false + catch err + err isa ArgumentError && return true + rethrow() + end +end + +@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES + dev = dev_type() + + !MLDataDevices.functional(dev) && continue + + @info "Testing Device Iterator for $(dev)..." + + @testset "Basic Device Iterator" begin + datalist = [rand(10) for _ in 1:10] + + prev_batch = nothing + for data in DeviceIterator(dev, datalist) + prev_batch === nothing || @test freed_if_can_be_freed(prev_batch) + prev_batch = data + @test size(data) == (10,) + @test get_device_type(data) == dev_type + end + end +end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 965e818..938908a 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -12,6 +12,7 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote @test check_no_self_qualified_accesses(MLDataDevices) === nothing @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing @test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing - @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems - @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem + # mostly upstream problems + @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing + @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing end diff --git a/test/runtests.jl b/test/runtests.jl index b9fb136..65cc190 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,7 +28,7 @@ end Test.@test true end + @safetestset "Iterator Tests" include("iterator_tests.jl") @safetestset "Misc Tests" include("misc_tests.jl") - @safetestset "QA Tests" include("qa_tests.jl") end