Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trying #6

Merged
merged 18 commits into from
Apr 5, 2024
25 changes: 23 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
name = "TaijaParallel"
uuid = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0"
authors = ["Patrick Altmeyer <[email protected]>"]
version = "0.1.0"
version = "1.0.0"

[deps]
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
TaijaBase = "10284c91-9f28-4c9a-abbf-ee43576dfff6"

[weakdeps]
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[extensions]
MPIExt = "MPI"

[compat]
Aqua = "0.8"
CounterfactualExplanations = "0.1"
Logging = "1.7, 1.8, 1.9, 1.10"
MLUtils = "0.4.4"
MPI = "0.20"
PackageExtensionCompat = "1"
ProgressMeter = "1"
Reexport = "1"
Serialization = "1.7, 1.8, 1.9, 1.10"
TaijaBase = "1"
Test = "1.7, 1.8, 1.9, 1.10"
julia = "1.7, 1.8, 1.9, 1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Aqua", "MPI", "Test"]
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[![Build Status](https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/actions/workflows/CI.yml?query=branch%3Amaster)
[![Coverage](https://codecov.io/gh/JuliaTrustworthyAI/TaijaParallel.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaTrustworthyAI/TaijaParallel.jl)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

This package adds custom support for parallelization for certain [Taija](https://github.com/JuliaTrustworthyAI) packages.

Expand Down
29 changes: 12 additions & 17 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
using TaijaParallel
using Documenter

DocMeta.setdocmeta!(TaijaParallel, :DocTestSetup, :(using TaijaParallel); recursive=true)
DocMeta.setdocmeta!(TaijaParallel, :DocTestSetup, :(using TaijaParallel); recursive = true)

makedocs(;
modules=[TaijaParallel],
authors="Patrick Altmeyer",
repo="https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/blob/{commit}{path}#{line}",
sitename="TaijaParallel.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://JuliaTrustworthyAI.github.io/TaijaParallel.jl",
edit_link="main",
assets=String[],
modules = [TaijaParallel],
authors = "Patrick Altmeyer",
repo = "https://github.com/JuliaTrustworthyAI/TaijaParallel.jl/blob/{commit}{path}#{line}",
sitename = "TaijaParallel.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://JuliaTrustworthyAI.github.io/TaijaParallel.jl",
edit_link = "main",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages = ["Home" => "index.md"],
)

deploydocs(;
repo="github.com/JuliaTrustworthyAI/TaijaParallel.jl",
devbranch="main",
)
deploydocs(; repo = "github.com/JuliaTrustworthyAI/TaijaParallel.jl", devbranch = "main")
9 changes: 6 additions & 3 deletions ext/MPIExt/MPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ module MPIExt

export MPIParallelizer

using TaijaParallel
using Logging
using MPI
using ProgressMeter
using TaijaBase
using TaijaParallel

"The `MPIParallelizer` type is used to parallelize the evaluation of a function using `MPI.jl`."
struct MPIParallelizer <: AbstractParallelizer
struct MPIParallelizer <: TaijaParallel.AbstractParallelizer
comm::MPI.Comm
rank::Int
n_proc::Int
Expand All @@ -22,7 +23,9 @@ end
Create an `MPIParallelizer` object from an `MPI.Comm` object. Optionally, specify the number of observations to send to each process using `n_each`. If `n_each` is `nothing`, then all observations will be split into equally sized bins and sent to each process. If `threaded` is `true`, then the `MPIParallelizer` will use `Threads.@threads` to further parallelize the evaluation of a function.
"""
function TaijaParallel.MPIParallelizer(
comm::MPI.Comm; n_each::Union{Nothing,Int}=nothing, threaded::Bool=false
comm::MPI.Comm;
n_each::Union{Nothing,Int} = nothing,
threaded::Bool = false,
)
rank = MPI.Comm_rank(comm) # Rank of this process in the world 🌍
n_proc = MPI.Comm_size(comm) # Number of processes in the world 🌍
Expand Down
27 changes: 15 additions & 12 deletions ext/MPIExt/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
parallelize(
TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
Expand All @@ -8,19 +8,19 @@

Parallelizes the evaluation of the `CounterfactualExplanations.Evaluation.evaluate` function. This function is used to evaluate the performance of a counterfactual explanation method.
"""
function parallelize(
function TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
verbose::Bool=false,
verbose::Bool = false,
kwargs...,
)

# Setup:
n_each = parallelizer.n_each

# Extract positional arguments:
counterfactuals = args[1] |> x -> CounterfactualExplanations.vectorize_collection(x)
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)
# Get meta data if supplied:
if length(args) > 1
meta_data = args[2]
Expand All @@ -33,7 +33,7 @@
# Break down into chunks:
args = zip(counterfactuals, meta_data)
if !isnothing(n_each)
chunks = Parallelization.chunk_obs(args, n_each, parallelizer.n_proc)
chunks = chunk_obs(args, n_each, parallelizer.n_proc)

Check warning on line 36 in ext/MPIExt/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

ext/MPIExt/evaluate.jl#L36

Added line #L36 was not covered by tests
else
chunks = [collect(args)]
end
Expand All @@ -43,15 +43,15 @@

# For each chunk:
for (i, chunk) in enumerate(chunks)
worker_chunk = Parallelization.split_obs(chunk, parallelizer.n_proc)
worker_chunk = TaijaParallel.split_obs(chunk, parallelizer.n_proc)
worker_chunk = MPI.scatter(worker_chunk, parallelizer.comm)
worker_chunk = stack(worker_chunk; dims=1)
worker_chunk = stack(worker_chunk; dims = 1)
if !parallelizer.threaded
if parallelizer.rank == 0 && verbose
# Generating counterfactuals with progress bar:
output = []
@showprogress desc = "Evaluating counterfactuals ..." for x in zip(
eachcol(worker_chunk)...
eachcol(worker_chunk)...,
)
with_logger(NullLogger()) do
push!(output, f(x...; kwargs...))
Expand All @@ -66,8 +66,11 @@
else
# Parallelize further with `Threads.@threads`:
second_parallelizer = ThreadsParallelizer()
output = parallelize(
second_parallelizer, f, eachcol(worker_chunk)...; kwargs...
output = TaijaBase.parallelize(

Check warning on line 69 in ext/MPIExt/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

ext/MPIExt/evaluate.jl#L69

Added line #L69 was not covered by tests
second_parallelizer,
f,
eachcol(worker_chunk)...;
kwargs...,
)
end
MPI.Barrier(parallelizer.comm)
Expand All @@ -84,7 +87,7 @@
# Load output from rank 0:
if parallelizer.rank == 0
outputs = []
for i in 1:length(chunks)
for i = 1:length(chunks)
output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
Expand All @@ -95,7 +98,7 @@
end

# Broadcast output to all processes:
final_output = MPI.bcast(output, parallelizer.comm; root=0)
final_output = MPI.bcast(output, parallelizer.comm; root = 0)
MPI.Barrier(parallelizer.comm)

return final_output
Expand Down
28 changes: 16 additions & 12 deletions ext/MPIExt/generate_counterfactual.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using CounterfactualExplanations
using MLUtils: stack
using Serialization

"""
parallelize(
TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.generate_counterfactual),
args...;
Expand All @@ -11,19 +12,19 @@

Parallelizes the `CounterfactualExplanations.generate_counterfactual` function using `MPI.jl`. This function is used to generate counterfactual explanations.
"""
function parallelize(
function TaijaBase.parallelize(
parallelizer::MPIParallelizer,
f::typeof(CounterfactualExplanations.generate_counterfactual),
args...;
verbose::Bool=false,
verbose::Bool = false,
kwargs...,
)

# Setup:
n_each = parallelizer.n_each

# Extract positional arguments:
counterfactuals = args[1] |> x -> CounterfactualExplanations.vectorize_collection(x)
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)
target = args[2] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
data = args[3] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
M = args[4] |> x -> isa(x, AbstractArray) ? x : fill(x, length(counterfactuals))
Expand All @@ -32,7 +33,7 @@
# Break down into chunks:
args = zip(counterfactuals, target, data, M, generator)
if !isnothing(n_each)
chunks = Parallelization.chunk_obs(args, n_each, parallelizer.n_proc)
chunks = chunk_obs(args, n_each, parallelizer.n_proc)

Check warning on line 36 in ext/MPIExt/generate_counterfactual.jl

View check run for this annotation

Codecov / codecov/patch

ext/MPIExt/generate_counterfactual.jl#L36

Added line #L36 was not covered by tests
else
chunks = [collect(args)]
end
Expand All @@ -42,15 +43,15 @@

# For each chunk:
for (i, chunk) in enumerate(chunks)
worker_chunk = Parallelization.split_obs(chunk, parallelizer.n_proc)
worker_chunk = TaijaParallel.split_obs(chunk, parallelizer.n_proc)
worker_chunk = MPI.scatter(worker_chunk, parallelizer.comm)
worker_chunk = stack(worker_chunk; dims=1)
worker_chunk = stack(worker_chunk; dims = 1)
if !parallelizer.threaded
if parallelizer.rank == 0 && verbose
# Generating counterfactuals with progress bar:
output = []
@showprogress desc = "Generating counterfactuals ..." for x in zip(
eachcol(worker_chunk)...
eachcol(worker_chunk)...,
)
with_logger(NullLogger()) do
push!(output, f(x...; kwargs...))
Expand All @@ -65,8 +66,11 @@
else
# Parallelize further with `Threads.@threads`:
second_parallelizer = ThreadsParallelizer()
output = parallelize(
second_parallelizer, f, eachcol(worker_chunk)...; kwargs...
output = TaijaBase.parallelize(

Check warning on line 69 in ext/MPIExt/generate_counterfactual.jl

View check run for this annotation

Codecov / codecov/patch

ext/MPIExt/generate_counterfactual.jl#L69

Added line #L69 was not covered by tests
second_parallelizer,
f,
eachcol(worker_chunk)...;
kwargs...,
)
end
MPI.Barrier(parallelizer.comm)
Expand All @@ -83,7 +87,7 @@
# Load output from rank 0:
if parallelizer.rank == 0
outputs = []
for i in 1:length(chunks)
for i = 1:length(chunks)
output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
Expand All @@ -94,7 +98,7 @@
end

# Broadcast output to all processes:
final_output = MPI.bcast(output, parallelizer.comm; root=0)
final_output = MPI.bcast(output, parallelizer.comm; root = 0)
MPI.Barrier(parallelizer.comm)

return final_output
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using CounterfactualExplanations
using CounterfactualExplanations: generate_counterfactual
using CounterfactualExplanations.Evaluation: evaluate
import CounterfactualExplanations
using Logging
using ProgressMeter

include("assign_traits.jl")
include("threads/threads.jl")
include("threads/threads.jl")
5 changes: 3 additions & 2 deletions src/CounterfactualExplanations.jl/assign_traits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"The `generate_counterfactual` method is parallelizable."
ProcessStyle(::Type{<:typeof(generate_counterfactual)}) = IsParallel()
ProcessStyle(::Type{<:typeof(CounterfactualExplanations.generate_counterfactual)}) =
IsParallel()

"The `evaluate` function is parallelizable."
function ProcessStyle(::Type{<:typeof(CounterfactualExplanations.Evaluation.evaluate)})
return IsParallel()
end
end
20 changes: 11 additions & 9 deletions src/CounterfactualExplanations.jl/threads/evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import TaijaBase

"""
parallelize(
TaijaBase.parallelize(
parallelizer::ThreadsParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
Expand All @@ -8,16 +10,16 @@

Parallelizes the evaluation of `f` using `Threads.@threads`. This function is used to evaluate counterfactual explanations.
"""
function parallelize(
function TaijaBase.parallelize(

Check warning on line 13 in src/CounterfactualExplanations.jl/threads/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/CounterfactualExplanations.jl/threads/evaluate.jl#L13

Added line #L13 was not covered by tests
parallelizer::ThreadsParallelizer,
f::typeof(CounterfactualExplanations.Evaluation.evaluate),
args...;
verbose::Bool=true,
verbose::Bool = true,
kwargs...,
)

# Setup:
counterfactuals = args[1] |> x -> CounterfactualExplanations.vectorize_collection(x)
counterfactuals = args[1] |> x -> TaijaBase.vectorize_collection(x)

Check warning on line 22 in src/CounterfactualExplanations.jl/threads/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/CounterfactualExplanations.jl/threads/evaluate.jl#L22

Added line #L22 was not covered by tests

# Get meta data if supplied:
if length(args) > 1
Expand All @@ -28,22 +30,22 @@

# Check meta data:
if typeof(meta_data) <: AbstractArray
meta_data = CounterfactualExplanations.vectorize_collection(meta_data)
meta_data = TaijaBase.vectorize_collection(meta_data)

Check warning on line 33 in src/CounterfactualExplanations.jl/threads/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/CounterfactualExplanations.jl/threads/evaluate.jl#L33

Added line #L33 was not covered by tests
@assert length(meta_data) == length(counterfactuals) "The number of meta data must match the number of counterfactuals."
else
meta_data = fill(meta_data, length(counterfactuals))
end

# Preallocate:
evaluations = [[] for _ in 1:Threads.nthreads()]
evaluations = [[] for _ = 1:Threads.nthreads()]

Check warning on line 40 in src/CounterfactualExplanations.jl/threads/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/CounterfactualExplanations.jl/threads/evaluate.jl#L40

Added line #L40 was not covered by tests

# Verbosity:
if verbose
prog = ProgressMeter.Progress(
length(counterfactuals);
desc="Evaluating counterfactuals ...",
showspeed=true,
color=:green,
desc = "Evaluating counterfactuals ...",
showspeed = true,
color = :green,
)
end

Expand Down
Loading
Loading