Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

436 dont pass entire data container #439

Merged
merged 14 commits into from
Apr 30, 2024
Merged

Conversation

pat-alt
Copy link
Member

@pat-alt pat-alt commented Apr 30, 2024

No description provided.

@pat-alt pat-alt linked an issue Apr 30, 2024 that may be closed by this pull request
Comment on lines 82 to 83
:iteration_count => 0,
:mutability => DataPreprocessing.mutability_constraints(data),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
:iteration_count => 0,
:mutability => DataPreprocessing.mutability_constraints(data),
:iteration_count => 0, :mutability => DataPreprocessing.mutability_constraints(data)

@@ -39,6 +39,7 @@ end
Computes the distance of the counterfactual from a point in the target main.
"""
function distance_from_target(ce::AbstractCounterfactualExplanation; K::Int=50, kwrgs...)
get!(ce.search, :potential_neighbours, CounterfactualExplanations.find_potential_neighbours(ce))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
get!(ce.search, :potential_neighbours, CounterfactualExplanations.find_potential_neighbours(ce))
get!(
ce.search,
:potential_neighbours,
CounterfactualExplanations.find_potential_neighbours(ce),
)

4. Initializes the loss.
"""
function initialize!(ce::CounterfactualExplanation)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change


return loss
end
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

Comment on lines 12 to 13
PenaltyRequirements(::Type{<:typeof(distance_from_target)}) =
NeedsNeighbours()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
PenaltyRequirements(::Type{<:typeof(distance_from_target)}) =
NeedsNeighbours()
PenaltyRequirements(::Type{<:typeof(distance_from_target)}) = NeedsNeighbours()


Computes the total loss of a counterfactual explanation with respect to the search objective.
"""
total_loss(ce::AbstractCounterfactualExplanation) = ℓ(ce.generator, ce) + h(ce.generator, ce)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
total_loss(ce::AbstractCounterfactualExplanation) = (ce.generator, ce) + h(ce.generator, ce)
total_loss(ce::AbstractCounterfactualExplanation) =
(ce.generator, ce) + h(ce.generator, ce)


Computes the total loss of a counterfactual explanation with respect to the search objective.
"""
total_loss(ce::AbstractCounterfactualExplanation) = hasfield(typeof(ce.generator), :loss) ? ℓ(ce.generator, ce) + h(ce.generator, ce) : nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
total_loss(ce::AbstractCounterfactualExplanation) = hasfield(typeof(ce.generator), :loss) ? (ce.generator, ce) + h(ce.generator, ce) : nothing
total_loss(ce::AbstractCounterfactualExplanation) = if hasfield(typeof(ce.generator), :loss)
(ce.generator, ce) + h(ce.generator, ce)
else
nothing
end


Check if a generator needs access to neighbors in the target class.
"""
needs_neighbours(gen::AbstractGenerator) = hasfield(typeof(gen), :penalty) ? any(needs_neighbours.(gen.penalty)) : false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
needs_neighbours(gen::AbstractGenerator) = hasfield(typeof(gen), :penalty) ? any(needs_neighbours.(gen.penalty)) : false
needs_neighbours(gen::AbstractGenerator) =
hasfield(typeof(gen), :penalty) ? any(needs_neighbours.(gen.penalty)) : false

Comment on lines 83 to 87
total_loss(ce::AbstractCounterfactualExplanation) = if hasfield(typeof(ce.generator), :loss)
ℓ(ce.generator, ce) + h(ce.generator, ce)
else
nothing
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
total_loss(ce::AbstractCounterfactualExplanation) = if hasfield(typeof(ce.generator), :loss)
(ce.generator, ce) + h(ce.generator, ce)
else
nothing
end
total_loss(ce::AbstractCounterfactualExplanation) =
if hasfield(typeof(ce.generator), :loss)
(ce.generator, ce) + h(ce.generator, ce)
else
nothing
end


Fit a transformer to the data for an `InputTransformer` object. This is a no-op.
"""
function fit_transformer(data::CounterfactualData, input_encoder::InputTransformer; kwargs...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
function fit_transformer(data::CounterfactualData, input_encoder::InputTransformer; kwargs...)
function fit_transformer(
data::CounterfactualData, input_encoder::InputTransformer; kwargs...
)

@pat-alt
Copy link
Member Author

pat-alt commented Apr 30, 2024

It turns out that passing the data container was not the issue: I tried using Ref to the container instead with no change in performance.

Instead, there were a few unnecessary forward passes through the entire dataset that have now been removed. There is now little to no dependency of performance on the dataset size (as in number of samples).

@pat-alt pat-alt merged commit dd7181a into main Apr 30, 2024
8 of 10 checks passed
@pat-alt pat-alt deleted the 436-dont-pass-entire-data-container branch April 30, 2024 13:11
@pat-alt pat-alt mentioned this pull request Apr 30, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Don't pass entire data container
1 participant