Skip to content

Commit

Permalink
feat: Updated regression example to produce something more useful and…
Browse files Browse the repository at this point in the history
… more inline with the paper.
  • Loading branch information
DoktorMike committed Jul 18, 2022
1 parent afb03d1 commit 8577d8e
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 50 deletions.
151 changes: 104 additions & 47 deletions examples/regression.jl
Original file line number Diff line number Diff line change
@@ -1,77 +1,134 @@
using EvidentialFlux
using Flux
using Flux.Optimise: ADAM
using UnicodePlots
using Flux.Optimise: AdamW, Adam
using GLMakie
using Statistics


const epochs = 10000
const lr = 0.001

function gendata()
x = Float32.(collect(-2π:0.1:2π))
y = Float32.(sin.(x) .+ 0.3 * randn(size(x)))
#scatterplot(x, y)
x, y
f1(x) = sin.(x)
f2(x) = 0.01 * x .^ 3 .- 0.1 * x
f3(x) = x .^ 3
function gendata(id = 1)
x = Float32.(-4:0.05:4)
if id == 1
y = f1(x) .+ 0.2 * randn(size(x))
elseif id == 2
y = f2(x) .* (1.0 .+ 0.2 * randn(size(x))) .+ 0.2 * randn(size(x))
else
y = f3(x) .+ randn(size(x)) .* 3.0
end
#scatterplot(x, y)
x, y
end

"""
predict(m, x)
predict_all(m, x)
Predicts the output of the model m on the input x.
"""
function predict(m, x)
= m(x)
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
(pred=γ, eu=uncertainty(ν, α, β), au=uncertainty(α, β))
function predict_all(m, x)
= m(x)
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
# Correction for α = 1 case
α = α .+ 1.0f-7
au = uncertainty(α, β)
eu = uncertainty(ν, α, β)
(pred = γ, eu = eu, au = au)
end

function plotfituncert!(m, x, y, wband = true)
ŷ, u, au = predict_all(m, x')
#u, au = u ./ maximum(u), au ./ maximum(au)
u, au = u ./ maximum(u) .* std(y), au ./ maximum(au) .* std(y)
GLMakie.scatter!(x, y, color = "#5E81AC")
GLMakie.lines!(x, ŷ, color = "#BF616A", linewidth = 5)
if wband == true
GLMakie.band!(x, ŷ + u, ŷ - u, color = "#5E81ACAC")
else
GLMakie.scatter!(x, u, color = :yellow)
GLMakie.scatter!(x, au, color = :green)
end
end

function plotfituncert(m, x, y, wband = true)
f = Figure()
Axis(f[1, 1])
ŷ, u, au = predict_all(m, x')
#u, au = u ./ maximum(u), au ./ maximum(au)
#u, au = u ./ maximum(u) .* std(y), au ./ maximum(au) .* std(y)
GLMakie.scatter!(x, y, color = "#5E81AC")
GLMakie.lines!(x, ŷ, color = "#BF616A", linewidth = 5)
if wband == true
#GLMakie.band!(x, ŷ + u, ŷ - u, color = "#5E81ACAC")
GLMakie.band!(x, ŷ + u, ŷ - u, color = "#EBCB8BAC")
GLMakie.band!(x, ŷ + au, ŷ - au, color = "#A3BE8CAC")
else
GLMakie.scatter!(x, u, color = :yellow)
GLMakie.scatter!(x, au, color = :green)
end
f
end

mae(y, ŷ) = Statistics.mean(abs.(y - ŷ))

x, y = gendata()
p = scatterplot(x, y, width = 80, height = 30)
lines!(p, x, sin.(x))
x, y = gendata(3)
GLMakie.scatter(x, y)
GLMakie.lines!(x, f3(x))

m = Chain(Dense(1 => 100, tanh), NIG(100 => 1))
epochs = 10000
lr = 0.005
m = Chain(Dense(1 => 100, relu), Dense(100 => 100, relu), Dense(100 => 100, relu), NIG(100 => 1))
#m(x')
opt = ADAM(lr)
opt = AdamW(lr, (0.89, 0.995), 0.001)
#opt = Flux.Optimiser(AdamW(lr), ClipValue(1e1))
pars = Flux.params(m)
trnlosses = zeros(epochs)
f = Figure()
Axis(f[1, 1])
for epoch in 1:epochs
local trnloss = 0
grads = Flux.gradient(pars) do
= m(x')
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
trnloss = Statistics.mean(nigloss(y, γ, ν, α, β, 0.01, 1e-4))
trnloss
end
trnlosses[epoch] = trnloss
# Test that we can update the weights based on gradients
Flux.Optimise.update!(opt, pars, grads)
local trnloss = 0
grads = Flux.gradient(pars) do
= m(x')
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
trnloss = Statistics.mean(nigloss(y, γ, ν, α, β, 0.01, 0.001))
trnloss
end
Flux.Optimise.update!(opt, pars, grads)
trnlosses[epoch] = trnloss
if epoch % 2000 == 0
println("Epoch: $epoch, Loss: $trnloss")
plotfituncert!(m, x, f3(x), true)
end
end

# The convergance plot shows the loss function converges to a local minimum
scatterplot(1:epochs, trnlosses, width = 80)
GLMakie.scatter(1:epochs, trnlosses)
# And the MAE corresponds to the noise we added in the target
ŷ, u, au = predict(m, x')
ŷ, u, au = predict_all(m, x')
u, au = u ./ maximum(u), au ./ maximum(au)
println("MAE: $(mae(y, ŷ))")

# Correlation plot confirms the fit
p = scatterplot(y, ŷ, width = 80, height = 30, marker = "o");
lines!(p, -2:0.01:2, -2:0.01:2)

p = scatterplot(x, y, width = 80, height = 30, marker = "o");
scatterplot!(p, x, ŷ, color = :red, marker = "x");
scatterplot!(p, x, u)
GLMakie.scatter(y, ŷ)
GLMakie.lines!(-2:0.01:2, -2:0.01:2)

p = scatterplot(x, ŷ, marker = :x, width = 80, height = 30, color = :red);
scatterplot!(p, x, y, marker = :x, color = :blue)
plotfituncert(m, x, y, true)
GLMakie.ylims!(-100, 100)

## Out of sample predictions
## Out of sample predictions to the left and right
xood = Float32.(-6:0.2:6);
plotfituncert(m, xood, f3(xood), true)
GLMakie.ylims!(-200, 200)
GLMakie.band!(4:0.01:6, -200, 200, color = "#8FBCBBB1")
GLMakie.band!(-6:0.01:-4, -200, 200, color = "#8FBCBBB1")
GLMakie.xlabel("hello")

x = Float32.(collect(0:0.1:3π));
ŷ, u, au = predict(m, x');
## Out of sample predictions to the right
xood = Float32.(0:0.2:6);
plotfituncert(m, xood, f3(xood), true)
GLMakie.ylims!(-200, 200)

p = scatterplot(x, sin.(x), width = 80, height = 30, marker = "o");
scatterplot!(p, x, ŷ, color = :red, marker = "x");
scatterplot!(p, x, u)
scatterplot!(p, x, au)
## Out of sample predictions to the left
xood = Float32.(-6:0.2:0);
plotfituncert(m, xood, f3(xood), true)
GLMakie.ylims!(-200, 200)
Binary file added images/cubefun.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 0 additions & 3 deletions src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ function nigloss(y, γ, ν, α, β, λ = 1, ϵ = 1e-4)
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) +
logγ.(α) -
logγ.(α .+ 0.5)
nll

# REG: Calculate regularizer based on absolute error of prediction
error = abs.(y - γ)
reg = error .* (2 * ν + α)

# Combine negative log likelihood and regularizer
loss = nll + λ .* (reg .- ϵ)
loss
Expand Down

0 comments on commit 8577d8e

Please sign in to comment.