Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No method matching with argument IRTools.Inner.Undefined in gradient computation. #134

Closed
jumerckx opened this issue Aug 13, 2022 · 2 comments

Comments

@jumerckx
Copy link
Contributor

This code, when added in with the SimpleRNN example, fails.

s = SpiralClassifier(10, 20, 30)
ps, st = Lux.setup(Random.default_rng(), s)
x = rand(10, 20, 16)

gradient(ps) do ps
    out, st = s(x, ps, st)
    return sum(out)
end

I couldn't find similar issues online but I believe the above code should work?

The issue seems not to stem from this specific example but is more general as I had the same problem with a custom layer.
When the state variable is ignored, there's no error.

gradient(ps) do ps
    out, _ = s(x, ps, st)
    return sum(out)
end

Stacktrace:

ERROR: MethodError: no method matching (::SpiralClassifier{LSTMCell{true, false, false, Tuple{typeof(Lux.zeros32), typeof(Lux.zeros32), typeof(Lux.ones32), typeof(Lux.zeros32)}, NTuple{4, typeof(Lux.glorot_uniform)}, typeof(Lux.zeros32), typeof(Lux.zeros32)}, Dense{true, typeof(sigmoid_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}})(::Array{Float64, 3}, ::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, ::IRTools.Inner.Undefined)
Closest candidates are:
  (::SpiralClassifier)(::AbstractArray{T, 3}, ::NamedTuple, ::NamedTuple) where T at ~/transformer/Lux.jl/examples/SimpleRNN/main.jl:60
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(::Zygote.Context, ::SpiralClassifier{LSTMCell{true, false, false, Tuple{typeof(Lux.zeros32), typeof(Lux.zeros32), typeof(Lux.ones32), typeof(Lux.zeros32)}, NTuple{4, typeof(Lux.glorot_uniform)}, typeof(Lux.zeros32), typeof(Lux.zeros32)}, Dense{true, typeof(sigmoid_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}, ::Array{Float64, 3}, ::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}}, ::IRTools.Inner.Undefined)
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:9
 [3] _pullback
   @ ~/transformer/Lux.jl/examples/SimpleRNN/main.jl:149 [inlined]
 [4] _pullback(ctx::Zygote.Context, f::var"#38#39", args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [5] _pullback(f::Function, args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
 [6] pullback(f::Function, args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
 [7] gradient(f::Function, args::NamedTuple{(:lstm_cell, :classifier), Tuple{NamedTuple{(:weight_i, :weight_h, :bias), Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}}, NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}}}})
   @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
 [8] top-level scope
   @ ~/transformer/Lux.jl/examples/SimpleRNN/main.jl:148
@avik-pal
Copy link
Member

gradient(ps) do ps
    out, st = s(x, ps, st)
    return sum(out)
end

This is semantically invalid. You are assigning to st without defining it in the local scope. It is always recommended to run the function without gradient and check that it works.

@jumerckx
Copy link
Contributor Author

I was testing in global scope so didn't notice that code wasn't supposed to work.
Thanks a lot!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants