-
Notifications
You must be signed in to change notification settings - Fork 7
/
counterfactual_data.jl
299 lines (258 loc) · 9.7 KB
/
counterfactual_data.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
using MLJBase: MLJBase, Continuous, Finite
using StatsBase: StatsBase, ZScoreTransform
using Tables: Tables
using Graphs
using CausalInference: CausalInference
"""
InputTransformer
Abstract type for data transformers. This can be any of the following:
- `StatsBase.AbstractDataTransform`: A data transformation object from the `StatsBase` package.
- `MultivariateStats.AbstractDimensionalityReduction`: A dimensionality reduction object from the `MultivariateStats` package.
- `GenerativeModels.AbstractGenerativeModel`: A generative model object from the `GenerativeModels` module.
"""
const InputTransformer = Union{
StatsBase.AbstractDataTransform,
MultivariateStats.AbstractDimensionalityReduction,
GenerativeModels.AbstractGenerativeModel,
CausalInference.SCM,
}
"""
TypedInputTransformer
Abstract type for data transformers.
"""
const TypedInputTransformer = Union{
Type{<:StatsBase.AbstractDataTransform},
Type{<:MultivariateStats.AbstractDimensionalityReduction},
Type{<:GenerativeModels.AbstractGenerativeModel},
Type{<:CausalInference.SCM},
}
"""
CounterfactualData(
X::AbstractMatrix, y::AbstractMatrix;
mutability::Union{Vector{Symbol},Nothing}=nothing,
domain::Union{Any,Nothing}=nothing,
features_categorical::Union{Vector{Int},Nothing}=nothing,
features_continuous::Union{Vector{Int},Nothing}=nothing,
standardize::Bool=false
)
Stores data and metadata for counterfactual explanations.
"""
mutable struct CounterfactualData
X::AbstractMatrix
y::EncodedOutputArrayType
likelihood::Symbol
mutability::Union{Vector{Symbol},Nothing}
domain::Union{Any,Nothing}
features_categorical::Union{Vector{Vector{Int}},Nothing}
features_continuous::Union{Vector{Int},Nothing}
input_encoder::Union{Nothing,InputTransformer}
y_levels::AbstractVector
output_encoder::OutputEncoder
function CounterfactualData(
X,
y,
likelihood,
mutability,
domain,
features_categorical,
features_continuous,
input_encoder,
y_levels,
output_encoder,
)
# Conditions:
conditions = []
# Feature dimension:
conditions = vcat(
conditions...,
length(size(X)) != 2 ? error("Data should be in tabular format") : true,
)
# Output dimension:
conditions = vcat(
conditions...,
if size(X)[2] != size(y)[2]
throw(
DimensionMismatch(
"Number of output observations is $(size(y)[2]). Expected it to match the number of input observations: $(size(X)[2]).",
),
)
else
true
end,
)
# Likelihood:
available_likelihoods = [:classification_binary, :classification_multi]
@assert likelihood ∈ available_likelihoods "Specified likelihood not available. Needs to be one of: $(available_likelihoods)."
if all(conditions)
new(
X,
y,
likelihood,
mutability,
domain,
features_categorical,
features_continuous,
input_encoder,
y_levels,
output_encoder,
)
end
end
end
include("transformer.jl")
"""
CounterfactualData(
X::AbstractMatrix,
y::RawOutputArrayType;
mutability::Union{Vector{Symbol},Nothing}=nothing,
domain::Union{Any,Nothing}=nothing,
features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
features_continuous::Union{Vector{Int},Nothing}=nothing,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)
This outer constructor method prepares features `X` and labels `y` to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.
# Examples
```julia-repl
using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
```
"""
function CounterfactualData(
X::AbstractMatrix,
y::RawOutputArrayType;
mutability::Union{Vector{Symbol},Nothing}=nothing,
domain::Union{Any,Nothing}=nothing,
features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
features_continuous::Union{Vector{Int},Nothing}=nothing,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)
# Output variable:
y_raw = deepcopy(y)
output_encoder = OutputEncoder(y_raw, nothing)
y, y_levels, likelihood = output_encoder()
# Feature type indices:
if isnothing(features_categorical) && isnothing(features_continuous)
features_continuous = 1:size(X, 1)
elseif !isnothing(features_categorical) && isnothing(features_continuous)
features_all = 1:size(X, 1)
cat_indices = reduce(vcat, features_categorical)
features_continuous = findall(map(i -> !(i ∈ cat_indices), features_all))
end
# Defaults:
domain = typeof(domain) <: Tuple ? [domain for var in features_continuous] : domain # domain constraints
counterfactual_data = CounterfactualData(
X,
y,
likelihood,
mutability,
domain,
features_categorical,
features_continuous,
nothing,
y_levels,
output_encoder,
)
# Data transformations:
if transformable_features(counterfactual_data) !=
counterfactual_data.features_continuous
@warn "Some of the underlying features are constant."
end
counterfactual_data.input_encoder = fit_transformer(counterfactual_data, input_encoder)
counterfactual_data.X = Float32.(counterfactual_data.X)
return counterfactual_data
end
"""
function CounterfactualData(
X::Tables.MatrixTable,
y::RawOutputArrayType;
kwrgs...
)
Outer constructor method that accepts a `Tables.MatrixTable`. By default, the indices of categorical and continuous features are automatically inferred the features' `scitype`.
"""
function CounterfactualData(X::Tables.MatrixTable, y::RawOutputArrayType; kwrgs...)
features_categorical = findall([
MLJBase.scitype(x) <: AbstractVector{<:Finite} for x in X
])
features_categorical =
length(features_categorical) == 0 ? nothing : features_categorical
features_continuous = findall([
MLJBase.scitype(x) <: AbstractVector{<:Continuous} for x in X
])
features_continuous = length(features_continuous) == 0 ? nothing : features_continuous
X = permutedims(Tables.matrix(X))
counterfactual_data = CounterfactualData(X, y; kwrgs...)
return counterfactual_data
end
"""
reconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::Vector)
Reconstruct the categorical encoding for a single instance.
"""
function reconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::AbstractArray)
features_categorical = counterfactual_data.features_categorical
if isnothing(features_categorical)
return x
end
x = vec(x)
map(features_categorical) do cat_group_index
if length(cat_group_index) > 1
x[cat_group_index] = Int.(x[cat_group_index] .== maximum(x[cat_group_index]))
if sum(x[cat_group_index]) > 1
ties = findall(x[cat_group_index] .== 1)
_x = zeros(length(x[cat_group_index]))
winner = rand(ties, 1)[1]
_x[winner] = 1
x[cat_group_index] = _x
end
else
x[cat_group_index] = [round(clamp(x[cat_group_index][1], 0, 1))]
end
end
return x
end
"""
transformable_features(counterfactual_data::CounterfactualData)
Dispatches the `transformable_features` function to the appropriate method based on the type of the `dt` field.
"""
function transformable_features(counterfactual_data::CounterfactualData)
return transformable_features(counterfactual_data, counterfactual_data.input_encoder)
end
"""
transformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)
By default, all continuous features are transformable. This function returns the indices of all continuous features.
"""
function transformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)
return counterfactual_data.features_continuous
end
"""
transformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
)
Returns the indices of all continuous features that can be transformed. For constant features `ZScoreTransform` returns `NaN`.
"""
function transformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
)
# Find all columns that have varying values:
idx_not_all_equal = [
length(unique(counterfactual_data.X[i, :])) != 1 for
i in counterfactual_data.features_continuous
]
# Returns indices of columns that have varying values:
return counterfactual_data.features_continuous[idx_not_all_equal]
end
"""
transformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
)
Returns the indices of all features that have causal parents.
"""
function transformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
)
# Find all nodes that have causal parents
g = counterfactual_data.input_encoder.dag
child_causal_nodes = [v for v in vertices(g) if indegree(g, v) >= 1]
return child_causal_nodes
end