-
Notifications
You must be signed in to change notification settings - Fork 7
/
generator_conditions.jl
56 lines (49 loc) · 2.05 KB
/
generator_conditions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
GeneratorConditionsConvergence
Convergence criterion for counterfactual explanations based on the generator conditions. The search stops when the gradients of the search objective are below a certain threshold and the generator conditions are satisfied.
# Fields
- `decision_threshold::AbstractFloat`: The threshold for the decision probability.
- `gradient_tol::AbstractFloat`: The tolerance for the gradients of the search objective.
- `max_iter::Int`: The maximum number of iterations.
- `min_success_rate::AbstractFloat`: The minimum success rate for the generator conditions (across counterfactuals).
"""
struct GeneratorConditionsConvergence <: AbstractConvergence
decision_threshold::AbstractFloat
gradient_tol::AbstractFloat
max_iter::Int
min_success_rate::AbstractFloat
end
"""
GeneratorConditionsConvergence(; decision_threshold=0.5, gradient_tol=1e-2, max_iter=100, min_success_rate=0.75, y_levels=nothing)
Outer constructor for `GeneratorConditionsConvergence`.
"""
function GeneratorConditionsConvergence(;
decision_threshold::AbstractFloat=0.5,
gradient_tol::AbstractFloat=1e-2,
max_iter::Int=100,
min_success_rate::AbstractFloat=0.75,
y_levels::Union{Nothing,AbstractVector}=nothing,
)
@assert 0.0 < min_success_rate <= 1.0 "Minimum success rate should be ∈ [0.0,1.0]."
if isa(y_levels, AbstractVector)
decision_threshold = 1 / length(y_levels)
end
return GeneratorConditionsConvergence(
decision_threshold, gradient_tol, max_iter, min_success_rate
)
end
"""
converged(
convergence::GeneratorConditionsConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.
"""
function converged(
convergence::GeneratorConditionsConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
return threshold_reached(ce, x) && Generators.conditions_satisfied(ce.generator, ce)
end