Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enzyme] broken MeanPool gradient #2564

Closed
CarloLucibello opened this issue Dec 31, 2024 · 2 comments
Closed

[enzyme] broken MeanPool gradient #2564

CarloLucibello opened this issue Dec 31, 2024 · 2 comments
Labels

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 31, 2024

using Flux, Enzyme, Statistics, Random

function enzyme_withgradient(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Enzyme.Active(x))
        else
            push!(args, Enzyme.Duplicated(x, Enzyme.make_zero(x)))
        end
    end
    ad = Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal)
    ret = Enzyme.autodiff(ad, Enzyme.Const(f), Enzyme.Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return ret[2], g
end

loss(model, x) = mean(model(x))
model = MeanPool((3, 3))
x = rand(Float32, 3, 3, 2, 2)
enzyme_withgradient(loss, model, x)

Output:

ERROR: 
No create nofree of empty function (julia.gc_loaded) julia.gc_loaded)
 at context:   call fastcc void @julia__PoolDims_14_89677({ [2 x i64], [2 x i64], i64, [2 x i64], [4 x i64], [2 x i64] }* noalias nocapture nofree noundef nonnull writeonly sret({ [2 x i64], [2 x i64], i64, [2 x i64], [4 x i64], [2 x i64] }) align 8 dereferenceable(104) %3, [2 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %12, [4 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(48) %11, [4 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %20, [2 x i64] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(64) %13) #49, !dbg !75 (julia__PoolDims_14_89677)

Stacktrace:
 [1] PoolDims
   @ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:20
 [2] PoolDims
   @ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:43
 [3] MeanPool
   @ ~/.julia/dev/Flux/src/layers/conv.jl:774
 [4] loss
   @ ~/.julia/dev/Flux/prova.jl:14
 [5] loss
   @ ~/.julia/dev/Flux/prova.jl:0


Stacktrace:
  [1] PoolDims
    @ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:20 [inlined]
  [2] PoolDims
    @ ~/.julia/packages/NNlib/mRRJu/src/dim_helpers/PoolDims.jl:43 [inlined]
  [3] MeanPool
    @ ~/.julia/dev/Flux/src/layers/conv.jl:774 [inlined]
  [4] loss
    @ ~/.julia/dev/Flux/prova.jl:14 [inlined]
  [5] loss
    @ ~/.julia/dev/Flux/prova.jl:0 [inlined]
  [6] diffejulia_loss_89542_inner_6wrap
    @ ~/.julia/dev/Flux/prova.jl:0
  [7] macro expansion
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:5317 [inlined]
  [8] enzyme_call
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4863 [inlined]
  [9] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/DiEvV/src/compiler.jl:4735 [inlined]
 [10] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/DiEvV/src/Enzyme.jl:503
 [11] enzyme_withgradient(::Function, ::MeanPool{2, 4}, ::Vararg{Any})
    @ Main ~/.julia/dev/Flux/test/test_utils.jl:32
 [12] top-level scope
    @ ~/.julia/dev/Flux/prova.jl:17
Some type information was truncated. Use `show(err)` to see complete types.

cc @wsmoses

@CarloLucibello CarloLucibello changed the title [enzyme] broken MeanPool gradien [enzyme] broken MeanPool gradient Dec 31, 2024
@wsmoses
Copy link
Contributor

wsmoses commented Dec 31, 2024

Likely resolved by EnzymeAD/Enzyme.jl#2240

@CarloLucibello
Copy link
Member Author

Fixed on Enzyme v0.13.27. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants