Skip to content

Commit

Permalink
Merge pull request #431 from JuliaTrustworthyAI/430-move-to-explicit-ad
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt authored Apr 17, 2024
2 parents 871d0ae + ad8a815 commit fd40d9f
Show file tree
Hide file tree
Showing 36 changed files with 173 additions and 72 deletions.
18 changes: 14 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,31 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

*Note*: We try to adhere to these practices as of version `1.1.1`.
*Note*: We try to adhere to these practices as of version [v1.1.1].

## [Unreleased]
## Version [v1.1.3] - 2024-04-17

### Added

- Adds a section on `Convergence` to the documentation, `Changelog.jl` functionality and a few doc tests. [#429]

## [1.1.2] - 2024-04-16
### Changed

- Changes style of taking gradients for the counterfactual search from implicit to explicit. [#430]
- Removed all implicit imports. [#430]

### Removed

- Removed CUDA.jl dependency, because redundant. [#430]
- Removed Parameters.jl dependency, because redundant. [#430]

## Version [v1.1.2] - 2024-04-16

### Changed

- Replaces the GIF in the README and introduction of docs for a static image.

## [1.1.1] - 2024-04-15
## Version [v1.1.1] - 2024-04-15

### Added

Expand Down
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["Patrick Altmeyer <[email protected]>"]
version = "1.1.3"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand All @@ -19,7 +18,6 @@ MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -41,7 +39,6 @@ NeuroTreeExt = "NeuroTreeModels"

[compat]
Aqua = "0.8"
CUDA = "3, 4, 5"
CategoricalArrays = "0.10"
ChainRulesCore = "1.15"
DataFrames = "1"
Expand All @@ -59,7 +56,6 @@ MLUtils = "0.2, 0.3, 0.4"
MultivariateStats = "0.9, 0.10"
NeuroTreeModels = "1.1.0"
PackageExtensionCompat = "1"
Parameters = "0.12"
ProgressMeter = "1"
Random = "1.6, 1.7, 1.8, 1.9, 1.10"
Serialization = "1.6, 1.7, 1.8, 1.9, 1.10"
Expand Down
4 changes: 2 additions & 2 deletions src/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
module CounterfactualExplanations

# Package extensions:
using PackageExtensionCompat
using PackageExtensionCompat: PackageExtensionCompat, @require_extensions
function __init__()
@require_extensions
end

# Dependencies:
using Flux
using TaijaBase
using TaijaBase: TaijaBase

# Setup:
include("artifacts_setup.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/artifacts_setup.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LazyArtifacts
using LazyArtifacts: LazyArtifacts, @artifact_str

function generate_artifact_dir(name::String)
_artifacts_julia_version = "$(Int(VERSION.major)).$(Int(VERSION.minor))"
Expand Down
8 changes: 7 additions & 1 deletion src/convergence/invalidation_rate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using Distributions: Distributions
using Flux: Flux
using LinearAlgebra: LinearAlgebra

Base.@kwdef struct InvalidationRateConvergence <: AbstractConvergence
invalidation_rate::AbstractFloat = 0.1
max_iter::Int = 100
Expand Down Expand Up @@ -44,7 +48,9 @@ function invalidation_rate(ce::AbstractCounterfactualExplanation)
end
gradᵀ = LinearAlgebra.transpose(grad)

identity_matrix = LinearAlgebra.Matrix{Float32}(I, length(grad), length(grad))
identity_matrix = LinearAlgebra.Matrix{Float32}(
LinearAlgebra.I, length(grad), length(grad)
)
denominator = sqrt(gradᵀ * ce.convergence.variance * identity_matrix * grad)[1]

normalized_gradient = f_loss / denominator
Expand Down
6 changes: 3 additions & 3 deletions src/counterfactuals/Counterfactuals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ using .DataPreprocessing
using .GenerativeModels
using .Generators
using .Models
using ChainRulesCore
using ChainRulesCore: ChainRulesCore
using Flux
using MLUtils
using MLUtils: MLUtils
using MultivariateStats
using Statistics
using Statistics: Statistics
using StatsBase

include("core_struct.jl")
Expand Down
5 changes: 1 addition & 4 deletions src/counterfactuals/core_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ function CounterfactualExplanation(
)

# Initialization:
adjust_shape!(ce) # adjust shape to specified number of counterfactuals
ce.s′ = encode_state(ce) # encode the counterfactual state
ce.s′ = initialize_state(ce) # initialize the counterfactual state
ce.x′ = decode_state(ce) # decode the counterfactual state
adjust_shape!(ce) |> encode_state! |> initialize_state! |> decode_state!

ce.search[:path] = [ce.s′]
ce.search[:times_changed_features] = zeros(size(decode_state(ce)))
Expand Down
34 changes: 32 additions & 2 deletions src/counterfactuals/encodings.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
using ChainRulesCore: ignore_derivatives
using MultivariateStats: MultivariateStats
using StatsBase: StatsBase

"""
encode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)
Expand Down Expand Up @@ -68,7 +72,7 @@ function encode_state(
if !ce.generator.latent_space && data.standardize
dt = data.dt
idx = transformable_features(data)
ChainRulesCore.ignore_derivatives() do
ignore_derivatives() do
s = s′[idx, :]
s = encode_array(dt, s)
s′[idx, :] = s
Expand All @@ -85,6 +89,19 @@ function encode_state(
return s′
end

"""
encode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
In-place version of `encode_state`.
"""
function encode_state!(
ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing
)
ce.s′ = encode_state(ce, x)

return ce
end

"""
function decode_state(
ce::CounterfactualExplanation,
Expand Down Expand Up @@ -118,7 +135,7 @@ function decode_state(

# Continuous:
idx = transformable_features(data)
ChainRulesCore.ignore_derivatives() do
ignore_derivatives() do
s = s′[idx, :]
s = decode_array(dt, s)
s′[idx, :] = s
Expand All @@ -137,3 +154,16 @@ function decode_state(

return s′
end

"""
decode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
In-place version of `decode_state`.
"""
function decode_state!(
ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing
)
ce.x′ = decode_state(ce, x)

return ce
end
13 changes: 13 additions & 0 deletions src/counterfactuals/initialisation.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Flux

"""
initialize_state(ce::CounterfactualExplanation)
Expand Down Expand Up @@ -30,3 +32,14 @@ function initialize_state(ce::CounterfactualExplanation)

return s′
end

"""
initialize_state!(ce::CounterfactualExplanation)
Initializes the starting point for the factual(s) in-place.
"""
function initialize_state!(ce::CounterfactualExplanation)
ce.s′ = initialize_state(ce)

return ce
end
2 changes: 2 additions & 0 deletions src/counterfactuals/latent_space_mappings.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Flux: Flux

"""
map_from_latent(
ce::CounterfactualExplanation,
Expand Down
6 changes: 4 additions & 2 deletions src/counterfactuals/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function guess_loss(ce::CounterfactualExplanation)
elseif ce.M.likelihood == :classification_multi
loss_fun = Objectives.logitcrossentropy
else
loss_fun = Flux.Losses.mse
loss_fun = Objectives.mse
end
else
loss_fun = nothing
Expand Down Expand Up @@ -66,7 +66,9 @@ function adjust_shape!(ce::CounterfactualExplanation)

search = ce.search
search[:mutability] = adjust_shape(ce, search[:mutability]) # augment to account for specified number of counterfactuals
return ce.search = search
ce.search = search

return ce
end

"""
Expand Down
8 changes: 4 additions & 4 deletions src/data_preprocessing/DataPreprocessing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module DataPreprocessing
using CategoricalArrays
using CounterfactualExplanations
using ..GenerativeModels
using DataFrames
using Flux
using MultivariateStats
using DataFrames: DataFrames
using Flux: Flux
using MultivariateStats: MultivariateStats
using StatsBase
using Tables
using Random
using Random: Random

include("counterfactual_data.jl")
include("utils.jl")
Expand Down
4 changes: 3 additions & 1 deletion src/data_preprocessing/counterfactual_data.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using MLJBase
using MLJBase: MLJBase, Continuous, Finite
using StatsBase: StatsBase, ZScoreTransform
using Tables: Tables

"""
CounterfactualData(
Expand Down
3 changes: 3 additions & 0 deletions src/data_preprocessing/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using CategoricalArrays: CategoricalArrays
using StatsBase: sample

"Treat `CounterfactualData` as scalar when broadcasting."
Base.broadcastable(data::CounterfactualData) = Ref(data)

Expand Down
2 changes: 1 addition & 1 deletion src/evaluation/Evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ..CounterfactualExplanations
using DataFrames
using ..Generators
using ..Models
using LinearAlgebra
using LinearAlgebra: LinearAlgebra
using Statistics

include("benchmark.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/evaluation/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Base.Iterators
using Serialization
using DataFrames: DataFrames
using Serialization: Serialization
using Statistics: mean
using TaijaBase: AbstractParallelizer, vectorize_collection, parallelize
using UUIDs
using UUIDs: UUIDs

"A container for benchmarks of counterfactual explanations. Instead of subtyping `DataFrame`, it contains a `DataFrame` of evaluation measures (see [this discussion](https://discourse.julialang.org/t/creating-an-abstractdataframe-subtype/36451/6?u=pat-alt) for why we don't subtype `DataFrame` directly)."
struct Benchmark
Expand Down
3 changes: 2 additions & 1 deletion src/evaluation/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using UUIDs
using DataFrames: nrow
using UUIDs: uuid1

"""
compute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)
Expand Down
2 changes: 2 additions & 0 deletions src/evaluation/measures.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Statistics: Statistics

"""
validity(ce::CounterfactualExplanation; γ=0.5)
Expand Down
6 changes: 2 additions & 4 deletions src/generative_models/GenerativeModels.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
module GenerativeModels

using CounterfactualExplanations
using CUDA
using Flux
using Parameters
using ProgressMeter
using ProgressMeter: ProgressMeter
using Random
using Statistics
using Statistics: Statistics

"""
Base type for generative model.
Expand Down
3 changes: 3 additions & 0 deletions src/generative_models/encoders.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using Flux: sigmoid
using Random: Random

"""
Encoder
Expand Down
10 changes: 7 additions & 3 deletions src/generative_models/vae.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
using Flux: Flux, Adam, cpu, gpu
using ProgressMeter: Progress, next!
using Statistics: mean

"""
VAEParams <: AbstractGMParams
The default VAE parameters describing both the encoder/decoder architecture and the training process.
"""
Parameters.@with_kw mutable struct VAEParams <: AbstractGMParams
Base.@kwdef mutable struct VAEParams <: AbstractGMParams
η = 1e-3 # learning rate
λ = 0.01f0 # regularization parameter
batch_size = 50 # batch size
epochs = 100 # number of epochs
seed = 0 # random seed
cuda = true # use GPU
gpu = true # use GPU
device = gpu # default device
latent_dim = 2 # latent dimension
hidden_dim = 32 # hidden dimension
Expand Down Expand Up @@ -41,7 +45,7 @@ function VAE(input_dim; kws...)
args = VAEParams(; kws...)

# GPU config
if args.cuda && CUDA.has_cuda()
if args.gpu
args.device = gpu
else
args.device = cpu
Expand Down
12 changes: 5 additions & 7 deletions src/generators/Generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ using Flux
using LinearAlgebra
using ..Models
using ..Objectives
using Statistics
using Parameters
using Statistics: Statistics
using DecisionTree
using DataFrames
using MLJBase
using MLJDecisionTreeInterface
using Distributions
using DataFrames: DataFrames
using MLJBase: MLJBase
using MLJDecisionTreeInterface: MLJDecisionTreeInterface
using Distributions: Distributions
using Random
using Statistics

export AbstractGradientBasedGenerator
export AbstractNonGradientBasedGenerator
Expand Down
Loading

2 comments on commit fd40d9f

@pat-alt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105119

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.3 -m "<description of version>" fd40d9fa58153874b5afc3b791b08ed38d629a6d
git push origin v1.1.3

Please sign in to comment.