From a9b366bf4ca4fc77d7a29ea9490d3678f423bfd8 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Sat, 23 Dec 2023 22:16:09 -0300 Subject: [PATCH] Working example with two rotations and AD capabilities --- Project.toml | 14 ++++++++ src/SphereFit.jl | 18 ++++++++++- src/train.jl | 83 ++++++++++++++++++++++++++++++++++++++++++++++++ src/types.jl | 21 ++++++++++++ src/utils.jl | 22 +++++++++++++ 5 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 src/train.jl create mode 100644 src/types.jl create mode 100644 src/utils.jl diff --git a/Project.toml b/Project.toml index fbbecea..6973079 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,20 @@ uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d" authors = ["Facundo Sapienza"] version = "1.0.0-DEV" +[deps] +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" +OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [compat] julia = "1" diff --git a/src/SphereFit.jl b/src/SphereFit.jl index 82ed937..ea85c37 100644 --- a/src/SphereFit.jl +++ b/src/SphereFit.jl @@ -1,5 +1,21 @@ module SphereFit -# Write your package code here. +# types +using Base: @kwdef +# utils +# training +using LinearAlgebra, Statistics +using Lux +using OrdinaryDiffEq +using SciMLSensitivity +using Optimization, OptimizationOptimisers, OptimizationOptimJL +using ComponentArrays: ComponentVector + +export SphereParameters, SphereData +export train_sphere + +include("utils.jl") +include("types.jl") +include("train.jl") end diff --git a/src/train.jl b/src/train.jl new file mode 100644 index 0000000..d1c8073 --- /dev/null +++ b/src/train.jl @@ -0,0 +1,83 @@ +export train_sphere + +function get_NN(params, rng, θ_trained) + # Define neural network + U = Lux.Chain( + Lux.Dense(1, 5, tanh), + Lux.Dense(5, 10, tanh), + Lux.Dense(10, 5, tanh), + Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax)) + ) + # This is what we have in ODINN.jl, but not clear if neccesary + # + # UA = Flux.f64(UA) + # # See if parameters need to be retrained or not + # θ, UA_f = Flux.destructure(UA) + # if !isempty(θ_trained) + # θ = θ_trained + # end + # return UA_f, θ + + θ, st = Lux.setup(rng, U) + return U, θ, st +end + +function train_sphere(data::AbstractData, + params::AbstractParameters, + rng, + θ_trained=[]) + + U, θ, st = get_NN(params, rng, θ_trained) + + function ude_rotation!(du, u, p, t) + # Angular momentum given by network prediction + L = U([t], p, st)[1] + du .= cross(L, u) + nothing + end + + prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], θ) + + function predict(θ; u0=params.u0, T=data.times) + _prob = remake(prob_nn, u0=u0, + tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), + p = θ) + Array(solve(_prob, Tsit5(), saveat=T, + abstol=params.abstol, reltol=params.reltol, + sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))) + end + + function loss(θ) + u_ = predict(θ) + # Empirical error + l_ = mean(abs2, u_ .- data.directions) + return l_ + end + + losses = Float64[] + callback = function (p, l) + push!(losses, l) + if length(losses) % 50 == 0 + println("Current loss after $(length(losses)) iterations: $(losses[end])") + end + return false + end + + adtype = Optimization.AutoZygote() + optf = Optimization.OptimizationFunction((x, θ) -> loss(x), adtype) + optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(θ)) + + res1 = Optimization.solve(optprob, ADAM(0.001), callback=callback, maxiters=1000) + println("Training loss after $(length(losses)) iterations: $(losses[end])") + + optprob2 = Optimization.OptimizationProblem(optf, res1.u) + res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=300) + println("Final training loss after $(length(losses)) iterations: $(losses[end])") + + θ_trained = res2.u + + return θ_trained, U, st +end + + + diff --git a/src/types.jl b/src/types.jl new file mode 100644 index 0000000..c446059 --- /dev/null +++ b/src/types.jl @@ -0,0 +1,21 @@ +export SphereParameters, AbstractParameters +export SphereData, AbstractData + +abstract type AbstractParameters end + +@kwdef struct SphereParameters{F <: AbstractFloat} <: AbstractParameters + tmin::F + tmax::F + u0::Union{Vector{F}, Nothing} + ωmax::F + reltol::F + abstol::F +end + +abstract type AbstractData end + +@kwdef struct SphereData{F <: AbstractFloat} <: AbstractData + times::Vector{F} + directions::Matrix{F} + kappas::Union{Vector{F}, Nothing} +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..6de88d3 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,22 @@ +export sigmoid_cap, relu_cap + +# Normalization of the NN. Ideally we want to do this with L2 norm . + +""" + sigmoid_cap(x; ω₀) +""" +function sigmoid_cap(x; ω₀) + min_value = - 2ω₀ + max_value = + 2ω₀ + return min_value + (max_value - min_value) / ( 1.0 + exp(-x) ) +end + +function relu_cap(x; ω₀) + min_value = - 2ω₀ + max_value = + 2ω₀ + return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) +end + + + +