-
Notifications
You must be signed in to change notification settings - Fork 7
/
distance_utils.jl
57 lines (51 loc) · 1.68 KB
/
distance_utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
using Flux: Flux
"""
distance(
ce::AbstractCounterfactualExplanation;
from::Union{Nothing,AbstractArray}=nothing,
agg=mean,
p::Real=1,
weights::Union{Nothing,AbstractArray}=nothing,
)
Computes the distance of the counterfactual to the original factual.
"""
function distance(
ce::AbstractCounterfactualExplanation;
from::Union{Nothing,AbstractArray}=nothing,
agg=mean,
p::Real=1,
weights::Union{Nothing,AbstractArray}=nothing,
cosine::Bool=false,
)
if isnothing(from)
from = CounterfactualExplanations.factual(ce)
end
x′ = CounterfactualExplanations.decode_state(ce)
# Cosine:
if cosine
xs = eachslice(x′; dims=ndims(x′))
δs = map(x′ -> cos_dist(x′, from), xs)
Δ = agg(δs)
return Δ
end
if ce.num_counterfactuals == 1
return LinearAlgebra.norm(x′ .- from, p)
else
xs = eachslice(x′; dims=ndims(x′)) # slices along the last dimension (i.e. the number of counterfactuals)
if isnothing(weights)
Δ = agg(map(x′ -> LinearAlgebra.norm(x′ .- from, p), xs)) # aggregate across counterfactuals
else
@assert length(weights) == size(first(xs), ndims(first(xs))) "The length of the weights vector must match the number of features."
Δ = agg(map(x′ -> (LinearAlgebra.norm.(x′ .- from, p)'weights)[1], xs)) # aggregate across counterfactuals
end
return Δ
end
end
"""
cos_dist(x,y)
Computes the cosine distance between two vectors.
"""
function cos_dist(x, y)
cos_sim = (x'y / (norm(x) * norm(y)))[1]
return 1 - cos_sim
end