Skip to content

Commit

Permalink
feat: Implemented the DER correction 🥶
Browse files Browse the repository at this point in the history
- Added a new method for aleatoric uncertainty
- Added a new method for epistemic uncertainty
- Added a new loss function for the NIG loss
- Updated documentation and added a quick example in README

This one closes #9
  • Loading branch information
DoktorMike committed Jul 19, 2022
1 parent 20b1280 commit 450a76f
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 7 deletions.
40 changes: 37 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@

This is a Julia implementation in Flux of the Evidential Deep Learning framework. It strives to estimate heteroskedastic aleatoric uncertainty as well as epistemic uncertainty along with every prediction made. All of it calculated in one glorious forward pass. Boom!

## For the impatient

Below is an example of how to train Deep Evidential Regression model, extract
the predictions as well as the epistemic and aleatoric uncertainty. For a more
elaborate example have a look in the examples folder.

```julia
using Flux
using EvidentialFlux

x = Float32.(-4:0.1:4)
y = x .^3 .+ randn(Float32, length(x)) .* 3

lr = 0.0005
m = Chain(Dense(1 => 100, relu), Dense(100 => 100, relu), Dense(100 => 100, relu), NIG(100 => 1))
opt = AdamW(lr, (0.89, 0.995), 0.001)
pars = Flux.params(m)
for epoch in 1:500
grads = Flux.gradient(pars) do
= m(x')
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
trnloss = Statistics.mean(nigloss2(y, γ, ν, α, β, 0.01, 2))
trnloss
end
Flux.Optimise.update!(opt, pars, grads)
end

γ, ν, α, β = predict(m, x)
eu = epistemic(ν)
au = aleatoric(ν, α, β)
```

## Classification

Expand All @@ -19,14 +50,17 @@ for this example can be found in

## Regression

In the case of a regression problem we utilize the NormalInverseGamma distribution to model a type II likelihood
function that then explicitely models the aleatoric and epistemic uncertainty.
In the case of a regression problem we utilize the NormalInverseGamma
distribution to model a type II likelihood function that then explicitely
models the aleatoric and epistemic uncertainty. The code for the example
producing the plot below can be found in
[regression.jl](examples/regression.jl).

![uncertainty](images/cubefun.png)

## Summary

Uncertainty is crucial for the deployment and utilization of robust machine
learning in production. No model is perfect and each and every one of them have
learning in production. No model is perfect and each one of them have
their own strengths and weaknesses, but as a minimum requirement we should all
at least demand that our models report uncertainty in every prediction.
15 changes: 14 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Evidential Deep Learning is a way to generate predictions and the uncertainty
associated with them in one single forward pass. This is in stark contrast to
traditional Bayesian neural networks which are typically based on Variational
Inference, Markov Chain Monte Carlo, Monte Carlo Dropout or Ensembles.
Inference, Markov Chain Monte Carlo, Monte Carlo Dropout or Ensembles.

## Deep Evidential Regression

Expand Down Expand Up @@ -53,6 +53,14 @@ variable, namely ``\gamma,\nu,\alpha,\beta``. This means that in one forward
pass we can estimate the prediction, the heteroskedastic aleatoric uncertainty
as well as the epistemic uncertainty. Boom!

### Theoretical justifications

Although for the problems illustrated by Amini et. al., this approach seems to
work well it has been shown in [^nis2022] that there are theoretical
shortcomings regarding the expression of the aleatoric and epistemic
uncertainty. They propose a correction of the loss, and the uncertainty
calculations. In this package I have implemented both.

## Deep Evidential Classification

We follow [^sensoy2018] in our implementation of Deep Evidential
Expand Down Expand Up @@ -87,8 +95,11 @@ DIR
NIG
predict
uncertainty
aleatoric
epistemic
evidence
nigloss
nigloss2
dirloss
```

Expand All @@ -103,3 +114,5 @@ dirloss

[^sensoy2018]: Sensoy, Murat, Lance Kaplan, and Melih Kandemir. “Evidential Deep Learning to Quantify Classification Uncertainty.” Advances in Neural Information Processing Systems 31 (June 2018): 3179–89.

[^nis2022]: Meinert, Nis, Jakob Gawlikowski, and Alexander Lavin. “The Unreasonable Effectiveness of Deep Evidential Regression.” arXiv, May 20, 2022. http://arxiv.org/abs/2205.10060.

133 changes: 133 additions & 0 deletions examples/regression2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using EvidentialFlux
using Flux
using Flux.Optimise: AdamW, Adam
using GLMakie
using Statistics


f1(x) = sin.(x)
f2(x) = 0.01 * x .^ 3 .- 0.1 * x
f3(x) = x .^ 3
function gendata(id = 1)
x = Float32.(-4:0.05:4)
if id == 1
y = f1(x) .+ 0.2 * randn(size(x))
elseif id == 2
y = f2(x) .* (1.0 .+ 0.2 * randn(size(x))) .+ 0.2 * randn(size(x))
else
y = f3(x) .+ randn(size(x)) .* 3.0
end
#scatterplot(x, y)
x, y
end

"""
predict_all(m, x)
Predicts the output of the model m on the input x.
"""
function predict_all(m, x)
= m(x)
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
# Correction for α = 1 case
α = α .+ 1.0f-7
au = aleatoric(ν, α, β)
eu = epistemic(ν)
(pred = γ, eu = eu, au = au)
end

function plotfituncert!(m, x, y, wband = true)
ŷ, u, au = predict_all(m, x')
#u, au = u ./ maximum(u), au ./ maximum(au)
u, au = u ./ maximum(u) .* std(y), au ./ maximum(au) .* std(y)
GLMakie.scatter!(x, y, color = "#5E81AC")
GLMakie.lines!(x, ŷ, color = "#BF616A", linewidth = 5)
if wband == true
GLMakie.band!(x, ŷ + u, ŷ - u, color = "#5E81ACAC")
else
GLMakie.scatter!(x, u, color = :yellow)
GLMakie.scatter!(x, au, color = :green)
end
end

function plotfituncert(m, x, y, wband = true)
f = Figure()
Axis(f[1, 1], xlabel = "x", ylabel = "y")
ŷ, u, au = predict_all(m, x')
#u, au = u ./ maximum(u), au ./ maximum(au)
#u, au = u ./ maximum(u) .* std(y), au ./ maximum(au) .* std(y)
GLMakie.scatter!(x, y, color = "#5E81AC")
GLMakie.lines!(x, ŷ, color = "#BF616A", linewidth = 5)
if wband == true
#GLMakie.band!(x, ŷ + u, ŷ - u, color = "#5E81ACAC")
GLMakie.band!(x, ŷ + u, ŷ - u, color = "#EBCB8BAC")
GLMakie.band!(x, ŷ + au, ŷ - au, color = "#A3BE8CAC")
else
GLMakie.scatter!(x, u, color = :yellow)
GLMakie.scatter!(x, au, color = :green)
end
f
end

mae(y, ŷ) = Statistics.mean(abs.(y - ŷ))

x, y = gendata(3)
GLMakie.scatter(x, y)
GLMakie.lines!(x, f3(x))

epochs = 6000
lr = 0.0005
m = Chain(Dense(1 => 100, relu), Dense(100 => 100, relu), Dense(100 => 100, relu), NIG(100 => 1))
#m(x')
opt = AdamW(lr, (0.89, 0.995), 0.001)
#opt = Flux.Optimiser(AdamW(lr), ClipValue(1e1))
pars = Flux.params(m)
trnlosses = zeros(epochs)
f = Figure()
Axis(f[1, 1])
for epoch in 1:epochs
local trnloss = 0
grads = Flux.gradient(pars) do
= m(x')
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
trnloss = Statistics.mean(nigloss2(y, γ, ν, α, β, 0.01, 1))
trnloss
end
Flux.Optimise.update!(opt, pars, grads)
trnlosses[epoch] = trnloss
if epoch % 2000 == 0
println("Epoch: $epoch, Loss: $trnloss")
plotfituncert!(m, x, f3(x), true)
end
end

# The convergance plot shows the loss function converges to a local minimum
GLMakie.scatter(1:epochs, trnlosses)
# And the MAE corresponds to the noise we added in the target
ŷ, u, au = predict_all(m, x')
u, au = u ./ maximum(u), au ./ maximum(au)
println("MAE: $(mae(y, ŷ))")

# Correlation plot confirms the fit
GLMakie.scatter(y, ŷ)
GLMakie.lines!(-2:0.01:2, -2:0.01:2)

plotfituncert(m, x, y, true)
GLMakie.ylims!(-100, 100)

## Out of sample predictions to the left and right
xood = Float32.(-6:0.2:6);
plotfituncert(m, xood, f3(xood), true)
GLMakie.ylims!(-200, 200)
GLMakie.band!(4:0.01:6, -200, 200, color = "#8FBCBBB1")
GLMakie.band!(-6:0.01:-4, -200, 200, color = "#8FBCBBB1")

## Out of sample predictions to the right
xood = Float32.(0:0.2:6);
plotfituncert(m, xood, f3(xood), true)
GLMakie.ylims!(-200, 200)

## Out of sample predictions to the left
xood = Float32.(-6:0.2:0);
plotfituncert(m, xood, f3(xood), true)
GLMakie.ylims!(-200, 200)
3 changes: 3 additions & 0 deletions src/EvidentialFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ export DIR

include("losses.jl")
export nigloss
export nigloss2
export dirloss

include("utils.jl")
export uncertainty
export aleatoric
export epistemic
export evidence
export predict

Expand Down
44 changes: 41 additions & 3 deletions src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ function: μ and σ.
"""
function nigloss(y, γ, ν, α, β, λ = 1, ϵ = 1e-4)
# NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution
twoβλ = 2 * β .* (1 .+ ν)
Ω = 2 * β .* (1 .+ ν)
logγ = SpecialFunctions.loggamma
nll = 0.5 * log.(π ./ ν) -
α .* log.(twoβλ) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) +
α .* log.(Ω) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + Ω) +
logγ.(α) -
logγ.(α .+ 0.5)
# REG: Calculate regularizer based on absolute error of prediction
Expand All @@ -31,6 +31,44 @@ function nigloss(y, γ, ν, α, β, λ = 1, ϵ = 1e-4)
loss
end

"""
nigloss2(y, γ, ν, α, β, λ = 1, p = 1)
This is the corrected loss function for DER as recommended by Meinert, Nis,
Jakob Gawlikowski, and Alexander Lavin. “The Unreasonable Effectiveness of Deep
Evidential Regression.” arXiv, May 20, 2022. http://arxiv.org/abs/2205.10060.
This is the standard loss function for Evidential Inference given a
NormalInverseGamma posterior for the parameters of the gaussian likelihood
function: μ and σ.
# Arguments:
- `y`: the targets whose shape should be (O, B)
- `γ`: the γ parameter of the NIG distribution which corresponds to it's mean and whose shape should be (O, B)
- `ν`: the ν parameter of the NIG distribution which relates to it's precision and whose shape should be (O, B)
- `α`: the α parameter of the NIG distribution which relates to it's precision and whose shape should be (O, B)
- `β`: the β parameter of the NIG distribution which relates to it's uncertainty and whose shape should be (O, B)
- `λ`: the weight to put on the regularizer (default: 1)
- `p`: the power which to raise the scaled absolute prediction error (default: 1)
"""
function nigloss2(y, γ, ν, α, β, λ = 1, p = 1)
# NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution
Ω = 2 * β .* (1 .+ ν)
logγ = SpecialFunctions.loggamma
nll = 0.5 * log.(π ./ ν) -
α .* log.(Ω) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + Ω) +
logγ.(α) -
logγ.(α .+ 0.5)
# REG: Calculate regularizer based on absolute error of prediction
uₐ = aleatoric(ν, α, β)
error = (abs.(y - γ) ./ uₐ) .^ p
Φ = evidence(ν, α) # Total evidence
reg = error .* Φ
# Combine negative log likelihood and regularizer
loss = nll + λ * reg
loss
end

# The α here is actually the α̃ which has scaled down evidence that is good.
# the α heres is a matrix of size (K, B) or (O, B)
function kl(α)
Expand Down
27 changes: 27 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ distribution is as ν virtual observations governing the mean μ of the likeliho
"""
evidence(ν, α) = @. 2ν + α

"""
aleatoric(ν, α, β)
This is the aleatoric uncertainty as recommended by Meinert, Nis, Jakob
Gawlikowski, and Alexander Lavin. “The Unreasonable Effectiveness of Deep
Evidential Regression.” arXiv, May 20, 2022. http://arxiv.org/abs/2205.10060.
This is precisely the ``σ_{St}`` from the Student T distribution.
# Arguments:
- `ν`: the ν parameter of the NIG distribution which relates to it's precision and whose shape should be (O, B)
- `α`: the α parameter of the NIG distribution which relates to it's precision and whose shape should be (O, B)
- `β`: the β parameter of the NIG distribution which relates to it's uncertainty and whose shape should be (O, B)
"""
aleatoric(ν, α, β) = @.* (1 + ν)) /* α)

"""
epistemic(ν)
This is the epistemic uncertainty as recommended by Meinert, Nis, Jakob
Gawlikowski, and Alexander Lavin. “The Unreasonable Effectiveness of Deep
Evidential Regression.” arXiv, May 20, 2022. http://arxiv.org/abs/2205.10060.
# Arguments:
- `ν`: the ν parameter of the NIG distribution which relates to it's precision and whose shape should be (O, B)
"""
epistemic(ν) = 1 ./ sqrt.(ν)

"""
predict(m, x)
Expand Down

0 comments on commit 450a76f

Please sign in to comment.