Skip to content

Commit

Permalink
Change activation function to reproduce knots
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed Jan 10, 2024
1 parent 480bf9f commit b36934e
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 27 deletions.
10 changes: 5 additions & 5 deletions examples/double_rotation/double_rotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ rng = Random.default_rng()
Random.seed!(rng, 666)

# Total time simulation
tspan = [0, 130.0]
tspan = [0, 160.0]
# Number of sample points
N_samples = 50
# Times where we sample points
Expand Down Expand Up @@ -64,19 +64,19 @@ X_true = X_noiseless + FisherNoise(kappa=200.)

data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true)

regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode="Finite Differences"),
regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="Finite Differences"),
Regularization(order=0, power=2.0, λ=0.1, diff_mode="Finite Differences")]

params = SphereParameters(tmin=tspan[1], tmax=tspan[2],
reg=regs,
u0=[0.0, 0.0, -1.0], ωmax=2ω₀, reltol=reltol, abstol=abstol,
niter_ADAM=1000, niter_LBFGS=400)
u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol,
niter_ADAM=1000, niter_LBFGS=600)

results = train(data, params, rng, nothing)

##############################################################
###################### PyCall Plots #########################
##############################################################

plot_sphere(data, results, -20., 150., "examples/double_rotation/plot_module.pdf")
plot_sphere(data, results, -20., 150., "examples/double_rotation/plot_sphere.pdf")
plot_L(data, results, saveas="examples/double_rotation/plot_L.pdf")
Binary file modified examples/double_rotation/plot_L.pdf
Binary file not shown.
Binary file modified examples/double_rotation/plot_sphere.pdf
Binary file not shown.
19 changes: 6 additions & 13 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,11 @@ export train
function get_NN(params, rng, θ_trained)
# Define neural network
U = Lux.Chain(
Lux.Dense(1, 5, tanh),
Lux.Dense(5, 10, tanh),
Lux.Dense(10, 5, tanh),
Lux.Dense(1, 5, relu_cap), # explore discontinuity function for activation
Lux.Dense(5, 10, relu_cap),
Lux.Dense(10, 5, relu_cap),
Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax))
)
# This is what we have in ODINN.jl, but not clear if neccesary
#
# UA = Flux.f64(UA)
# # See if parameters need to be retrained or not
# θ, UA_f = Flux.destructure(UA)
# if !isempty(θ_trained)
# θ = θ_trained
# end
# return UA_f, θ

θ, st = Lux.setup(rng, U)
return U, θ, st
end
Expand All @@ -29,6 +19,9 @@ function train(data::AbstractData,

U, θ, st = get_NN(params, rng, θ_trained)

# one option is to restrict where the NN is evaluated to discrete t to
# generate piece-wise dynamics.

function ude_rotation!(du, u, p, t)
# Angular momentum given by network prediction
L = U([t], p, st)[1]
Expand Down
27 changes: 18 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
export sigmoid_cap, relu_cap
export sigmoid_cap, relu_cap, step_cap
export cart2sph
export AbstractNoise, FisherNoise

# Normalization of the NN. Ideally we want to do this with L2 norm .

"""
sigmoid_cap(x; ω₀)
sigmoid_cap(x; ω₀=1.0)
"""
function sigmoid_cap(x; ω₀)
min_value = - 2ω
max_value = + 2ω
function sigmoid_cap(x; ω₀=1.0)
min_value = - ω
max_value = + ω
return min_value + (max_value - min_value) / ( 1.0 + exp(-x) )
end

function relu_cap(x; ω₀)
min_value = - 2ω₀
max_value = + 2ω₀
"""
relu_cap(x; ω₀=1.0)
"""
function relu_cap(x; ω₀=1.0)
min_value = - ω₀
max_value = + ω₀
return min_value + (max_value - min_value) * max(0.0, min(x, 1.0))
end

"""
cart2sph(X::AbstractArray{<:Number}; radians::Bool=true)
Convert cartesian coordinates to spherical
"""
function cart2sph(X::AbstractArray{<:Number}; radians::Bool=true)
Expand All @@ -33,8 +38,12 @@ end

"""
Add Fisher noise to matrix of three dimensional unit vectors
"""
This is carried by the definition of type FisherNoise <: AbstractNoise and
extending the base definition +(,) to allow the simple syntax
X_noise = X_noiseless + FisherNoise(kappa=200.)
"""
abstract type AbstractNoise end

@kwdef struct FisherNoise{F <: AbstractFloat} <: AbstractNoise
Expand Down

0 comments on commit b36934e

Please sign in to comment.