-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathMLP.jl
29 lines (22 loc) · 811 Bytes
/
MLP.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
struct MLP <: AbstractFluxNN end
"""
MLP(model; likelihood::Symbol=:classification_binary)
An outer constructor for a multi-layer perceptron (MLP) model.
"""
function MLP(model; likelihood::Symbol=:classification_binary)
return Model(model, MLP(); likelihood=likelihood)
end
"""
(M::Model)(data::CounterfactualData, type::MLP; kwargs...)
Constructs a multi-layer perceptron (MLP) for the given data.
"""
function (M::Model)(data::CounterfactualData, type::MLP; kwargs...)
# Basic setup:
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(data)
input_dim = size(X, 1)
output_dim = size(y, 1)
# Build MLP:
model = build_mlp(; input_dim=input_dim, output_dim=output_dim, kwargs...)
M = Model(model, type; likelihood=data.likelihood)
return M
end