-
Notifications
You must be signed in to change notification settings - Fork 7
/
loss_functions.jl
60 lines (52 loc) · 1.89 KB
/
loss_functions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
using Statistics: Statistics
"""
Flux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)
Simply extends the `logitbinarycrossentropy` method to work with objects of type `AbstractCounterfactualExplanation`.
"""
function Flux.Losses.logitbinarycrossentropy(
ce::AbstractCounterfactualExplanation; kwargs...
)
loss = Flux.Losses.logitbinarycrossentropy(
logits(ce.M, CounterfactualExplanations.decode_state(ce)),
ce.target_encoded;
kwargs...,
)
return loss
end
"""
Flux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)
Simply extends the `logitcrossentropy` method to work with objects of type `AbstractCounterfactualExplanation`.
"""
function Flux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation; kwargs...)
loss = Flux.Losses.logitcrossentropy(
logits(ce.M, CounterfactualExplanations.decode_state(ce)),
ce.target_encoded;
kwargs...,
)
return loss
end
"""
Flux.Losses.mse(ce::AbstractCounterfactualExplanation)
Simply extends the `mse` method to work with objects of type `AbstractCounterfactualExplanation`.
"""
function Flux.Losses.mse(ce::AbstractCounterfactualExplanation; kwargs...)
loss = Flux.Losses.mse(
logits(ce.M, CounterfactualExplanations.decode_state(ce)),
ce.target_encoded;
kwargs...,
)
return loss
end
"""
predictive_entropy(ce::AbstractCounterfactualExplanation; agg=Statistics.mean)
Computes the predictive entropy of the counterfactuals.
Explained in https://arxiv.org/abs/1406.2541.
"""
function predictive_entropy(ce::AbstractCounterfactualExplanation; agg=Statistics.mean)
model = ce.M
counterfactual_data = ce.data
X = CounterfactualExplanations.decode_state(ce)
p = CounterfactualExplanations.Models.predict_proba(model, counterfactual_data, X)
output = -agg(sum(@.(p * log(p)); dims=2))
return output
end