Skip to content

Commit

Permalink
Working example with two rotations and AD capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed Dec 24, 2023
1 parent 39c8ca7 commit a9b366b
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 1 deletion.
14 changes: 14 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d"
authors = ["Facundo Sapienza"]
version = "1.0.0-DEV"

[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1"

Expand Down
18 changes: 17 additions & 1 deletion src/SphereFit.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
module SphereFit

# Write your package code here.
# types
using Base: @kwdef
# utils
# training
using LinearAlgebra, Statistics
using Lux
using OrdinaryDiffEq
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays: ComponentVector

export SphereParameters, SphereData
export train_sphere

include("utils.jl")
include("types.jl")
include("train.jl")

end
83 changes: 83 additions & 0 deletions src/train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
export train_sphere

function get_NN(params, rng, θ_trained)
# Define neural network
U = Lux.Chain(
Lux.Dense(1, 5, tanh),
Lux.Dense(5, 10, tanh),
Lux.Dense(10, 5, tanh),
Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax))
)
# This is what we have in ODINN.jl, but not clear if neccesary
#
# UA = Flux.f64(UA)
# # See if parameters need to be retrained or not
# θ, UA_f = Flux.destructure(UA)
# if !isempty(θ_trained)
# θ = θ_trained
# end
# return UA_f, θ

θ, st = Lux.setup(rng, U)
return U, θ, st
end

function train_sphere(data::AbstractData,
params::AbstractParameters,
rng,
θ_trained=[])

U, θ, st = get_NN(params, rng, θ_trained)

function ude_rotation!(du, u, p, t)
# Angular momentum given by network prediction
L = U([t], p, st)[1]
du .= cross(L, u)
nothing
end

prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], θ)

function predict(θ; u0=params.u0, T=data.times)
_prob = remake(prob_nn, u0=u0,
tspan=(min(T[1], params.tmin), max(T[end], params.tmax)),
p = θ)
Array(solve(_prob, Tsit5(), saveat=T,
abstol=params.abstol, reltol=params.reltol,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss(θ)
u_ = predict(θ)
# Empirical error
l_ = mean(abs2, u_ .- data.directions)
return l_
end

losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, θ) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(θ))

res1 = Optimization.solve(optprob, ADAM(0.001), callback=callback, maxiters=1000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=300)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

θ_trained = res2.u

return θ_trained, U, st
end



21 changes: 21 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
export SphereParameters, AbstractParameters
export SphereData, AbstractData

abstract type AbstractParameters end

@kwdef struct SphereParameters{F <: AbstractFloat} <: AbstractParameters
tmin::F
tmax::F
u0::Union{Vector{F}, Nothing}
ωmax::F
reltol::F
abstol::F
end

abstract type AbstractData end

@kwdef struct SphereData{F <: AbstractFloat} <: AbstractData
times::Vector{F}
directions::Matrix{F}
kappas::Union{Vector{F}, Nothing}
end
22 changes: 22 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
export sigmoid_cap, relu_cap

# Normalization of the NN. Ideally we want to do this with L2 norm .

"""
sigmoid_cap(x; ω₀)
"""
function sigmoid_cap(x; ω₀)
min_value = - 2ω₀
max_value = + 2ω₀
return min_value + (max_value - min_value) / ( 1.0 + exp(-x) )
end

function relu_cap(x; ω₀)
min_value = - 2ω₀
max_value = + 2ω₀
return min_value + (max_value - min_value) * max(0.0, min(x, 1.0))
end




0 comments on commit a9b366b

Please sign in to comment.