Skip to content

Commit

Permalink
test: use TestExtras in MLDataDevices testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 22, 2024
1 parent 2ed5366 commit 9412297
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 15 deletions.
2 changes: 2 additions & 0 deletions lib/MLDataDevices/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -39,5 +40,6 @@ ReverseDiff = "1.15"
SafeTestsets = "0.1"
SparseArrays = "1.10"
Test = "1.10"
TestExtras = "0.3.1"
Tracker = "0.2.36"
Zygote = "0.6.69"
4 changes: 2 additions & 2 deletions lib/MLDataDevices/test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Test
using MLDataDevices, Random, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -122,7 +122,7 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)
end
end

Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/test/cuda_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Functors, Test
using MLDataDevices, Random, Functors, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -144,7 +144,7 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)

return_val2(x) = Val(get_device(x))
@test_throws ErrorException @inferred(return_val2(ps))
Expand Down
6 changes: 3 additions & 3 deletions lib/MLDataDevices/test/metal_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Test
using MLDataDevices, Random, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{get_device(x)}
@constinferred Val{get_device(x)} return_val2(ps)
end
end

Expand Down
6 changes: 3 additions & 3 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Adapt, MLDataDevices, ComponentArrays, Random
using Adapt, MLDataDevices, ComponentArrays, Random, TestExtras
using ArrayInterface: parameterless_type
using ChainRulesTestUtils: test_rrule
using ReverseDiff, Tracker, ForwardDiff
Expand Down Expand Up @@ -148,10 +148,10 @@ end
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{typeof(cpu_device())}
@constinferred Val{typeof(cpu_device())} return_val(ps)

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{cpu_device()}
@constinferred Val{cpu_device()} return_val2(ps)
end

@testset "undefined references array" begin
Expand Down
6 changes: 3 additions & 3 deletions lib/MLDataDevices/test/oneapi_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Test
using MLDataDevices, Random, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)

return_val2(x) = Val(get_device(x))
@test @inferred(return_val2(ps)) isa Val{get_device(x)}
@constinferred Val{get_device(x)} return_val2(ps)
end
end

Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/test/xla_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using MLDataDevices, Random, Test
using MLDataDevices, Random, Test, TestExtras
using ArrayInterface: parameterless_type

@testset "CPU Fallback" begin
Expand Down Expand Up @@ -108,7 +108,7 @@ using FillArrays, Zygote # Extensions
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}
@constinferred Val{parameterless_type(typeof(device))} return_val(ps)

return_val2(x) = Val(get_device(x))
@test_throws TypeError @inferred(return_val2(ps))
Expand Down

0 comments on commit 9412297

Please sign in to comment.