From c7f47d693d68bdd5bd5327e9c638a1b605a696b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 31 May 2024 20:26:15 -0700 Subject: [PATCH] Add tests for Ensemble Problems --- test/misc/ensemble_tests.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 test/misc/ensemble_tests.jl diff --git a/test/misc/ensemble_tests.jl b/test/misc/ensemble_tests.jl new file mode 100644 index 000000000..c034bf6bc --- /dev/null +++ b/test/misc/ensemble_tests.jl @@ -0,0 +1,21 @@ +@testitem "Ensemble Nonlinear Problems" tags=[:misc] begin + using NonlinearSolve + + prob_func(prob, i, repeat) = remake(prob; u0 = prob.u0[:, i]) + + prob_nls_oop = NonlinearProblem((u, p) -> u .* u .- p, rand(4, 128), 2.0) + prob_nls_iip = NonlinearProblem((du, u, p) -> du .= u .* u .- p, rand(4, 128), 2.0) + prob_nlls_oop = NonlinearLeastSquaresProblem((u, p) -> u .^ 2 .- p, rand(4, 128), 2.0) + prob_nlls_iip = NonlinearLeastSquaresProblem( + NonlinearFunction{true}((du, u, p) -> du .= u .^ 2 .- p; resid_prototype = rand(4)), + rand(4, 128), 2.0) + + for prob in (prob_nls_oop, prob_nls_iip, prob_nlls_oop, prob_nlls_iip) + ensembleprob = EnsembleProblem(prob; prob_func) + + for ensemblealg in (EnsembleThreads(), EnsembleSerial()) + sim = solve(ensembleprob, nothing, ensemblealg; trajectories = size(prob.u0, 2)) + @test all(SciMLBase.successful_retcode, sim.u) + end + end +end