From f0ec53af7863a423f1d4b4749ea9da3117b289e2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 7 Nov 2024 21:38:10 -0500 Subject: [PATCH] test: Reactant support for the models --- test/Project.toml | 2 + test/layer_tests.jl | 20 ++++- test/shared_testsetup.jl | 20 +++++ test/vision_tests.jl | 161 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 201 insertions(+), 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index a05d3bd..a5a5fb5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,6 +22,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -51,6 +52,7 @@ NNlib = "0.9.21" Pkg = "1.10" Random = "1.10" ReTestItems = "1.24.0" +Reactant = "0.2.5" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 3477ba0..7b90856 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -16,14 +16,28 @@ model = Layers.MLP(2, (4, 4, 2), act; norm_layer=norm) ps, st = Lux.setup(StableRNG(0), model) |> dev + st_test = Lux.testmode(st) x = randn(Float32, 2, 2) |> aType - @jet model(x, ps, st) + @jet model(x, ps, st_test) __f = (x, ps) -> sum(abs2, first(model(x, ps, st))) @test_gradients(__f, x, ps; atol=1e-3, rtol=1e-3, soft_fail=[AutoFiniteDiff()]) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st_test) + x_ra = rdev(x) + + model_compiled = Reactant.compile(model, (x_ra, ps_ra, st_ra)) + @test first(model_compiled(x_ra, ps_ra, st_ra)) ≈ + Array(first(model(x, ps, st_test))) + end end end end @@ -217,6 +231,10 @@ end __f = x -> sum(first(layer(x, ps, st))) @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, enzyme_set_runtime_activity=true) + + # TODO: Reactant testing + # We need to solve https://github.com/EnzymeAD/Reactant.jl/issues/242 and + # https://github.com/EnzymeAD/Reactant.jl/issues/243 first end end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index dc8e429..b5bb538 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -19,6 +19,24 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" using AMDGPU end +@static if !Sys.iswindows() + @reexport using Reactant + test_reactant(mode) = mode != "amdgpu" + function set_reactant_backend!(mode) + if mode == "cuda" + Reactant.set_default_backend("gpu") + elseif mode == "cpu" + Reactant.set_default_backend("cpu") + end + end +else + test_reactant(::Any) = true + set_reactant_backend!(::Any) = nothing + macro compile(expr) + return :() + end +end + cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" function cuda_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && @@ -38,5 +56,7 @@ const MODES = begin end export MODES, BACKEND_GROUP +export test_reactant, set_reactant_backend! +export @compile end diff --git a/test/vision_tests.jl b/test/vision_tests.jl index 449e90d..dfb0752 100644 --- a/test/vision_tests.jl +++ b/test/vision_tests.jl @@ -23,7 +23,12 @@ function imagenet_acctest(model, ps, st, dev; size=224) TEST_X = size == 224 ? MONARCH_224 : (size == 256 ? MONARCH_256 : error("size must be 224 or 256")) x = TEST_X |> dev - ypred = first(model(x, ps, st)) |> collect |> vec + + if dev isa MLDataDevices.ReactantDevice + model = Reactant.compile(model, (x, ps, st)) + end + + ypred = first(model(x, ps, st)) |> cpu_device() |> collect |> vec top5 = TEST_LBLS[sortperm(ypred; rev=true)] return "monarch" in top5 end @@ -48,6 +53,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + Array(first(model(img, ps, st))) + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -63,6 +85,19 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + Array(first(model(img, ps, st))) + end end end @@ -77,6 +112,19 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + Array(first(model(img, ps, st))) + end end end @@ -91,6 +139,19 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + Array(first(model(img, ps, st))) + end end end @@ -110,6 +171,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + first(model(img, ps, st)) + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -134,6 +212,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + first(model(img, ps, st)) + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -157,6 +252,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + first(model(img, ps, st)) + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -177,6 +289,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + first(model(img, ps, st)) + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -197,6 +326,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + first(model(img, ps, st)) + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -221,5 +367,18 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled(img_ra, ps_ra, st_ra)) ≈ + Array(first(model(img, ps, st))) + end end end