-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Working example with two rotations and AD capabilities
- Loading branch information
1 parent
39c8ca7
commit a9b366b
Showing
5 changed files
with
157 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,21 @@ | ||
module SphereFit | ||
|
||
# Write your package code here. | ||
# types | ||
using Base: @kwdef | ||
# utils | ||
# training | ||
using LinearAlgebra, Statistics | ||
using Lux | ||
using OrdinaryDiffEq | ||
using SciMLSensitivity | ||
using Optimization, OptimizationOptimisers, OptimizationOptimJL | ||
using ComponentArrays: ComponentVector | ||
|
||
export SphereParameters, SphereData | ||
export train_sphere | ||
|
||
include("utils.jl") | ||
include("types.jl") | ||
include("train.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
export train_sphere | ||
|
||
function get_NN(params, rng, θ_trained) | ||
# Define neural network | ||
U = Lux.Chain( | ||
Lux.Dense(1, 5, tanh), | ||
Lux.Dense(5, 10, tanh), | ||
Lux.Dense(10, 5, tanh), | ||
Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax)) | ||
) | ||
# This is what we have in ODINN.jl, but not clear if neccesary | ||
# | ||
# UA = Flux.f64(UA) | ||
# # See if parameters need to be retrained or not | ||
# θ, UA_f = Flux.destructure(UA) | ||
# if !isempty(θ_trained) | ||
# θ = θ_trained | ||
# end | ||
# return UA_f, θ | ||
|
||
θ, st = Lux.setup(rng, U) | ||
return U, θ, st | ||
end | ||
|
||
function train_sphere(data::AbstractData, | ||
params::AbstractParameters, | ||
rng, | ||
θ_trained=[]) | ||
|
||
U, θ, st = get_NN(params, rng, θ_trained) | ||
|
||
function ude_rotation!(du, u, p, t) | ||
# Angular momentum given by network prediction | ||
L = U([t], p, st)[1] | ||
du .= cross(L, u) | ||
nothing | ||
end | ||
|
||
prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], θ) | ||
|
||
function predict(θ; u0=params.u0, T=data.times) | ||
_prob = remake(prob_nn, u0=u0, | ||
tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), | ||
p = θ) | ||
Array(solve(_prob, Tsit5(), saveat=T, | ||
abstol=params.abstol, reltol=params.reltol, | ||
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))) | ||
end | ||
|
||
function loss(θ) | ||
u_ = predict(θ) | ||
# Empirical error | ||
l_ = mean(abs2, u_ .- data.directions) | ||
return l_ | ||
end | ||
|
||
losses = Float64[] | ||
callback = function (p, l) | ||
push!(losses, l) | ||
if length(losses) % 50 == 0 | ||
println("Current loss after $(length(losses)) iterations: $(losses[end])") | ||
end | ||
return false | ||
end | ||
|
||
adtype = Optimization.AutoZygote() | ||
optf = Optimization.OptimizationFunction((x, θ) -> loss(x), adtype) | ||
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(θ)) | ||
|
||
res1 = Optimization.solve(optprob, ADAM(0.001), callback=callback, maxiters=1000) | ||
println("Training loss after $(length(losses)) iterations: $(losses[end])") | ||
|
||
optprob2 = Optimization.OptimizationProblem(optf, res1.u) | ||
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=300) | ||
println("Final training loss after $(length(losses)) iterations: $(losses[end])") | ||
|
||
θ_trained = res2.u | ||
|
||
return θ_trained, U, st | ||
end | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
export SphereParameters, AbstractParameters | ||
export SphereData, AbstractData | ||
|
||
abstract type AbstractParameters end | ||
|
||
@kwdef struct SphereParameters{F <: AbstractFloat} <: AbstractParameters | ||
tmin::F | ||
tmax::F | ||
u0::Union{Vector{F}, Nothing} | ||
ωmax::F | ||
reltol::F | ||
abstol::F | ||
end | ||
|
||
abstract type AbstractData end | ||
|
||
@kwdef struct SphereData{F <: AbstractFloat} <: AbstractData | ||
times::Vector{F} | ||
directions::Matrix{F} | ||
kappas::Union{Vector{F}, Nothing} | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
export sigmoid_cap, relu_cap | ||
|
||
# Normalization of the NN. Ideally we want to do this with L2 norm . | ||
|
||
""" | ||
sigmoid_cap(x; ω₀) | ||
""" | ||
function sigmoid_cap(x; ω₀) | ||
min_value = - 2ω₀ | ||
max_value = + 2ω₀ | ||
return min_value + (max_value - min_value) / ( 1.0 + exp(-x) ) | ||
end | ||
|
||
function relu_cap(x; ω₀) | ||
min_value = - 2ω₀ | ||
max_value = + 2ω₀ | ||
return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) | ||
end | ||
|
||
|
||
|
||
|