Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorization run between levels of nested AD #748

Closed
samakins opened this issue Apr 20, 2023 · 8 comments
Closed

Vectorization run between levels of nested AD #748

samakins opened this issue Apr 20, 2023 · 8 comments

Comments

@samakins
Copy link

samakins commented Apr 20, 2023

Maybe misusing this but for a simple scalar function when I try to take a second derivative I get some very strange results.

using Enzyme

function derivative(f::F, x::T) where {F, T<:Real}
    adf = @inline xarg -> Enzyme.autodiff(Enzyme.Forward, f, Duplicated(xarg, 1.0))
    first(adf(x))
end

function secondderivative(f::F, x::T) where {F, T<:Real}
    adf = @inline xarg -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, Active(xarg))
    first(first(first(autodiff(Forward, adf, Duplicated(x, 1.0)))))
end

g1(x) = x*x*x*x + 0.0*x*x*x
g2(x) = x*x*x*x
g3(x) = x*x*x*x + 1e-9*x*x*x

println(derivative(g1, 1.0))
println(derivative(g2, 1.0))
println(derivative(g3, 1.0))
println(secondderivative(g1, 1.0))
println(secondderivative(g2, 1.0))
println(secondderivative(g3, 1.0))

This produces the output

4.0

4.0

4.000000003

15.0

12.0

12.000000006
@wsmoses
Copy link
Member

wsmoses commented Apr 20, 2023

What version of Enzyme, and Julia are you using

On latest main and Julia 1.9, I get the following:

julia> using Enzyme

julia> function derivative(f::F, x::T) where {F, T<:Real}
           adf = @inline xarg -> Enzyme.autodiff(Enzyme.Forward, f, Duplicated(xarg, 1.0))
           first(adf(x))
       end
derivative (generic function with 1 method)

julia> function secondderivative(f::F, x::T) where {F, T<:Real}
           adf = @inline xarg -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, Active(xarg))
           first(first(first(autodiff(Forward, adf, Duplicated(x, 1.0)))))
       end
secondderivative (generic function with 1 method)


julia> g1(x) = x*x*x*x + 0.0*x*x*x
g1 (generic function with 1 method)

julia> println(secondderivative(g1, 1.0))
12.0

@samakins
Copy link
Author

Running in Julia 1.8.1 with the following versions.

[[deps.Enzyme]]
deps = ["CEnum", "EnzymeCore", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Printf", "Random"]
git-tree-sha1 = "28c09cee567a3b62a8ae2a99035bd477fdbaee3f"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
version = "0.11.0"

[[deps.EnzymeCore]]
deps = ["Adapt"]
git-tree-sha1 = "d0840cfff51e34729d20fd7d0a13938dc983878b"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
version = "0.3.0"

[[deps.Enzyme_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "b83b072e7a5eee2050800658f5676b0ff7b37dc6"
uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef"
version = "0.0.53+0"

@wsmoses
Copy link
Member

wsmoses commented Apr 20, 2023

Can you add Enzyme.API.printall!(true) right after the using Enzyme, then post the log output after only running the incorrect result

@samakins
Copy link
Author

Sure

Running this:

using Enzyme
Enzyme.API.printall!(true)

function secondderivative(f::F, x::T) where {F, T<:Real}
    adf = xarg -> Enzyme.autodiff_deferred(Enzyme.Reverse, f, Active(xarg))
    first(first(first(autodiff(Forward, adf, Duplicated(x, 1.0)))))
end

g1(x) = x*x*x*x + 0.0*x*x*x

secondderivative(g1, 1.0)

produces this output:

secondderivative (generic function with 1 method)

g1 (generic function with 1 method)

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_g1_3388_inner.1(double %0) local_unnamed_addr #2 !dbg !21 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #3
  %2 = fmul double %0, %0, !dbg !22
  %3 = fmul double %2, %0, !dbg !22
  %4 = fmul double %3, %0, !dbg !26
  %5 = fmul double %0, 0.000000e+00, !dbg !22
  %6 = fmul double %5, %0, !dbg !22
  %7 = fmul double %6, %0, !dbg !26
  %8 = fadd double %4, %7, !dbg !28
  ret double %8, !dbg !29
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double } @diffejulia_g1_3388_inner.1(double %0, double %differeturn) local_unnamed_addr #2 !dbg !30 {
entry:
  %"'de" = alloca double, align 8
  %1 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %"'de1" = alloca double, align 8
  %2 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %2, align 8
  %"'de2" = alloca double, align 8
  %3 = getelementptr double, double* %"'de2", i64 0
  store double 0.000000e+00, double* %3, align 8
  %"'de4" = alloca double, align 8
  %4 = getelementptr double, double* %"'de4", i64 0
  store double 0.000000e+00, double* %4, align 8
  %"'de5" = alloca double, align 8
  %5 = getelementptr double, double* %"'de5", i64 0
  store double 0.000000e+00, double* %5, align 8
  %"'de8" = alloca double, align 8
  %6 = getelementptr double, double* %"'de8", i64 0
  store double 0.000000e+00, double* %6, align 8
  %"'de13" = alloca double, align 8
  %7 = getelementptr double, double* %"'de13", i64 0
  store double 0.000000e+00, double* %7, align 8
  %"'de16" = alloca double, align 8
  %8 = getelementptr double, double* %"'de16", i64 0
  store double 0.000000e+00, double* %8, align 8
  %9 = call {}*** @julia.get_pgcstack() #3
  %10 = fmul double %0, %0, !dbg !31
  %11 = fmul double %10, %0, !dbg !31
  %12 = fmul double %0, 0.000000e+00, !dbg !31
  %13 = fmul double %12, %0, !dbg !31
  br label %invertentry, !dbg !35

invertentry:                                      ; preds = %entry
  store double %differeturn, double* %"'de", align 8
  %14 = load double, double* %"'de", align 8
  store double 0.000000e+00, double* %"'de", align 8
  %15 = load double, double* %"'de1", align 8
  %16 = fadd fast double %15, %14
  store double %16, double* %"'de1", align 8
  %17 = load double, double* %"'de2", align 8
  %18 = fadd fast double %17, %14
  store double %18, double* %"'de2", align 8
  %19 = load double, double* %"'de2", align 8
  %m0diffe = fmul fast double %19, %0
  %m1diffe = fmul fast double %19, %13
  store double 0.000000e+00, double* %"'de2", align 8
  %20 = load double, double* %"'de4", align 8
  %21 = fadd fast double %20, %m0diffe
  store double %21, double* %"'de4", align 8
  %22 = load double, double* %"'de5", align 8
  %23 = fadd fast double %22, %m1diffe
  store double %23, double* %"'de5", align 8
  %24 = load double, double* %"'de4", align 8
  %m0diffe6 = fmul fast double %24, %0
  %m1diffe7 = fmul fast double %24, %12
  store double 0.000000e+00, double* %"'de4", align 8
  %25 = load double, double* %"'de8", align 8
  %26 = fadd fast double %25, %m0diffe6
  store double %26, double* %"'de8", align 8
  %27 = load double, double* %"'de5", align 8
  %28 = fadd fast double %27, %m1diffe7
  store double %28, double* %"'de5", align 8
  %29 = load double, double* %"'de8", align 8
  %m0diffe9 = fmul fast double %29, 0.000000e+00
  store double 0.000000e+00, double* %"'de8", align 8
  %30 = load double, double* %"'de5", align 8
  %31 = fadd fast double %30, %m0diffe9
  store double %31, double* %"'de5", align 8
  %32 = load double, double* %"'de1", align 8
  %m0diffe11 = fmul fast double %32, %0
  %m1diffe12 = fmul fast double %32, %11
  store double 0.000000e+00, double* %"'de1", align 8
  %33 = load double, double* %"'de13", align 8
  %34 = fadd fast double %33, %m0diffe11
  store double %34, double* %"'de13", align 8
  %35 = load double, double* %"'de5", align 8
  %36 = fadd fast double %35, %m1diffe12
  store double %36, double* %"'de5", align 8
  %37 = load double, double* %"'de13", align 8
  %m0diffe14 = fmul fast double %37, %0
  %m1diffe15 = fmul fast double %37, %10
  store double 0.000000e+00, double* %"'de13", align 8
  %38 = load double, double* %"'de16", align 8
  %39 = fadd fast double %38, %m0diffe14
  store double %39, double* %"'de16", align 8
  %40 = load double, double* %"'de5", align 8
  %41 = fadd fast double %40, %m1diffe15
  store double %41, double* %"'de5", align 8
  %42 = load double, double* %"'de16", align 8
  %m0diffe17 = fmul fast double %42, %0
  %m1diffe18 = fmul fast double %42, %0
  store double 0.000000e+00, double* %"'de16", align 8
  %43 = load double, double* %"'de5", align 8
  %44 = fadd fast double %43, %m0diffe17
  store double %44, double* %"'de5", align 8
  %45 = load double, double* %"'de5", align 8
  %46 = fadd fast double %45, %m1diffe18
  store double %46, double* %"'de5", align 8
  %47 = load double, double* %"'de5", align 8
  %48 = insertvalue { double } undef, double %47, 0
  ret { double } %48
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define [1 x [1 x double]] @preprocess_julia__5_3281_inner.1(double %0) local_unnamed_addr #4 !dbg !39 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #5
  %2 = fmul double %0, %0, !dbg !40
  %3 = fmul double %0, 0.000000e+00, !dbg !40
  %4 = insertelement <2 x double> poison, double %0, i32 0
  %5 = insertelement <2 x double> %4, double %3, i32 1
  %6 = insertelement <2 x double> <double poison, double 1.000000e+00>, double %2, i32 0
  %7 = fmul fast <2 x double> %5, %6
  %8 = fadd fast double %2, %3
  %9 = insertelement <2 x double> <double poison, double 2.000000e+00>, double %8, i32 0
  %10 = shufflevector <2 x double> %4, <2 x double> poison, <2 x i32> zeroinitializer
  %11 = insertelement <2 x double> <double 1.000000e+00, double poison>, double %0, i32 1
  %12 = fmul fast <2 x double> %10, %11
  %13 = fmul fast <2 x double> %12, %9
  %14 = fadd fast <2 x double> %13, %7
  %15 = extractelement <2 x double> %14, i32 1
  %reass.mul.i = fmul fast double %15, %0
  %16 = extractelement <2 x double> %14, i32 0
  %17 = fadd fast double %reass.mul.i, %16
  %.fca.0.0.insert = insertvalue [1 x [1 x double]] undef, double %17, 0, 0, !dbg !53
  ret [1 x [1 x double]] %.fca.0.0.insert, !dbg !54
}

; Function Attrs: mustprogress nofree nosync willreturn
define internal [1 x [1 x double]] @fwddiffejulia__5_3281_inner.1(double %0, double %"'") local_unnamed_addr #5 !dbg !55 {
entry:
  %1 = call {}*** @julia.get_pgcstack() #6
  %2 = fmul double %0, %0, !dbg !56
  %3 = fmul fast double %"'", %0, !dbg !56
  %4 = fmul fast double %"'", %0, !dbg !56
  %5 = fadd fast double %3, %4, !dbg !56
  %6 = fmul double %0, 0.000000e+00, !dbg !56
  %7 = fmul fast double %"'", 0.000000e+00
  %"'ipie" = insertelement <2 x double> zeroinitializer, double %"'", i32 0
  %8 = insertelement <2 x double> poison, double %0, i32 0
  %"'ipie16" = insertelement <2 x double> %"'ipie", double %7, i32 1
  %9 = insertelement <2 x double> %8, double %6, i32 1
  %"'ipie17" = insertelement <2 x double> <double poison, double 1.000000e+00>, double %5, i32 0
  %10 = insertelement <2 x double> <double poison, double 1.000000e+00>, double %2, i32 0
  %11 = fmul fast <2 x double> %9, %10
  %12 = fmul fast <2 x double> %"'ipie16", %10
  %13 = fmul fast <2 x double> %"'ipie17", %9
  %14 = fadd fast <2 x double> %12, %13
  %15 = fadd fast double %2, %6
  %16 = fadd fast double %5, %7
  %"'ipie18" = insertelement <2 x double> <double poison, double 2.000000e+00>, double %16, i32 0
  %17 = insertelement <2 x double> <double poison, double 2.000000e+00>, double %15, i32 0
  %"'ipsv" = shufflevector <2 x double> %"'ipie", <2 x double> poison, <2 x i32> zeroinitializer
  %18 = shufflevector <2 x double> %8, <2 x double> poison, <2 x i32> zeroinitializer
  %"'ipie19" = insertelement <2 x double> <double 1.000000e+00, double poison>, double %"'", i32 1
  %19 = insertelement <2 x double> <double 1.000000e+00, double poison>, double %0, i32 1
  %20 = fmul fast <2 x double> %18, %19
  %21 = fmul fast <2 x double> %"'ipsv", %19
  %22 = fmul fast <2 x double> %"'ipie19", %18
  %23 = fadd fast <2 x double> %21, %22
  %24 = fmul fast <2 x double> %20, %17
  %25 = fmul fast <2 x double> %23, %17
  %26 = fmul fast <2 x double> %"'ipie18", %20
  %27 = fadd fast <2 x double> %25, %26
  %28 = fadd fast <2 x double> %24, %11
  %29 = fadd fast <2 x double> %27, %14
  %"'ipee" = extractelement <2 x double> %29, i32 1
  %30 = extractelement <2 x double> %28, i32 1
  %31 = fmul fast double %"'ipee", %0
  %32 = fmul fast double %"'", %30
  %33 = fadd fast double %31, %32
  %"'ipee20" = extractelement <2 x double> %29, i32 0
  %34 = fadd fast double %33, %"'ipee20", !dbg !69
  %".fca.0.0.insert'ipiv" = insertvalue [1 x [1 x double]] zeroinitializer, double %34, 0, 0, !dbg !69
  ret [1 x [1 x double]] %".fca.0.0.insert'ipiv"
}

15.0

@samakins
Copy link
Author

samakins commented Apr 20, 2023

Not sure if this is helpful but this function

g4(x) = x*x*x*x + 0.0*(x*x*x)

will return a correct second derivative of 12.0

Additionally just changing the number of x's in the zero term causes the problem to go away/reappear

g4(x) = x*x*x*x + 0.0*x*x*x*x*x*x

will work for example.

@wsmoses
Copy link
Member

wsmoses commented Apr 21, 2023

@wsmoses
Copy link
Member

wsmoses commented Apr 22, 2023

This should now be fixed on main, if you want to retry. Separately I want to leave this open though since clearly we are running a vectorizer in between the rounds of AD, which we shouldn't do (even though now it will be correct).

@wsmoses wsmoses changed the title 0.0 values in functions don't take effect Vectorization run between levels of nested AD Dec 19, 2023
@wsmoses
Copy link
Member

wsmoses commented Dec 6, 2024

Should be resolved by #2161

@wsmoses wsmoses closed this as completed Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants