-
Notifications
You must be signed in to change notification settings - Fork 7
/
path_tracking.jl
44 lines (37 loc) · 1.17 KB
/
path_tracking.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
path(ce::CounterfactualExplanation)
A convenience method that returns the entire counterfactual path.
"""
function path(ce::CounterfactualExplanation; feature_space=true)
path = deepcopy(ce.search[:path])
if feature_space
path = [decode_state(ce, z) for z in path]
end
return path
end
"""
counterfactual_probability_path(ce::CounterfactualExplanation)
Returns the counterfactual probabilities for each step of the search.
"""
function counterfactual_probability_path(ce::CounterfactualExplanation)
return map(X -> counterfactual_probability(ce, X), path(ce))
end
"""
counterfactual_label_path(ce::CounterfactualExplanation)
Returns the counterfactual labels for each step of the search.
"""
function counterfactual_label_path(ce::CounterfactualExplanation)
counterfactual_data = ce.data
M = ce.M
ŷ = map(X -> predict_label(M, counterfactual_data, X), path(ce))
return ŷ
end
"""
target_probs_path(ce::CounterfactualExplanation)
Returns the target probabilities for each step of the search.
"""
function target_probs_path(ce::CounterfactualExplanation)
X = path(ce)
P = map(x -> target_probs(ce, x), X)
return P
end