Skip to content

Commit

Permalink
fix: different device handling
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 16, 2024
1 parent 7535186 commit a453576
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 78 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down Expand Up @@ -70,6 +71,7 @@ LuxCUDA = "0.3.3"
LuxCore = "1.0.1"
LuxLib = "1.3.2"
MCMCChains = "6"
MLDataDevices = "1.2.0"
MethodOfLines = "0.11.6"
ModelingToolkit = "9.46"
MonteCarloMeasurements = "1.1"
Expand Down
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer
using Lux: FromFluxAdaptor, recursive_eltype
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer
using MCMCChains: MCMCChains, Chains, sample
using MLDataDevices: cpu_device, get_device
using ModelingToolkit: ModelingToolkit, Num, PDESystem, toexpr, expand_derivatives, infimum,
supremum
using MonteCarloMeasurements: Particles
Expand Down
5 changes: 4 additions & 1 deletion src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ NN OUTPUT AT t,θ ~ phi(t,θ).
"""
function (f::LogTargetDensity)(t::AbstractVector, θ)
θ = vector_to_parameters(θ, f.init_params)
return f.prob.u0 .+ (t' .- f.prob.tspan[1]) .* f.smodel(t', θ)
dev = get_device(θ)
t = t |> dev
u0 = f.prob.u0 |> dev
return u0 .+ (t' .- f.prob.tspan[1]) .* f.smodel(t', θ)
end

(f::LogTargetDensity)(t::Number, θ) = f([t], θ)[:, 1]
Expand Down
6 changes: 3 additions & 3 deletions src/neural_adapter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function get_loss_function_neural_adapter(
eqs isa Array || (eqs = [eqs])
eltypeθ = recursive_eltype(init_params)
train_set = generate_training_sets(pde_system.domain, strategy.dx, eqs, eltypeθ)
return get_loss_function(loss, train_set, eltypeθ, strategy)
return get_loss_function(init_params, loss, train_set, eltypeθ, strategy)
end

function get_loss_function_neural_adapter(loss, init_params, pde_system,
Expand All @@ -51,7 +51,7 @@ function get_loss_function_neural_adapter(loss, init_params, pde_system,

eltypeθ = recursive_eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
return get_loss_function(loss, bound, eltypeθ, strategy)
return get_loss_function(init_params, loss, bound, eltypeθ, strategy)
end

function get_loss_function_neural_adapter(
Expand All @@ -64,7 +64,7 @@ function get_loss_function_neural_adapter(

eltypeθ = recursive_eltype(init_params)
bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)
return get_loss_function(loss, bound[1][1], bound[2][1], eltypeθ, strategy)
return get_loss_function(init_params, loss, bound[1][1], bound[2][1], eltypeθ, strategy)
end

"""
Expand Down
19 changes: 15 additions & 4 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ respects boundary conditions, i.e. `phi(t) = u0 + t*NN(t)`.
smodel <: StatefulLuxLayer
end

Functors.@functor ODEPhi (u0, t0)

function ODEPhi(model::AbstractLuxLayer, t0::Number, u0, st)
return ODEPhi(u0, t0, StatefulLuxLayer{true}(model, nothing, st))
end
Expand All @@ -127,13 +129,22 @@ function generate_phi_θ(chain::AbstractLuxLayer, t, u0, init_params)
return ODEPhi(chain, t, u0, st), init_params
end

(f::ODEPhi{<:Number})(t::Number, θ) = f.u0 + (t - f.t0) * first(f.smodel([t], θ.depvar))
function (f::ODEPhi)(t, θ)
dev = get_device(θ)
return (dev(f))(dev, dev(t), θ)
end

(f::ODEPhi{<:Number})(t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar)
function (f::ODEPhi{<:Number})(dev, t::Number, θ)
return f.u0 + (t - f.t0) * first(f.smodel(dev([t]), θ.depvar))
end

function (f::ODEPhi{<:Number})(_, t::AbstractVector, θ)
return f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar)
end

(f::ODEPhi)(t::Number, θ) = f.u0 .+ (t .- f.t0) .* f.smodel([t], θ.depvar)
(f::ODEPhi)(dev, t::Number, θ) = f.u0 .+ (t .- f.t0) .* f.smodel(dev([t]), θ.depvar)

(f::ODEPhi)(t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar)
(f::ODEPhi)(_, t::AbstractVector, θ) = f.u0 .+ (t' .- f.t0) .* f.smodel(t', θ.depvar)

"""
ode_dfdx(phi, t, θ, autodiff)
Expand Down
5 changes: 3 additions & 2 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ function Phi(layer::AbstractLuxLayer)
layer, nothing, initialstates(Random.default_rng(), layer)))
end

(f::Phi)(x::Number, θ) = f([x], θ)[1]
(f::Phi)(x::Number, θ) = (f([x], θ) |> cpu_device())[1]

(f::Phi)(x::AbstractArray, θ) = f.smodel(x, θ)
(f::Phi)(x::AbstractArray, θ) = f.smodel(get_device(θ)(x), θ)

"""
PhysicsInformedNN(chain, strategy; init_params = nothing, phi = nothing,
Expand Down Expand Up @@ -357,6 +357,7 @@ get_u() = (cord, θ, phi) -> phi(cord, θ)
function numeric_derivative(phi, u, x, εs, order, θ)
ε = εs[order]
_epsilon = inv(first(ε[ε .!= zero(ε)]))
ε = ε |> get_device(x)

# any(x->x!=εs[1],εs)
# εs is the epsilon for each order, if they are all the same then we use a fancy formula
Expand Down
17 changes: 12 additions & 5 deletions src/rode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,26 @@ end
smodel <: StatefulLuxLayer
end

Functors.@functor RODEPhi (u0, t0)

RODEPhi(phi::ODEPhi) = RODEPhi(phi.u0, phi.t0, phi.smodel)

function (f::RODEPhi{<:Number})(t::Number, W, θ)
return f.u0 + (t - f.t0) * first(f.smodel([t, W], θ.depvar))
function (f::RODEPhi)(t, W, θ)
dev = get_device(θ)
return (dev(f))(dev, dev(t), dev(W), θ)
end

function (f::RODEPhi{<:Number})(dev, t::Number, W, θ)
return f.u0 + (t - f.t0) * first(f.smodel(dev([t, W]), θ.depvar))
end

function (f::RODEPhi{<:Number})(t::AbstractVector, W, θ)
function (f::RODEPhi{<:Number})(_, t::AbstractVector, W, θ)
return f.u0 .+ (t' .- f.t0) .* f.smodel(vcat(t', W'), θ.depvar)
end

(f::RODEPhi)(t::Number, W, θ) = f.u0 .+ (t .- f.t0) .* f.smodel([t, W], θ.depvar)
(f::RODEPhi)(dev, t::Number, W, θ) = f.u0 .+ (t .- f.t0) .* f.smodel(dev([t, W]), θ.depvar)

function (f::RODEPhi)(t::AbstractVector, W, θ)
function (f::RODEPhi)(_, t::AbstractVector, W, θ)
return f.u0 .+ (t' .- f.t0) .* f.smodel(vcat(t', W'), θ.depvar)
end

Expand Down
74 changes: 44 additions & 30 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation,
# vector of points (pde_train_sets must be rowwise)
pde_loss_functions = if train_sets_pde !== nothing
pde_train_sets = [train_set[:, 2:end] for train_set in train_sets_pde] |> adaptor
[get_loss_function(_loss, _set, eltypeθ, strategy)
[get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_pde_loss_function, pde_train_sets)]
else
nothing
end

bc_loss_functions = if train_sets_bc !== nothing
bcs_train_sets = [train_set[:, 2:end] for train_set in train_sets_bc] |> adaptor
[get_loss_function(_loss, _set, eltypeθ, strategy)
[get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)]
else
nothing
Expand All @@ -53,17 +53,20 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,

# the points in the domain and on the boundary
pde_train_sets, bcs_train_sets = train_sets |> adaptor
pde_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy)
pde_loss_functions = [get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(
datafree_pde_loss_function, pde_train_sets)]

bc_loss_functions = [get_loss_function(_loss, _set, eltypeθ, strategy)
bc_loss_functions = [get_loss_function(pinnrep, _loss, _set, eltypeθ, strategy)
for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)]

return pde_loss_functions, bc_loss_functions
end

function get_loss_function(loss_function, train_set, eltype0, ::GridTraining; τ = nothing)
function get_loss_function(
init_params, loss_function, train_set, eltype0, ::GridTraining; τ = nothing)
init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
train_set = train_set |> get_device(init_params) |> EltypeAdaptor{eltype0}()
return θ -> mean(abs2, loss_function(train_set, θ))
end

Expand Down Expand Up @@ -100,19 +103,21 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy)
pde_bounds, bcs_bounds = bounds

pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy)
pde_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy)
for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)]

bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy)
bc_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy)
for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)]

pde_loss_functions, bc_loss_functions
end

function get_loss_function(loss_function, bound, eltypeθ, strategy::StochasticTraining;
τ = nothing)
function get_loss_function(init_params, loss_function, bound, eltypeθ,
strategy::StochasticTraining; τ = nothing)
init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
dev = get_device(init_params)
return θ -> begin
sets = generate_random_points(strategy.points, bound, eltypeθ) |>
sets = generate_random_points(strategy.points, bound, eltypeθ) |> dev |>
EltypeAdaptor{recursive_eltype(θ)}()
return mean(abs2, loss_function(sets, θ))
end
Expand Down Expand Up @@ -175,35 +180,36 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy)
pde_bounds, bcs_bounds = bounds

pde_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy)
pde_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy)
for (_loss, bound) in zip(datafree_pde_loss_function, pde_bounds)]

strategy_ = QuasiRandomTraining(strategy.bcs_points; strategy.sampling_alg,
strategy.resampling, strategy.minibatch)
bc_loss_functions = [get_loss_function(_loss, bound, eltypeθ, strategy_)
bc_loss_functions = [get_loss_function(pinnrep, _loss, bound, eltypeθ, strategy_)
for (_loss, bound) in zip(datafree_bc_loss_function, bcs_bounds)]

return pde_loss_functions, bc_loss_functions
end

function get_loss_function(loss_function, bound, eltypeθ, strategy::QuasiRandomTraining;
τ = nothing)
function get_loss_function(init_params, loss_function, bound, eltypeθ,
strategy::QuasiRandomTraining; τ = nothing)
(; sampling_alg, points, resampling, minibatch) = strategy

init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
dev = get_device(init_params)

return if resampling
θ -> begin
sets = @ignore_derivatives QuasiMonteCarlo.sample(
points, bound[1], bound[2], sampling_alg)
sets = sets |> EltypeAdaptor{eltypeθ}()
sets = sets |> dev |> EltypeAdaptor{eltypeθ}()
return mean(abs2, loss_function(sets, θ))
end
else
point_batch = generate_quasi_random_points_batch(
points, bound, eltypeθ, sampling_alg, minibatch)
θ -> begin
sets = point_batch[rand(1:minibatch)] |> EltypeAdaptor{eltypeθ}()
return mean(abs2, loss_function(sets, θ))
end
points, bound, eltypeθ, sampling_alg, minibatch) |> dev |>
EltypeAdaptor{eltypeθ}()
θ -> mean(abs2, loss_function(point_batch[rand(1:minibatch)], θ))
end
end

Expand Down Expand Up @@ -250,27 +256,33 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation,
pde_bounds, bcs_bounds = bounds

lbs, ubs = pde_bounds
pde_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy)
pde_loss_functions = [get_loss_function(pinnrep, _loss, lb, ub, eltypeθ, strategy)
for (_loss, lb, ub) in zip(datafree_pde_loss_function, lbs, ubs)]
lbs, ubs = bcs_bounds
bc_loss_functions = [get_loss_function(_loss, lb, ub, eltypeθ, strategy)
bc_loss_functions = [get_loss_function(pinnrep, _loss, lb, ub, eltypeθ, strategy)
for (_loss, lb, ub) in zip(datafree_bc_loss_function, lbs, ubs)]

return pde_loss_functions, bc_loss_functions
end

function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::QuadratureTraining;
τ = nothing)
length(lb) == 0 && return (θ) -> mean(abs2, loss_function(rand(eltypeθ, 1, 10), θ))
function get_loss_function(init_params, loss_function, lb, ub, eltypeθ,
strategy::QuadratureTraining; τ = nothing)
init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
dev = get_device(init_params)

if length(lb) == 0
return (θ) -> mean(abs2, loss_function(dev(rand(eltypeθ, 1, 10)), θ))
end

area = eltypeθ(prod(abs.(ub .- lb)))
f_ = (lb, ub, loss_, θ) -> begin
function integrand(x, θ)
x = x |> EltypeAdaptor{eltypeθ}()
sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x
x = x |> dev |> EltypeAdaptor{eltypeθ}()
return sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x
end
integral_function = BatchIntegralFunction(integrand, max_batch = strategy.batch)
prob = IntegralProblem(integral_function, (lb, ub), θ)
solve(prob, strategy.quadrature_alg; strategy.reltol, strategy.abstol,
return solve(prob, strategy.quadrature_alg; strategy.reltol, strategy.abstol,
strategy.maxiters)[1]
end
return (θ) -> 1 / area * f_(lb, ub, loss_function, θ)
Expand Down Expand Up @@ -299,7 +311,9 @@ This training strategy can only be used with ODEs (`NNODE`).
points::Int
end

function get_loss_function(loss_function, train_set, eltype0, ::WeightedIntervalTraining;
τ = nothing)
function get_loss_function(init_params, loss_function, train_set, eltype0,
::WeightedIntervalTraining; τ = nothing)
init_params = init_params isa PINNRepresentation ? init_params.init_params : init_params
train_set = train_set |> get_device(init_params) |> EltypeAdaptor{eltype0}()
return (θ) -> mean(abs2, loss_function(train_set, θ))
end
Loading

0 comments on commit a453576

Please sign in to comment.