diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 11f7429a40..84e6daf29a 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -3,7 +3,8 @@ abstract type NeuralPDEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end """ ```julia NNODE(chain, opt=OptimizationPolyalgorithms.PolyOpt(), init_params = nothing; - autodiff=false, batch=0, kwargs...) + autodiff=false, batch=0,additional_loss=nothing, + kwargs...) ``` Algorithm for solving ordinary differential equations using a neural network. This is a specialization @@ -23,6 +24,19 @@ of the physics-informed neural network which is used as a solver for a standard which thus uses the random initialization provided by the neural network library. ## Keyword Arguments +* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, + θ are the weights of the neural network(s). + +## Example + +```julia + ts=[t for t in 1:100] + (u_, t_) = (analytical_func(ts), ts) + function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + end + alg = NeuralPDE.NNODE(chain, opt, additional_loss = additional_loss) +``` * `autodiff`: The switch between automatic and numerical differentiation for the PDE operators. The reverse mode of the loss function is always @@ -63,7 +77,9 @@ is an accurate interpolation (up to the neural network training result). In addi Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000. """ -struct NNODE{C, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <: +struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function}, + S <: Union{Nothing, AbstractTrainingStrategy} + } <: NeuralPDEAlgorithm chain::C opt::O @@ -71,12 +87,13 @@ struct NNODE{C, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <: autodiff::Bool batch::B strategy::S + additional_loss::AL kwargs::K end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = nothing, kwargs...) - NNODE(chain, opt, init_params, autodiff, batch, strategy, kwargs) + autodiff = false, batch = nothing, additional_loss = nothing, kwargs...) + NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs) end """ @@ -236,7 +253,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, end """ -Representation of the loss function, paramtric on the training strategy `strategy` +Representation of the loss function, parametric on the training strategy `strategy` """ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, batch) @@ -250,15 +267,13 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp sol.u end - # Default this to ForwardDiff until Integrals.jl autodiff is sorted out - OptimizationFunction(loss, Optimization.AutoForwardDiff()) + return loss end function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch) ts = tspan[1]:(strategy.dx):tspan[2] # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken - function loss(θ, _) if batch sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p)) @@ -266,22 +281,22 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) end end - optf = OptimizationFunction(loss, Optimization.AutoZygote()) + return loss end function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, batch) + # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken function loss(θ, _) ts = adapt(parameterless_type(θ), [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) - if batch sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p)) else sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) end end - optf = OptimizationFunction(loss, Optimization.AutoZygote()) + return loss end function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, @@ -312,7 +327,7 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts]) end end - optf = OptimizationFunction(loss, Optimization.AutoZygote()) + return loss end function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan) @@ -407,7 +422,27 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, alg.batch end - optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, p, batch) + inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch) + additional_loss = alg.additional_loss + + # Creates OptimizationFunction Object from total_loss + function total_loss(θ, _) + L2_loss = inner_f(θ, phi) + if !(additional_loss isa Nothing) + return additional_loss(phi, θ) + L2_loss + end + L2_loss + end + + # Choice of Optimization Algo for Training Strategies + opt_algo = if strategy isa QuadratureTraining + Optimization.AutoForwardDiff() + else + Optimization.AutoZygote() + end + + # Creates OptimizationFunction Object from total_loss + optf = OptimizationFunction(total_loss, opt_algo) iteration = 0 callback = function (p, l) diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index ffc6d68158..1839f57c7d 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -207,6 +207,7 @@ sol = solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true, abstol = 1.0f-8, dt = 1 / 5.0f0) @test sol.errors[:l2] < 0.5 +# WeightedIntervalTraining(Lux Chain) function f(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] end @@ -228,3 +229,97 @@ alg = NeuralPDE.NNODE(chain, opt, autodiff = false, sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01) @test abs(mean(sol) - mean(true_sol)) < 0.2 + +# Checking if additional_loss feature works for NNODE +linear = (u, p, t) -> cos(2pi * t) +linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t) +tspan = (0.0f0, 1.0f0) +dt = (tspan[2] - tspan[1]) / 99 +ts = collect(tspan[1]:dt:tspan[2]) +prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) +opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95)) + +# Analytical solution +u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) + +# GridTraining (Flux Chain) +chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1)) + +(u_, t_) = (u_analytical(ts), ts) +function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) +end + +alg1 = NeuralPDE.NNODE(chain, opt, strategy = GridTraining(0.01), + additional_loss = additional_loss) + +sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500) +@test sol1.errors[:l2] < 0.5 + +# GridTraining (Lux Chain) +luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + +(u_, t_) = (u_analytical(ts), ts) +function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) +end + +alg1 = NeuralPDE.NNODE(luxchain, opt, strategy = GridTraining(0.01), + additional_loss = additional_loss) + +sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500) +@test sol1.errors[:l2] < 0.5 + +# QuadratureTraining (Flux Chain) +chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1)) + +(u_, t_) = (u_analytical(ts), ts) +function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) +end + +alg1 = NeuralPDE.NNODE(chain, opt, additional_loss = additional_loss) + +sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-10, maxiters = 200) +@test sol1.errors[:l2] < 0.5 + +# QuadratureTraining (Lux Chain) +luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + +(u_, t_) = (u_analytical(ts), ts) +function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) +end + +alg1 = NeuralPDE.NNODE(luxchain, opt, additional_loss = additional_loss) + +sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-10, maxiters = 200) +@test sol1.errors[:l2] < 0.5 + +# StochasticTraining(Flux Chain) +chain = Flux.Chain(Dense(1, 5, σ), Dense(5, 1)) + +(u_, t_) = (u_analytical(ts), ts) +function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) +end + +alg1 = NeuralPDE.NNODE(chain, opt, strategy = StochasticTraining(1000), + additional_loss = additional_loss) + +sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500) +@test sol1.errors[:l2] < 0.5 + +# StochasticTraining (Lux Chain) +luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) + +(u_, t_) = (u_analytical(ts), ts) +function additional_loss(phi, θ) + return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) +end + +alg1 = NeuralPDE.NNODE(luxchain, opt, strategy = StochasticTraining(1000), + additional_loss = additional_loss) + +sol1 = solve(prob, alg1, verbose = true, abstol = 1.0f-8, maxiters = 500) +@test sol1.errors[:l2] < 0.5