Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: basic tests for free-ing data
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 28, 2024
1 parent a0756e9 commit 6190360
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ ChainRulesCore = "1.23"
FillArrays = "1"
Functors = "0.4.8"
GPUArrays = "10"
MLUtils = "0.4"
MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Expand Down
5 changes: 2 additions & 3 deletions ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ using MLDataDevices: MLDataDevices, AbstractDevice, AbstractDeviceIterator, CPUD
Internal
using MLUtils: MLUtils, DataLoader

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol(dev, :Device)
@eval function (D::$(ldev))(dataloader::DataLoader)
for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
@eval function (D::$(dev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
@warn "Using `buffer=true` for parallel DataLoader with automatic device \
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -28,6 +29,7 @@ ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
MLUtils = "0.4"
Pkg = "1.10"
Random = "1.10"
RecursiveArrayTools = "3.8"
Expand Down
53 changes: 53 additions & 0 deletions test/iterator_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using MLDataDevices, MLUtils

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none"))

if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all"
using LuxCUDA
end

if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all"
using AMDGPU
end

if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all"
using Metal
end

if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all"
using oneAPI
end

DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice]

freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x)
freed_if_can_be_freed(::Type{CPUDevice}, x) = true
function freed_if_can_be_freed(::Type, x)
try
Array(x)
return false
catch err
err isa ArgumentError && return true
rethrow()
end
end

@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES
dev = dev_type()

!MLDataDevices.functional(dev) && continue

@info "Testing Device Iterator for $(dev)..."

@testset "Basic Device Iterator" begin
datalist = [rand(10) for _ in 1:10]

prev_batch = nothing
for data in DeviceIterator(dev, datalist)
prev_batch === nothing || @test freed_if_can_be_freed(prev_batch)
prev_batch = data
@test size(data) == (10,)
@test get_device_type(data) == dev_type
end
end
end
5 changes: 3 additions & 2 deletions test/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import FillArrays, RecursiveArrayTools, SparseArrays, Zygote
@test check_no_self_qualified_accesses(MLDataDevices) === nothing
@test check_all_explicit_imports_via_owners(MLDataDevices) === nothing
@test check_all_qualified_accesses_via_owners(MLDataDevices) === nothing
@test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing # mostly upstream problems
@test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing # mostly upstream problem
# mostly upstream problems
@test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing
@test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end
Test.@test true
end

@safetestset "Iterator Tests" include("iterator_tests.jl")
@safetestset "Misc Tests" include("misc_tests.jl")

@safetestset "QA Tests" include("qa_tests.jl")
end

0 comments on commit 6190360

Please sign in to comment.