Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need an adjoint for constructor for Zygote compatibility #737

Open
NAThompson opened this issue Aug 18, 2024 · 2 comments
Open

Need an adjoint for constructor for Zygote compatibility #737

NAThompson opened this issue Aug 18, 2024 · 2 comments

Comments

@NAThompson
Copy link

I'm trying to use Zygote with Unitful. If I use standard units like m, s and m/s, everything works:

using Zygote
using Unitful

# Works:
struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2u"m", 3u"s", 3e8u"m/s")


g = Zygote.gradient(f -> bar(f), foo)

println(g)

However, if I use a custom unit (in this case Euros) I get an error about requiring an adjoint for the constructor:

ERROR: LoadError: Need an adjoint for constructor Quantity{Float64, 𝐓, Unitful.FreeUnits{(yr,), 𝐓, nothing}}. Gradient is of type Quantity{Float64, 𝐓⁻¹, Unitful.FreeUnits{(€, yr⁻¹), 𝐓⁻¹, nothing}}

Apologies if my google-fu has failed me, but I have searched the Unitful docs, but did not manage to find anything about defining an adjoint for a constructor. How can I define one?

Code to reproduce:

using Zygote
using Unitful

@unit"" Euro 1.0 false

struct Loan
    loan_amount::Number
    interest_rate::Number
    loan_term::Number
end

function monthly_payment(l::Loan)
    r = l.interest_rate*1.0u"yr"/12.0 # use monthly interest rate for monthly payment
    # Number of payments:
    n = l.loan_term*12.0/1.0u"yr"
    return l.loan_amount*r*(1+r)^n/((1+r)^n - 1)
end

loan = Loan(500000.0*€, 0.025/1.0u"yr", 30.0u"yr")
g = Zygote.gradient(l -> monthly_payment(l), loan)
println(g)
@NAThompson NAThompson changed the title Need an adjoint for constructor Need an adjoint for constructor for Zygote compatibility Aug 18, 2024
@NAThompson
Copy link
Author

NAThompson commented Sep 21, 2024

For those finding this issue, I have managed to figure out how to get this working: Simply add

Zygote.@adjoint function Quantity(x::Number, u)
    return Quantity(x, u), (∂y -> (∂y * Quantity(1.0, u), nothing))
end

@sostock
Copy link
Collaborator

sostock commented Dec 19, 2024

I don’t think it has anything to do with whether the units are user-defined or not. It also affects units defined in this package:

julia> gradient(x -> ustrip(u"eV", x), 1u"J")
((val = 6.241509074460763e18,),)

julia> gradient(x -> uconvert(u"eV", x), 1u"J")
ERROR: Need an adjoint for constructor Quantity{Float64, 𝐋^2 𝐌 𝐓^-2, Unitful.FreeUnits{(eV,), 𝐋^2 𝐌 𝐓^-2, nothing}}. Gradient is of type Float64
Stacktrace:
[]

We could add the adjoint in a package extension.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants