Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try mutating Tapir #126

Closed
wants to merge 6 commits into from
Closed

Try mutating Tapir #126

wants to merge 6 commits into from

Conversation

gdalle
Copy link
Member

@gdalle gdalle commented Apr 1, 2024

Extensions

First try at mutating functions f!(y, x) = nothing with Tapir backend, and it fails due to a mismatched tangent type.

AssertionError: _typeof(tangent(out)) == T
  Stacktrace:num_to_arr!
    [1] value_and_pullback!!(::Tapir.DerivedRule{Core.OpaqueClosure{Tuple{Tapir.AugmentedRegister{Tapir.CoDual{DifferentiationInterfaceTest.var"#num_to_arr!#10"{Matrix{Int64}}, Tapir.Tangent{@NamedTuple{a::Matrix{Tapir.NoTangent}}}}, Base.RefArray{Tapir.Tangent{@NamedTuple{a::Matrix{Tapir.NoTangent}}}, Vector{Tapir.Tangent{@NamedTuple{a::Matrix{Tapir.NoTangent}}}}, Nothing}}, Tapir.AugmentedRegister{Tapir.CoDual{Matrix{Float64}, Matrix{Float64}}, Base.RefArray{Matrix{Float64}, Vector{Matrix{Float64}}, Nothing}}, Tapir.AugmentedRegister{Tapir.CoDual{Float64, Float64}, Base.RefArray{Float64, Vector{Float64}, Nothing}}}, Tapir.AugmentedRegister{Tapir.CoDual{Nothing, Tapir.NoTangent}, Tapir.NoTangentRef}}, Tuple{Tapir.Stack{Tapir.Tangent{@NamedTuple{a::Matrix{Tapir.NoTangent}}}}, Tapir.Stack{Matrix{Float64}}, Tapir.Stack{Float64}}, Core.OpaqueClosure{Tuple{Tapir.NoTangent, Tapir.Tangent{@NamedTuple{a::Matrix{Tapir.NoTangent}}}, Matrix{Float64}, Float64}, Nothing}, Val{false}, Val{3}}, ::Nothing, ::Tapir.CoDual{DifferentiationInterfaceTest.var"#num_to_arr!#10"{Matrix{Int64}}, Tapir.Tangent{@NamedTuple{a::Matrix{Tapir.NoTangent}}}}, ::Tapir.CoDual{Matrix{Float64}, Matrix{Float64}}, ::Tapir.CoDual{Float64, Float64})
      @ Tapir ~/.julia/packages/Tapir/456MF/src/interface.jl:11
    [2] value_and_pullback!!(f!::Function, y::Matrix{Float64}, dx::Float64, ::AutoTapir, x::Float64, dy::FillArrays.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, extras::Nothing)
      @ DifferentiationInterfaceTapirExt ~/Work/GitHub/Julia/DifferentiationInterface.jl/ext/DifferentiationInterfaceTapirExt/mutating.jl:6
    [3] value_and_pullback!!(f!::Function, y::Matrix{Float64}, dx::Float64, backend::AutoTapir, x::Float64, dy::FillArrays.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
      @ DifferentiationInterface ~/Work/GitHub/Julia/DifferentiationInterface.jl/src/pullback.jl:79
    [4] (::DifferentiationInterface.var"#5#7"{DifferentiationInterfaceTest.var"#num_to_arr!#10"{Matrix{Int64}}, Matrix{Float64}, AutoTapir, Float64, Float64})(i::CartesianIndex{2})
      @ DifferentiationInterface ~/Work/GitHub/Julia/DifferentiationInterface.jl/src/pushforward.jl:83

@gdalle
Copy link
Member Author

gdalle commented Apr 1, 2024

@willtebbutt if you have any clues I'm all ears!

@gdalle
Copy link
Member Author

gdalle commented Apr 2, 2024

I applied all your suggestions @willtebbutt but now I think Tapir is angry at me because my mutating function returns nothing (I adopted the Enzyme convention):

MethodError: no method matching copy(::Nothing)
  function:  num_to_arr!
  Closest candidates are:
    copy(::DataFrames.DataFrame; copycols)
     @ DataFrames ~/.julia/packages/DataFrames/58MUJ/src/dataframe/dataframe.jl:805
    copy(::IRTools.Inner.Branch)
     @ IRTools ~/.julia/packages/IRTools/Q1Ewy/src/ir/ir.jl:73
    copy(::LinearAlgebra.Transpose{Bool, BitMatrix})
     @ LinearAlgebra ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/bitarray.jl:240
    ...
  
  Stacktrace:
    [1] value_and_pullback!!(::Tapir.DerivedRule{…}, ::Tapir.NoTangent, ::Tapir.CoDual{…}, ::Tapir.CoDual{…}, ::Tapir.CoDual{…})
      @ Tapir ~/.julia/packages/Tapir/BqxEi/src/interface.jl:13
    [2] value_and_pullback!!(f!::Function, y::Vector{…}, dx::Float64, ::AutoTapir, x::Float64, dy::FillArrays.OneElement{…}, extras::DifferentiationInterfaceTapirExt.TapirMutatingPullbackExtras{…})
      @ DifferentiationInterfaceTapirExt ~/Work/GitHub/Julia/DifferentiationInterface.jl/ext/DifferentiationInterfaceTapirExt/mutating.jl:15
    [3] (::DifferentiationInterface.var"#5#7"{})(i::CartesianIndex{…})
      @ DifferentiationInterface ~/Work/GitHub/Julia/DifferentiationInterface.jl/src/pushforward.jl:119

@gdalle
Copy link
Member Author

gdalle commented Apr 2, 2024

@willtebbutt
Copy link
Member

willtebbutt commented Apr 2, 2024

Ahh okay. That's unfortunate. Perhaps I should be taking a deepcopy of whatever is returned?

edit: PR here

@gdalle
Copy link
Member Author

gdalle commented Apr 2, 2024

Why do you even copy it?
I don't wanna force you to make a deep copy without good reason, it is usually much slower and type unstable

@gdalle
Copy link
Member Author

gdalle commented Apr 2, 2024

@willtebbutt you're right compintell/Mooncake.jl#112 makes it no longer error, but the outputs are all wrong (seemingly uninitialized or zero arrays).
I have put together a MWE for you to play with if you wanna help:

  • Clone the repo from the branch gd/mutating_tapir
  • Activate the test environment with TestEnv.activate()
  • Add Tapir from the branch wct/deepcopy-output
  • Include test/tapir.jl

@gdalle gdalle marked this pull request as draft April 2, 2024 20:37
@willtebbutt
Copy link
Member

Why do you even copy it?

I need to make a copy because the reverse-pass has to "undo" all of the operations that the forwards-pass applies to any mutable state. For example, this is why you were seeing uninitialised values before I added in copy. I'm a little surprised that deepcopy doesn't also solve this problem -- will take a look tomorrow!

I don't wanna force you to make a deep copy without good reason, it is usually much slower and type unstable

Good point -- I'll probably just need to add my own _copy function or something so that I can have fast copying without type piracy.

@gdalle
Copy link
Member Author

gdalle commented Apr 3, 2024

See compintell/Mooncake.jl#113 for the distilled bug on your end / wrong use on mine

@gdalle gdalle closed this Apr 5, 2024
@gdalle gdalle deleted the gd/mutating_tapir branch April 5, 2024 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants