Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to reproduce the AD optimization in the docs due to type error #248

Closed
Jarrod-Angove opened this issue Jun 17, 2024 · 5 comments
Closed

Comments

@Jarrod-Angove
Copy link

I'm quite new to this package, so please forgive me if this is a mistake on my end, but I've been having some difficulties getting Optim AD methods working with Stheno. Even after directly copying the BFGS example from the docs, I end up with the following error:

ERROR: MethodError: no method matching AbstractGPs.FiniteGP(::Stheno.DerivedGP{Tuple{typeof(+), Stheno.DerivedGP{…}, Stheno.DerivedGP{…}}}, ::Float64)

Closest candidates are:
  AbstractGPs.FiniteGP(::AbstractGPs.AbstractGP, ::AbstractVector, ::AbstractVector{<:Real})
   @ AbstractGPs ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:13
  AbstractGPs.FiniteGP(::AbstractGPs.AbstractGP, ::AbstractVector, ::Real)
   @ AbstractGPs ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:19
  AbstractGPs.FiniteGP(::Tf, ::Tx, ::TΣ) where {Tf<:AbstractGPs.AbstractGP, Tx<:(AbstractVector), TΣ}
   @ AbstractGPs ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:8
  ...

This occurs if I copy and past the tutorial in the docs directly into the repl, or if I paste everything as a single function:

function copy_stheno_test()
    l1 = 0.4
    s1 = 0.2
    l2 = 5.0
    s2 = 1.0

    g = @gppp let
    f1 = s1 * stretch(GP(Matern52Kernel()), 1 / l1)
    f2 = s2 * stretch(GP(SEKernel()), 1 / l2)
    f3 = f1 + f2
    end;

    x = GPPPInput(:f3, collect(range(-5.0, 5.0; length=100)));
    σ²_n = 0.02;
    fx = g(x, σ²_n);
    y = rand(fx);

    θ = (
        # Short length-scale and small variance.
        l1 = positive(0.4),
        s1 = positive(0.2),

        # Long length-scale and larger variance.
        l2 = positive(5.0),
        s2 = positive(1.0),

        # Observation noise variance -- we'll be learning this as well. Constrained to be
        # at least 1e-3.
        s_noise = positive(0.1, exp, 1e-3),
    )
    θ_flat_init, unflatten = flatten(θ);
    unpack = ParameterHandling.value  unflatten;

    function build_model::NamedTuple)
    return @gppp let
        f1 = θ.s1 * stretch(GP(SEKernel()), 1 / θ.l1)
        f2 = θ.s2 * stretch(GP(SEKernel()), 1 / θ.l2)
        f3 = f1 + f2
    end
    end

    function nlml::NamedTuple)
    f = build_model(θ)
    return -logpdf(f(x, θ.s_noise + 1e-6), y)
    end

    results = Optim.optimize(
        nlml  unpack,
        θ->gradient(nlml  unpack, θ)[1],
        θ_flat_init + randn(length(θ_flat_init)),
        BFGS(),
        Optim.Options(show_trace=true,);
        inplace=false,
    )
    return results
end

This issue does not occur if I do not provide a gradient to optim, as I believe it defaults to a finite different method to calculate the gradient. Similarly, using NelderMead() allows it to run without issue.

Here is the stack trace:

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::Type{AbstractGPs.FiniteGP}, ::Stheno.DerivedGP{Tuple{…}}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
  [3] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [4] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
  [5] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [6] AbstractGP
    @ ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:32 [inlined]
  [7] _pullback(ctx::Zygote.Context{false}, f::Stheno.DerivedGP{Tuple{…}}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [8] #rrule_via_ad#54
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:260 [inlined]
  [9] rrule_via_ad
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:248 [inlined]
 [10] #723
    @ ./none:0 [inlined]
 [11] iterate
    @ ./generator.jl:47 [inlined]
 [12] collect(itr::Base.Generator{Vector{Float64}, ChainRules.var"#723#728"{Zygote.ZygoteRuleConfig{…}, Stheno.DerivedGP{…}}})
    @ Base ./array.jl:834
 [13] rrule(config::Zygote.ZygoteRuleConfig{…}, ::typeof(sum), f::Stheno.DerivedGP{…}, xs::Vector{…}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Base/mapreduce.jl:102
 [14] rrule
    @ ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Base/mapreduce.jl:76 [inlined]
 [15] rrule(config::Zygote.ZygoteRuleConfig{…}, ::typeof(mean), f::Stheno.DerivedGP{…}, x::Vector{…}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Statistics/statistics.jl:28
 [16] rrule
    @ ~/.julia/packages/ChainRules/hShjJ/src/rulesets/Statistics/statistics.jl:21 [inlined]
 [17] chain_rrule
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:223 [inlined]
 [18] macro expansion
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(mean), ::Stheno.DerivedGP{Tuple{…}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:87
 [20] mean_and_cov
    @ ~/.julia/packages/AbstractGPs/XejGR/src/abstract_gp.jl:48 [inlined]
 [21] _pullback(::Zygote.Context{false}, ::typeof(mean_and_cov), ::Stheno.DerivedGP{Tuple{…}}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] mean_and_cov
    @ ~/.julia/packages/Stheno/ZSwgx/src/gaussian_process_probabilistic_programme.jl:74 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::typeof(mean_and_cov), ::Stheno.GaussianProcessProbabilisticProgramme{…}, ::GPPPInput{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [24] mean_and_cov
    @ ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:134 [inlined]
 [25] _pullback(ctx::Zygote.Context{…}, f::typeof(mean_and_cov), args::AbstractGPs.FiniteGP{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] logpdf
    @ ~/.julia/packages/AbstractGPs/XejGR/src/finite_gp_projection.jl:307 [inlined]
 [27] _pullback(::Zygote.Context{…}, ::typeof(logpdf), ::AbstractGPs.FiniteGP{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [28] nlml
    @ ./REPL[31]:3 [inlined]
 [29] _pullback(ctx::Zygote.Context{…}, f::typeof(nlml), args::@NamedTuple{})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [30] call_composed
    @ ./operators.jl:1044 [inlined]
 [31] #_#103
    @ ./operators.jl:1041 [inlined]
 [32] _pullback(::Zygote.Context{…}, ::Base.var"##_#103", ::@Kwargs{}, ::ComposedFunction{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [33] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [34] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [35] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [36] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [37] _pullback(ctx::Zygote.Context{false}, f::ComposedFunction{typeof(nlml), ComposedFunction{…}}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [38] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [39] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [40] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [41] (::var"#7#8")(θ::Vector{Float64})
    @ Main ./REPL[36]:3
 [42] (::NLSolversBase.var"#gg!#2"{var"#7#8"})(G::Vector{Float64}, x::Vector{Float64})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/inplace_factory.jl:21
 [43] (::NLSolversBase.var"#fg!#8"{})(gx::Vector{…}, x::Vector{…})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/objective_types/abstract.jl:13
 [44] value_gradient!!(obj::OnceDifferentiable{Float64, Vector{Float64}, Vector{Float64}}, x::Vector{Float64})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/kavn7/src/interface.jl:82
 [45] initial_state(method::BFGS{…}, options::Optim.Options{…}, d::OnceDifferentiable{…}, initial_x::Vector{…})
    @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/solvers/first_order/bfgs.jl:94
 [46] optimize
    @ ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/optimize.jl:36 [inlined]
 [47] optimize(f::Function, g::Function, initial_x::Vector{…}, method::BFGS{…}, options::Optim.Options{…}; inplace::Bool, autodiff::Symbol)
    @ Optim ~/.julia/packages/Optim/ZhuZN/src/multivariate/optimize/interface.jl:156
 [48] top-level scope
    @ REPL[36]:1
@willtebbutt
Copy link
Member

Hmmm I'm really not sure what Zygote is doing here -- it has an annoying habit of breaking when you don't expect it to...

I would suggest trying with Tapir.jl. Could you please see if the following works for you?

using Stheno, Tapir, Random, Optim, ParameterHandling

l1 = 0.4
s1 = 0.2
l2 = 5.0
s2 = 1.0

g = @gppp let
f1 = s1 * stretch(GP(Matern52Kernel()), 1 / l1)
f2 = s2 * stretch(GP(SEKernel()), 1 / l2)
f3 = f1 + f2
end;

x = GPPPInput(:f3, collect(range(-5.0, 5.0; length=100)));
σ²_n = 0.02;
fx = g(x, σ²_n);
y = rand(fx);

θ = (
    # Short length-scale and small variance.
    l1 = positive(0.4),
    s1 = positive(0.2),

    # Long length-scale and larger variance.
    l2 = positive(5.0),
    s2 = positive(1.0),

    # Observation noise variance -- we'll be learning this as well. Constrained to be
    # at least 1e-3.
    s_noise = positive(0.1, exp, 1e-3),
)
θ_flat_init, unflatten = flatten(θ);

function build_model::NamedTuple)
    return @gppp let
        f1 = θ.s1 * stretch(GP(SEKernel()), 1 / θ.l1)
        f2 = θ.s2 * stretch(GP(SEKernel()), 1 / θ.l2)
        f3 = f1 + f2
    end
end

function nlml::NamedTuple)
    f = build_model(θ)
    return -logpdf(f(x, θ.s_noise + 1e-6), y)
end

# Define objective function, check it runs, and compute a gradient to check that works.
obj(x) = nlml(ParameterHandling.value(unflatten(x)))
obj(θ_flat_init)

rule = Tapir.build_rrule(obj, θ_flat_init)
Tapir.value_and_gradient!!(rule, obj, θ_flat_init)

# Run optimisation.
results = Optim.optimize(
    obj,
    θ->Tapir.value_and_gradient!!(rule, obj, θ)[2][2],
    θ_flat_init + randn(length(θ_flat_init)),
    BFGS(),
    Optim.Options(show_trace=true,);
    inplace=false,
)

@Jarrod-Angove
Copy link
Author

Jarrod-Angove commented Jun 17, 2024

Thanks for the quick response @willtebbutt !

I've run the suggested code but the Tapir.build_rrule function is failing:

julia> rule = Tapir.build_rrule(obj, θ_flat_init)
ERROR: MethodError: no method matching tangent_field_type(::Type{ParameterHandling.var"#unflatten_to_NamedTuple#15"{…}}, ::Int64)
The applicable method may be too new: running in world age 31913, while current world is 34727.

Closest candidates are:
  tangent_field_type(::Type{P}, ::Int64) where P (method too new to be called from this world context.)
   @ Tapir ~/.julia/packages/Tapir/7eB9t/src/tangents.jl:282

Edit: I thought it may be an issue with the ParameterHandling package, so I changed theta to a vector and tried Zygote, Tapir, and ForwardDiff again;

θ = [
        0.4,
        0.2,
        5.0,
        1.0,
        0.1,
    ]

Zygote still fails with the same type error. Tapir fails similarly, but this time the build_rrule function works and the value_and_gradient!! function fails:

ERROR: MethodError: no method matching tangent_field_type(::Type{Stheno.GaussianProcessProbabilisticProgramme{@NamedTuple{…}}}, ::Int64)
The applicable method may be too new: running in world age 31913, while current world is 34966.

Interestingly, by not using ParameterHandling, the ForwardDiff package is able to compute the gradient:

julia> ForwardDiff.gradient(obj, θ)
5-element Vector{Float64}:
  -87.61487924857141
 -182.63178639865728
    5.147333884587856
   17.701534645018494
  226.10513088590136

@willtebbutt
Copy link
Member

Hmm could you show me the output of Pkg.status(). It might be that you need a clean install. Also, what version of Julia are you on?

@Jarrod-Angove
Copy link
Author

Jarrod-Angove commented Jun 17, 2024

Yep... It looks like there is something wrong with my project. I ran the code again in a clean install and it appears to work fine. I have no idea what could be causing this, but I'll try rebuilding everything piece by piece until I find the conflict.

For the sake of completeness:
My Julia version is 1.10.4. and here is my Pkg.status():

Project DilPredict v0.1.0
Status `~/Documents/grad_school/thesis/DilPredict/Project.toml`
⌃ [99985d1d] AbstractGPs v0.5.9
  [8bb1440f] DelimitedFiles v1.9.1
  [39dd38d3] Dierckx v0.5.3
  [f6369f11] ForwardDiff v0.10.36
  [033835bb] JLD2 v0.4.48
  [ec8451be] KernelFunctions v0.10.63
  [429524aa] Optim v1.9.4
  [2412ca09] ParameterHandling v0.5.0
  [91a5bcdd] Plots v1.40.4
  [c46f51b8] ProfileView v1.7.2
  [8188c328] Stheno v0.8.2
  [07d77754] Tapir v0.2.20
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [de0858da] Printf
  [9a3f8284] Random
  [10745b16] Statistics v1.10.0
  [fa267f1f] TOML v1.0.3

Sorry for the trouble!

Edit: For some reason Pkg had been fetching a super outdated version of AbstractGPs... All I needed to do was update :')

@willtebbutt
Copy link
Member

No trouble at all -- happy to have been able to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants