Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No method matching for NeuroTreeModel #475

Open
abuszydlik opened this issue Sep 23, 2024 · 4 comments · Fixed by #477
Open

No method matching for NeuroTreeModel #475

abuszydlik opened this issue Sep 23, 2024 · 4 comments · Fixed by #477
Labels
bug Something isn't working

Comments

@abuszydlik
Copy link

I am trying to generate Counterfactual Explanations for a NeuroTreeModel, also trained using CounterfactualExplanations.jl. While the model itself seems to work as expected (i.e., it is able to generate predictions for unseen samples with reasonable accuracy), trying to generate an explanation leads to MethodError: no method matching.

Here is the code snippet, fully based on the documentation, that seems to be the direct culprit:

factual = sim.counterfactual_data.X[:, agent.id]
factual = reshape(factual, size(factual)[1], 1)

η = 0.01
generator = GenericGenerator(; opt=Descent(η), λ=0.01)
conv = CounterfactualExplanations.Convergence.DecisionThresholdConvergence(;
    decision_threshold=0.9, max_iter=100
)
counterfactual = generate_counterfactual(factual, 0, sim.counterfactual_data, sim.model, generator; convergence=conv)

and the corresponding error:

ERROR: MethodError: no method matching (::NeuroTreeModel{NeuroTreeModels.MLogLoss, Chain{Tuple{BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, NeuroTreeModels.StackTree}}})(::Vector{Float32})

Closest candidates are:
  (::NeuroTreeModel)(::AbstractMatrix)
   @ NeuroTreeModels [path].julia\packages\NeuroTreeModels\QUDXW\src\model.jl:136
  (::NeuroTreeModel)(::AbstractDataFrame)
   @ NeuroTreeModels [path].julia\packages\NeuroTreeModels\QUDXW\src\model.jl:143

Stacktrace:
  [1] macro expansion
    @ [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::NeuroTreeModel{NeuroTreeModels.MLogLoss, Chain{Tuple{BatchNorm{…}, NeuroTreeModels.StackTree}}}, args::Vector{Float32})
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:87
  [3] logits
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\ext\NeuroTreeExt\neurotree.jl:86 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(logits), ::CounterfactualExplanations.Models.Model, ::CounterfactualExplanations.NeuroTreeModel, ::Vector{Float32})
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [5] logits
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\src\models\core_struct.jl:81 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::typeof(logits), ::CounterfactualExplanations.Models.Model, ::Vector{Float32})
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [7] #logitcrossentropy#9
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\src\objectives\loss_functions.jl:25 [inlined]
  [8] _pullback(::Zygote.Context{…}, ::CounterfactualExplanations.Objectives.var"##logitcrossentropy#9", ::@Kwargs{}, ::typeof(Flux.Losses.logitcrossentropy), ::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
  [9] logitcrossentropy
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\src\objectives\loss_functions.jl:24 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(Flux.Losses.logitcrossentropy), args::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [11] ℓ
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\loss.jl:18 [inlined]
 [12] _pullback(::Zygote.Context{…}, ::typeof(CounterfactualExplanations.Generators.ℓ), ::CounterfactualExplanations.Generators.GradientBasedGenerator, ::Nothing, ::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [13] ℓ
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\loss.jl:7 [inlined]
 [14] _pullback(::Zygote.Context{…}, ::typeof(CounterfactualExplanations.Generators.ℓ), ::CounterfactualExplanations.Generators.GradientBasedGenerator, ::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [15] #22
    @ [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\gradient_based\loss.jl:15 [inlined]
 [16] _pullback(ctx::Zygote.Context{…}, f::CounterfactualExplanations.Generators.var"#22#23"{…}, args::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [17] pullback(f::Function, cx::Zygote.Context{false}, args::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface.jl:90
 [18] pullback
    @ [path].julia\packages\Zygote\nsBv0\src\compiler\interface.jl:88 [inlined]
 [19] gradient(f::Function, args::CounterfactualExplanation)
    @ Zygote [path].julia\packages\Zygote\nsBv0\src\compiler\interface.jl:147
 [20] ∂ℓ(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\gradient_based\loss.jl:15
 [21] ∇(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\gradient_based\loss.jl:48
 [22] propose_state(::CounterfactualExplanations.Models.IsDifferentiable, generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\gradient_based\generate_perturbations.jl:38
 [23] propose_state(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\gradient_based\generate_perturbations.jl:22
 [24] generate_perturbations(generator::CounterfactualExplanations.Generators.GradientBasedGenerator, ce::CounterfactualExplanation)
    @ CounterfactualExplanations.Generators [path].julia\packages\CounterfactualExplanations\mRwLf\src\generators\gradient_based\generate_perturbations.jl:10
 [25] update!(ce::CounterfactualExplanation)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\mRwLf\src\counterfactuals\search.jl:9
 [26] generate_counterfactual(x::Matrix{…}, target::Int64, data::CounterfactualData, M::CounterfactualExplanations.Models.Model, generator::CounterfactualExplanations.Generators.GradientBasedGenerator; num_counterfactuals::Int64, initialization::Symbol, convergence::CounterfactualExplanations.Convergence.DecisionThresholdConvergence, timeout::Nothing)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\mRwLf\src\counterfactuals\generate_counterfactual.jl:114
 [27] default_logic(current_stage::Recourse, agent::Customer, sim::StandardABM{…})
    @ Main [path]Desktop\Agents\src\stages\recourse.jl:28
 [28] process
    @ [path]Desktop\Agents\src\stages\recourse.jl:8 [inlined]
 [29] process(current_stage::Recourse, agent::Customer, sim::StandardABM{…})
    @ Main [path]Desktop\Agents\src\stages\recourse.jl:7
 [30] welfare_step!(agent::Customer, model::StandardABM{…})
    @ Main [path]Desktop\Agents\src\simulation.jl:19
 [31] step_ahead!(model::StandardABM{…}, agent_step!::typeof(welfare_step!), model_step!::typeof(dummystep), n::Int64, t::Base.RefValue{…})
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\step_standard.jl:17
 [32] step!
    @ [path].julia\packages\Agents\MBOEF\src\simulations\step_standard.jl:5 [inlined]
 [33] _run!(model::StandardABM{…}, df_agent::DataFrame, df_model::DataFrame, n::Int64, when::Int64, when_model::Int64, mdata::Nothing, adata::Nothing, obtainer::typeof(identity), dt::Int64, p::ProgressMeter.Progress)
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\collect.jl:166
 [34] run!(model::StandardABM{…}, n::Int64; when::Int64, when_model::Int64, mdata::Nothing, adata::Nothing, obtainer::Function, showprogress::Bool, init::Bool, dt::Float64)
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\collect.jl:148
 [35] run!(model::StandardABM{GraphSpace{…}, Customer, Vector{…}, Tuple{…}, typeof(welfare_step!), typeof(dummystep), typeof(Agents.Schedulers.fastest), Dict{…}, Xoshiro}, n::Int64)
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\collect.jl:100
 [36] run_sim()
    @ Main [path]Desktop\Agents\src\simulation.jl:88
 [37] top-level scope
    @ REPL[13]:1
Some type information was truncated. Use `show(err)` to see complete types.

The codebase heavily relies on Agents.jl but as far as I am able to tell, the error is fully independent of this dependency.

@abuszydlik abuszydlik added the bug Something isn't working label Sep 23, 2024
@pat-alt
Copy link
Member

pat-alt commented Sep 23, 2024

Thanks @abuszydlik, will take a look asap (probably not today though)

@pat-alt pat-alt linked a pull request Sep 24, 2024 that will close this issue
@pat-alt pat-alt reopened this Sep 24, 2024
@abuszydlik
Copy link
Author

Initial error has been resolved (thanks @pat-alt!), now generation fails with the following stacktrace:

ERROR: LoadError: MethodError: no method matching (::NeuroTree{Matrix{Float32}, Vector{Float32}, Array{Float32, 3}, typeof(tanh)})(::Matrix{AbstractFloat})

Closest candidates are:
  (::NeuroTree{W, B, P, F})(::W) where {W, B, P, F}
   @ NeuroTreeModels [path].julia\packages\NeuroTreeModels\QUDXW\src\model.jl:22

Stacktrace:
  [1] (::NeuroTreeModels.StackTree)(x::Matrix{AbstractFloat})
    @ NeuroTreeModels [path].julia\packages\NeuroTreeModels\QUDXW\src\model.jl:99
  [2] macro expansion
    @ [path].julia\packages\Flux\MtsAN\src\layers\basic.jl:53 [inlined]
  [3] _applychain(layers::Tuple{BatchNorm{…}, NeuroTreeModels.StackTree}, x::Matrix{AbstractFloat})
    @ Flux [path].julia\packages\Flux\MtsAN\src\layers\basic.jl:53
  [4] Chain
    @ [path].julia\packages\Flux\MtsAN\src\layers\basic.jl:51 [inlined]
  [5] (::NeuroTreeModel{NeuroTreeModels.MLogLoss, Chain{Tuple{…}}})(x::Matrix{AbstractFloat})
    @ NeuroTreeModels [path].julia\packages\NeuroTreeModels\QUDXW\src\model.jl:137
  [6] logits(M::CounterfactualExplanations.Models.Model, type::CounterfactualExplanations.NeuroTreeModel, X::Vector{…})
    @ NeuroTreeExt [path].julia\packages\CounterfactualExplanations\VIn4V\ext\NeuroTreeExt\neurotree.jl:87
  [7] logits(M::CounterfactualExplanations.Models.Model, X::Vector{AbstractFloat})
    @ CounterfactualExplanations.Models [path].julia\packages\CounterfactualExplanations\VIn4V\src\models\core_struct.jl:87
  [8] probs(M::CounterfactualExplanations.Models.Model, type::CounterfactualExplanations.NeuroTreeModel, X::Vector{…})
    @ NeuroTreeExt [path].julia\packages\CounterfactualExplanations\VIn4V\ext\NeuroTreeExt\neurotree.jl:102
  [9] probs(M::CounterfactualExplanations.Models.Model, X::Vector{AbstractFloat})
    @ CounterfactualExplanations.Models [path].julia\packages\CounterfactualExplanations\VIn4V\src\models\core_struct.jl:94
 [10] counterfactual_probability(ce::CounterfactualExplanation, x::Nothing)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\VIn4V\src\counterfactuals\info_extraction.jl:51
 [11] target_probs(ce::CounterfactualExplanation, x::Nothing)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\VIn4V\src\counterfactuals\info_extraction.jl:80
 [12] threshold_reached(ce::CounterfactualExplanation, x::Nothing)
    @ CounterfactualExplanations.Convergence [path].julia\packages\CounterfactualExplanations\VIn4V\src\convergence\decision_threshold.jl:57
 [13] converged
    @ [path].julia\packages\CounterfactualExplanations\VIn4V\src\convergence\decision_threshold.jl:45 [inlined]
 [14]
    @ CounterfactualExplanations.Convergence [path].julia\packages\CounterfactualExplanations\VIn4V\src\convergence\decision_threshold.jl:45
 [15] terminated(ce::CounterfactualExplanation)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\VIn4V\src\counterfactuals\termination.jl:7
 [16] update!(ce::CounterfactualExplanation)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\VIn4V\src\counterfactuals\search.jl:23
 [17] generate_counterfactual(x::Matrix{…}, target::Int64, data::CounterfactualData, M::CounterfactualExplanations.Models.Model, generator::CounterfactualExplanations.Generators.GradientBasedGenerator; num_counterfactuals::Int64, initialization::Symbol, convergence::CounterfactualExplanations.Convergence.DecisionThresholdConvergence, timeout::Nothing)
    @ CounterfactualExplanations [path].julia\packages\CounterfactualExplanations\VIn4V\src\counterfactuals\generate_counterfactual.jl:114
 [18] default_logic(current_stage::Recourse, agent::Customer, sim::StandardABM{…})
    @ Main [path]Desktop\Agents\src\stages\recourse.jl:22
 [19] process
    @ [path]Desktop\Agents\src\stages\recourse.jl:8 [inlined]
 [20] process(current_stage::Recourse, agent::Customer, sim::StandardABM{…})
    @ Main [path]Desktop\Agents\src\stages\recourse.jl:7
 [21] welfare_step!(agent::Customer, model::StandardABM{…})
    @ Main [path]Desktop\Agents\src\simulation.jl:19
 [22] step_ahead!(model::StandardABM{…}, agent_step!::typeof(welfare_step!), model_step!::typeof(dummystep), n::Int64, t::Base.RefValue{…})
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\step_standard.jl:17
 [23] step!
    @ [path].julia\packages\Agents\MBOEF\src\simulations\step_standard.jl:5 [inlined]
 [24] _run!(model::StandardABM{…}, df_agent::DataFrame, df_model::DataFrame, n::Int64, when::Int64, when_model::Int64, mdata::Nothing, adata::Nothing, obtainer::typeof(identity), dt::Int64, p::ProgressMeter.Progress)
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\collect.jl:166
 [25] run!(model::StandardABM{…}, n::Int64; when::Int64, when_model::Int64, mdata::Nothing, adata::Nothing, obtainer::Function, showprogress::Bool, init::Bool, dt::Float64)
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\collect.jl:148
 [26] run!(model::StandardABM{…}, n::Int64)
    @ Agents [path].julia\packages\Agents\MBOEF\src\simulations\collect.jl:100
 [27] top-level scope
    @ [path]Desktop\Agents\src\simulation.jl:92
in expression starting at [path]Desktop\Agents\src\simulation.jl:92
Some type information was truncated. Use `show(err)` to see complete types.

@pat-alt
Copy link
Member

pat-alt commented Sep 25, 2024

Closed by #479

@pat-alt pat-alt closed this as completed Sep 25, 2024
@abuszydlik abuszydlik reopened this Sep 30, 2024
@abuszydlik
Copy link
Author

abuszydlik commented Sep 30, 2024

Hi @pat-alt! Unfortunately the recent changes cause the error from my initial comment in this thread to reappear.

The call to M.fitresult() on line 86:

X = X[:, :] |> x -> convert.(eltype(Flux.params(M.fitresult().chain)[1]), x)

fails on the first call to logits (it is the same call that led to ERROR: MethodError: no method matching ... initially).

Conversely, if the conversion is missing, at some point the method is called on type Matrix{AbstractFloat} which causes the error from my second comment. I temporarily solved this problem by hard-coding the type as follows:

X = X[:, :] 
X = convert(Matrix{Float32}, X)
return M.fitresult(X)

If CounterfactualExplanations.jl does not necessarily need run-time inference of types in this method, then this is the same solution that was pointed out in this comment and I believe it should completely fix the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants