Skip to content

Commit

Permalink
Merge pull request #635 from sdesai1287/nnode_training_edits
Browse files Browse the repository at this point in the history
Explored some new training strategies for large neural networks.
  • Loading branch information
ChrisRackauckas authored Apr 4, 2023
2 parents d65a09c + 73294b5 commit a660d50
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "IntegralsCuba"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ For information on using the package,
[in-development documentation](https://docs.sciml.ai/NeuralPDE/dev/) for the version of
the documentation, which contains the unreleased features.


## Features

- Physics-Informed Neural Networks for ODE, SDE, RODE, and PDE solving
Expand Down
3 changes: 2 additions & 1 deletion docs/src/manual/training_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ method, such as `CubaVegas`, can be beneficial for difficult or stiff problems.

`GridTraining` should only be used for testing purposes and should not be relied upon for real
training cases. `StochasticTraining` achieves a lower convergence rate in the quasi-Monte Carlo
methods and thus `QuasiRandomTraining` should be preferred in most cases.
methods and thus `QuasiRandomTraining` should be preferred in most cases. `WeightedIntervalTraining` can only be used with ODEs (`NNODE`).

## API

Expand All @@ -23,4 +23,5 @@ GridTraining
StochasticTraining
QuasiRandomTraining
QuadratureTraining
WeightedIntervalTraining
```
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, WeightedIntervalTraining,
build_loss_function, get_loss_function,
generate_training_sets, get_variables, get_argument, get_bounds,
get_phi, get_numeric_derivative, get_numeric_integral,
Expand Down
33 changes: 30 additions & 3 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
# (tspan[2]-tspan[1])*rand() + tspan[1] gives Uniform(tspan[1],tspan[2])
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])

Expand All @@ -286,6 +284,35 @@ function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tsp
optf = OptimizationFunction(loss, Optimization.AutoZygote())
end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, batch)
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
samples = strategy.samples

difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+ ((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data

function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
end
end
optf = OptimizationFunction(loss, Optimization.AutoZygote())
end

function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, tspan)
error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.")
end
Expand Down Expand Up @@ -386,7 +413,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol
end

optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

Expand Down
34 changes: 34 additions & 0 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,37 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature
loss = (θ) -> 1 / area * f_(lb, ub, loss_function, θ)
return loss
end


"""
```julia
WeightedIntervalTraining(weights, samples)
```
A training strategy that generates points for training based on the given inputs.
We split the timespan into equal segments based on the number of weights,
then sample points in each segment based on that segments corresponding weight,
such that the total number of sampled points is equivalent to the given samples
## Positional Arguments
* `weights`: A vector of weights that should sum to 1, representing the proportion of samples at each interval.
* `samples`: the total number of samples that we want, across the entire time span
## Limitations
This training strategy can only be used with ODEs (`NNODE`).
"""
struct WeightedIntervalTraining{T} <: AbstractTrainingStrategy
weights::Vector{T}
samples::Int
end

function WeightedIntervalTraining(weights, samples)
WeightedIntervalTraining(weights, samples)
end

function get_loss_function(loss_function, train_set, eltypeθ, strategy::WeightedIntervalTraining;
τ = nothing)
loss = (θ) -> mean(abs2, loss_function(train_set, θ))
end
23 changes: 23 additions & 0 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Test, Flux
using Random, NeuralPDE
using OrdinaryDiffEq, Optimisers, Statistics
import Lux, OptimizationOptimisers, OptimizationOptimJL

Random.seed!(100)

# Run a solve on scalars
Expand Down Expand Up @@ -204,3 +206,24 @@ sol = solve(prob, NeuralPDE.NNODE(luxchain, opt; batch = true), verbose = true,
maxiters = 400,
abstol = 1.0f-8, dt = 1 / 5.0f0)
@test sol.errors[:l2] < 0.5

function f(u, p, t)
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end

p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0, 1.0]
prob_oop = ODEProblem{false}(f, u0, (0.0, 3.0), p)
true_sol = solve(prob_oop, Tsit5(), saveat = 0.01)
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
weights = [0.7, 0.2, 0.1]
samples = 200
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2

0 comments on commit a660d50

Please sign in to comment.