-
Notifications
You must be signed in to change notification settings - Fork 7
/
loss.jl
30 lines (25 loc) · 940 Bytes
/
loss.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
ℓ(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)
Dispatches to the appropriate loss function for any generator.
"""
function ℓ(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)
return ℓ(generator, generator.loss, ce)
end
"""
ℓ(generator::AbstractGenerator, loss::Nothing, ce::AbstractCounterfactualExplanation)
Overloads the `ℓ` function for the case where no loss function is provided.
"""
function ℓ(
generator::AbstractGenerator, loss::Nothing, ce::AbstractCounterfactualExplanation
)
return CounterfactualExplanations.guess_loss(ce)(ce)
end
"""
ℓ(generator::AbstractGenerator, loss::Function, ce::AbstractCounterfactualExplanation)
Overloads the `ℓ` function for the case where a single loss function is provided.
"""
function ℓ(
generator::AbstractGenerator, loss::Function, ce::AbstractCounterfactualExplanation
)
return loss(ce)
end