Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 4, 2023
1 parent a660d50 commit 0c932e2
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ For information on using the package,
[in-development documentation](https://docs.sciml.ai/NeuralPDE/dev/) for the version of
the documentation, which contains the unreleased features.


## Features

- Physics-Informed Neural Networks for ODE, SDE, RODE, and PDE solving
Expand Down
3 changes: 2 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, WeightedIntervalTraining,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
build_loss_function, get_loss_function,
generate_training_sets, get_variables, get_argument, get_bounds,
get_phi, get_numeric_derivative, get_numeric_integral,
Expand Down
10 changes: 6 additions & 4 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tsp
optf = OptimizationFunction(loss, Optimization.AutoZygote())
end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, batch)
function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
batch)
minT = tspan[1]
maxT = tspan[2]

Expand All @@ -297,12 +298,13 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+ ((index - 1) * difference)
temp_data = rand(1, trunc(Int, samples * item)) .* difference .+ minT .+
((index - 1) * difference)
data = append!(data, temp_data)
end

ts = data

function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
Expand Down Expand Up @@ -413,7 +415,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
verbose && println("Current loss is: $l, Iteration: $iteration")
l < abstol
end

optprob = OptimizationProblem(optf, init_params)
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

Expand Down
6 changes: 3 additions & 3 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature
return loss
end


"""
```julia
WeightedIntervalTraining(weights, samples)
Expand Down Expand Up @@ -328,7 +327,8 @@ function WeightedIntervalTraining(weights, samples)
WeightedIntervalTraining(weights, samples)
end

function get_loss_function(loss_function, train_set, eltypeθ, strategy::WeightedIntervalTraining;
function get_loss_function(loss_function, train_set, eltypeθ,
strategy::WeightedIntervalTraining;
τ = nothing)
loss = (θ) -> mean(abs2, loss_function(train_set, θ))
end
end
9 changes: 5 additions & 4 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,13 @@ true_sol = solve(prob_oop, Tsit5(), saveat = 0.01)
func = Lux.σ
N = 12
chain = Lux.Chain(Lux.Dense(1, N, func), Lux.Dense(N, N, func), Lux.Dense(N, N, func),
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))
Lux.Dense(N, N, func), Lux.Dense(N, length(u0)))

opt = Optimisers.Adam(0.01)
weights = [0.7, 0.2, 0.1]
samples = 200
alg = NeuralPDE.NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
sol = solve(prob_oop, alg, verbose=true, maxiters = 100000, saveat = 0.01)
alg = NeuralPDE.NNODE(chain, opt, autodiff = false,
strategy = NeuralPDE.WeightedIntervalTraining(weights, samples))
sol = solve(prob_oop, alg, verbose = true, maxiters = 100000, saveat = 0.01)

@test abs(mean(sol) - mean(true_sol)) < 0.2
@test abs(mean(sol) - mean(true_sol)) < 0.2

0 comments on commit 0c932e2

Please sign in to comment.