using MLJBase: MLJBase, Continuous, Finite
using StatsBase: StatsBase, ZScoreTransform
using Tables: Tables

using Graphs

using CausalInference: CausalInference

"""
    InputTransformer

Abstract type for data transformers. This can be any of the following:

- `StatsBase.AbstractDataTransform`: A data transformation object from the `StatsBase` package.
- `MultivariateStats.AbstractDimensionalityReduction`: A dimensionality reduction object from the `MultivariateStats` package.
- `GenerativeModels.AbstractGenerativeModel`: A generative model object from the `GenerativeModels` module.
"""
const InputTransformer = Union{
    StatsBase.AbstractDataTransform,
    MultivariateStats.AbstractDimensionalityReduction,
    GenerativeModels.AbstractGenerativeModel,
    CausalInference.SCM,
}

"""
    TypedInputTransformer

Abstract type for data transformers.
"""
const TypedInputTransformer = Union{
    Type{<:StatsBase.AbstractDataTransform},
    Type{<:MultivariateStats.AbstractDimensionalityReduction},
    Type{<:GenerativeModels.AbstractGenerativeModel},
    Type{<:CausalInference.SCM},
}

"""
    CounterfactualData(
        X::AbstractMatrix, y::AbstractMatrix;
        mutability::Union{Vector{Symbol},Nothing}=nothing,
        domain::Union{Any,Nothing}=nothing,
        features_categorical::Union{Vector{Int},Nothing}=nothing,
        features_continuous::Union{Vector{Int},Nothing}=nothing,
        standardize::Bool=false
    )
Stores data and metadata for counterfactual explanations.
"""

mutable struct CounterfactualData
    X::AbstractMatrix
    y::EncodedOutputArrayType
    likelihood::Symbol
    mutability::Union{Vector{Symbol},Nothing}
    domain::Union{Any,Nothing}
    features_categorical::Union{Vector{Vector{Int}},Nothing}
    features_continuous::Union{Vector{Int},Nothing}
    input_encoder::Union{Nothing,InputTransformer}
    y_levels::AbstractVector
    output_encoder::OutputEncoder
    function CounterfactualData(
        X,
        y,
        likelihood,
        mutability,
        domain,
        features_categorical,
        features_continuous,
        input_encoder,
        y_levels,
        output_encoder,
    )

        # Conditions:
        conditions = []
        # Feature dimension:
        conditions = vcat(
            conditions...,
            length(size(X)) != 2 ? error("Data should be in tabular format") : true,
        )
        # Output dimension:
        conditions = vcat(
            conditions...,
            if size(X)[2] != size(y)[2]
                throw(
                    DimensionMismatch(
                        "Number of output observations is $(size(y)[2]). Expected it to match the number of input observations: $(size(X)[2]).",
                    ),
                )
            else
                true
            end,
        )
        # Likelihood:
        available_likelihoods = [:classification_binary, :classification_multi]
        @assert likelihood ∈ available_likelihoods "Specified likelihood not available. Needs to be one of: $(available_likelihoods)."

        if all(conditions)
            new(
                X,
                y,
                likelihood,
                mutability,
                domain,
                features_categorical,
                features_continuous,
                input_encoder,
                y_levels,
                output_encoder,
            )
        end
    end
end

include("transformer.jl")

"""
    CounterfactualData(
        X::AbstractMatrix,
        y::RawOutputArrayType;
        mutability::Union{Vector{Symbol},Nothing}=nothing,
        domain::Union{Any,Nothing}=nothing,
        features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
        features_continuous::Union{Vector{Int},Nothing}=nothing,
        input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
    )

This outer constructor method prepares features `X` and labels `y` to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used. 

# Examples

```julia-repl
using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
```

"""
function CounterfactualData(
    X::AbstractMatrix,
    y::RawOutputArrayType;
    mutability::Union{Vector{Symbol},Nothing}=nothing,
    domain::Union{Any,Nothing}=nothing,
    features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
    features_continuous::Union{Vector{Int},Nothing}=nothing,
    input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)

    # Output variable:
    y_raw = deepcopy(y)
    output_encoder = OutputEncoder(y_raw, nothing)
    y, y_levels, likelihood = output_encoder()

    # Feature type indices:
    if isnothing(features_categorical) && isnothing(features_continuous)
        features_continuous = 1:size(X, 1)
    elseif !isnothing(features_categorical) && isnothing(features_continuous)
        features_all = 1:size(X, 1)
        cat_indices = reduce(vcat, features_categorical)
        features_continuous = findall(map(i -> !(i ∈ cat_indices), features_all))
    end

    # Defaults:
    domain = typeof(domain) <: Tuple ? [domain for var in features_continuous] : domain          # domain constraints

    counterfactual_data = CounterfactualData(
        X,
        y,
        likelihood,
        mutability,
        domain,
        features_categorical,
        features_continuous,
        nothing,
        y_levels,
        output_encoder,
    )

    # Data transformations:
    if transformable_features(counterfactual_data) !=
        counterfactual_data.features_continuous
        @warn "Some of the underlying features are constant."
    end
    counterfactual_data.input_encoder = fit_transformer(counterfactual_data, input_encoder)
    counterfactual_data.X = Float32.(counterfactual_data.X)

    return counterfactual_data
end

"""
    function CounterfactualData(
        X::Tables.MatrixTable,
        y::RawOutputArrayType;
        kwrgs...
    )
    
Outer constructor method that accepts a `Tables.MatrixTable`. By default, the indices of categorical and continuous features are automatically inferred the features' `scitype`.

"""
function CounterfactualData(X::Tables.MatrixTable, y::RawOutputArrayType; kwrgs...)
    features_categorical = findall([
        MLJBase.scitype(x) <: AbstractVector{<:Finite} for x in X
    ])
    features_categorical =
        length(features_categorical) == 0 ? nothing : features_categorical
    features_continuous = findall([
        MLJBase.scitype(x) <: AbstractVector{<:Continuous} for x in X
    ])
    features_continuous = length(features_continuous) == 0 ? nothing : features_continuous
    X = permutedims(Tables.matrix(X))

    counterfactual_data = CounterfactualData(X, y; kwrgs...)

    return counterfactual_data
end

"""
    reconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::Vector)

Reconstruct the categorical encoding for a single instance.
"""
function reconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::AbstractArray)
    features_categorical = counterfactual_data.features_categorical

    if isnothing(features_categorical)
        return x
    end

    x = vec(x)
    map(features_categorical) do cat_group_index
        if length(cat_group_index) > 1
            x[cat_group_index] = Int.(x[cat_group_index] .== maximum(x[cat_group_index]))
            if sum(x[cat_group_index]) > 1
                ties = findall(x[cat_group_index] .== 1)
                _x = zeros(length(x[cat_group_index]))
                winner = rand(ties, 1)[1]
                _x[winner] = 1
                x[cat_group_index] = _x
            end
        else
            x[cat_group_index] = [round(clamp(x[cat_group_index][1], 0, 1))]
        end
    end

    return x
end

"""
    transformable_features(counterfactual_data::CounterfactualData)

Dispatches the `transformable_features` function to the appropriate method based on the type of the `dt` field.
"""
function transformable_features(counterfactual_data::CounterfactualData)
    return transformable_features(counterfactual_data, counterfactual_data.input_encoder)
end

"""
    transformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)

By default, all continuous features are transformable. This function returns the indices of all continuous features.
"""
function transformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)
    return counterfactual_data.features_continuous
end

"""
    transformable_features(
        counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
    )

Returns the indices of all continuous features that can be transformed. For constant features `ZScoreTransform` returns `NaN`.
"""
function transformable_features(
    counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
)
    # Find all columns that have varying values:
    idx_not_all_equal = [
        length(unique(counterfactual_data.X[i, :])) != 1 for
        i in counterfactual_data.features_continuous
    ]
    # Returns indices of columns that have varying values:
    return counterfactual_data.features_continuous[idx_not_all_equal]
end

"""
    transformable_features(
        counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
    )

Returns the indices of all features that have causal parents.
"""
function transformable_features(
    counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
)
    # Find all nodes that have causal parents
    g = counterfactual_data.input_encoder.dag
    child_causal_nodes = [v for v in vertices(g) if indegree(g, v) >= 1]
    return child_causal_nodes
end