Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing Rule for Roots.jl #2035

Open
mhauru opened this issue Oct 31, 2024 · 8 comments
Open

Missing Rule for Roots.jl #2035

mhauru opened this issue Oct 31, 2024 · 8 comments

Comments

@mhauru
Copy link
Contributor

mhauru commented Oct 31, 2024

MWE:

module MWE
import Bijectors, Enzyme, StableRNGs
b = Bijectors.PlanarLayer(3)
binv = Bijectors.inverse(b)
x = randn(StableRNGs.StableRNG(23), (3, 3))
f = x -> sum(b(binv(x)))
Enzyme.gradient(Enzyme.Forward, f, x)
end

Output:

ERROR: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia_init_state_16477([4 x double]* noalias nocapture noundef nonnull sret([4 x double]) align 8 dereferenceable(32) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %0, { [3 x double] } addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" "enzymejl_parmtype"="4457150416" "enzymejl_parmtype_ref"="1" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %2, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %3, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %4, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4770454480" "enzymejl_parmtype_ref"="0" %5) unnamed_addr #59 !dbg !5084 {
top:
  %6 = alloca [4 x double], align 8
  %7 = call {}*** @julia.get_pgcstack() #60
  %ptls_field14 = getelementptr inbounds {}**, {}*** %7, i64 2
  %8 = bitcast {}*** %ptls_field14 to i64***
  %ptls_load1516 = load i64**, i64*** %8, align 8, !tbaa !61
  %9 = getelementptr inbounds i64*, i64** %ptls_load1516, i64 2
  %safepoint = load i64*, i64** %9, align 8, !tbaa !65
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #60, !dbg !5085
  fence syncscope("singlethread") seq_cst
  %10 = call double @llvm.fabs.f64(double %2) #60, !dbg !5086
  %11 = bitcast double %10 to i64, !dbg !5087
  %.not = icmp eq i64 %11, 9218868437227405312, !dbg !5087
  %12 = call double @julia_nextfloat_16460(double %2) #61, !dbg !5088
  %value_phi = select i1 %.not, double %12, double %2, !dbg !5088
  %13 = call double @llvm.fabs.f64(double %3) #60, !dbg !5089
  %14 = bitcast double %13 to i64, !dbg !5090
  %.not17 = icmp eq i64 %14, 9218868437227405312, !dbg !5090
  %15 = call double @julia_prevfloat_16488(double %3) #61, !dbg !5091
  %value_phi2 = select i1 %.not17, double %15, double %3, !dbg !5091
  %16 = fcmp uge double %value_phi, 0.000000e+00, !dbg !5092
  %17 = fcmp ule double %value_phi, 0.000000e+00, !dbg !5095
  %18 = select i1 %17, double %value_phi, double 1.000000e+00, !dbg !5097
  %19 = select i1 %16, double %18, double -1.000000e+00, !dbg !5097
  %20 = fcmp uge double %value_phi2, 0.000000e+00, !dbg !5092
  %21 = fcmp ule double %value_phi2, 0.000000e+00, !dbg !5095
  %22 = select i1 %21, double %value_phi2, double 1.000000e+00, !dbg !5097
  %23 = select i1 %20, double %22, double -1.000000e+00, !dbg !5097
  %24 = fmul double %19, %23, !dbg !5098
  %25 = fcmp uge double %24, 0.000000e+00, !dbg !5099
  br i1 %25, label %L31, label %L47, !dbg !5094

L31:                                              ; preds = %top
  %26 = call double @llvm.fabs.f64(double %value_phi) #60, !dbg !5101
  %bitcast_coercion = bitcast double %26 to i64, !dbg !5105
  %27 = call double @llvm.fabs.f64(double %value_phi2) #60, !dbg !5106
  %bitcast_coercion9 = bitcast double %27 to i64, !dbg !5108
  %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !5109
  %29 = lshr i64 %28, 1, !dbg !5111
  %30 = fadd double %value_phi, %value_phi2, !dbg !5113
  %31 = fcmp uge double %30, 0.000000e+00, !dbg !5115
  %32 = fcmp ule double %30, 0.000000e+00, !dbg !5117
  %33 = select i1 %32, double %30, double 1.000000e+00, !dbg !5119
  %34 = select i1 %31, double %33, double -1.000000e+00, !dbg !5119
  %bitcast_coercion12 = bitcast i64 %29 to double, !dbg !5120
  %35 = fmul double %34, %bitcast_coercion12, !dbg !5121
  br label %L47, !dbg !5121

L47:                                              ; preds = %L31, %top
  %value_phi6 = phi double [ %35, %L31 ], [ 0.000000e+00, %top ]
  %36 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 1, !dbg !5122
  %37 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 2, !dbg !5122
  %unbox = load double, double addrspace(11)* %37, align 8, !dbg !5124, !tbaa !65, !alias.scope !211, !noalias !214
  %38 = fadd double %value_phi6, %unbox, !dbg !5124
  %39 = call double @julia_tanh_16435(double %38) #62, !dbg !5122
  %unbox7 = load double, double addrspace(11)* %36, align 8, !dbg !5125, !tbaa !65, !alias.scope !211, !noalias !214
  %40 = fmul double %39, %unbox7, !dbg !5125
  %41 = fadd double %value_phi6, %40, !dbg !5124
  %42 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 0, !dbg !5122
  %unbox8 = load double, double addrspace(11)* %42, align 8, !dbg !5126, !tbaa !65, !alias.scope !211, !noalias !214
  %43 = fsub double %41, %unbox8, !dbg !5126
  call fastcc void @julia__init_state_52_16483([4 x double]* noalias nocapture noundef nonnull sret([4 x double]) align 8 dereferenceable(32) %6, double %value_phi6, double %43, double %2, double %3, double %4, double %5) #60, !dbg !5085
  %44 = bitcast [4 x double]* %0 to i8*, !dbg !5085
  %45 = bitcast [4 x double]* %6 to i8*, !dbg !5085
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 8 dereferenceable(32) %44, i8* noundef nonnull align 8 dereferenceable(32) %45, i64 32, i1 false) #60, !dbg !5085, !noalias !5127
  ret void, !dbg !5085
}

 Type analysis state:
<analysis>
[4 x double]* %0: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
{ [3 x double] } addrspace(11)* %1: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
double %2: {[-1]:Float@double}, intvals: {}
double %3: {[-1]:Float@double}, intvals: {}
double %4: {[-1]:Float@double}, intvals: {}
double %5: {[-1]:Float@double}, intvals: {}
double -1.000000e+00: {[-1]:Float@double}, intvals: {}
double 1.000000e+00: {[-1]:Float@double}, intvals: {}
double 0.000000e+00: {[-1]:Anything}, intvals: {}
  %value_phi2 = select i1 %.not17, double %15, double %3, !dbg !77: {[-1]:Float@double}, intvals: {}
  %23 = select i1 %20, double %22, double -1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %22 = select i1 %21, double %value_phi2, double 1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %19 = select i1 %16, double %18, double -1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %18 = select i1 %17, double %value_phi, double 1.000000e+00, !dbg !88: {[-1]:Float@double}, intvals: {}
  %value_phi = select i1 %.not, double %12, double %2, !dbg !73: {[-1]:Float@double}, intvals: {}
  %14 = bitcast double %13 to i64, !dbg !76: {[-1]:Float@double}, intvals: {}
  %bitcast_coercion = bitcast double %26 to i64, !dbg !100: {[-1]:Float@double}, intvals: {}
  %ptls_load1516 = load i64**, i64*** %8, align 8, !tbaa !61: {}, intvals: {}
  %safepoint = load i64*, i64** %9, align 8, !tbaa !65: {}, intvals: {}
  %bitcast_coercion9 = bitcast double %27 to i64, !dbg !104: {[-1]:Float@double}, intvals: {}
  %11 = bitcast double %10 to i64, !dbg !71: {[-1]:Float@double}, intvals: {}
  %8 = bitcast {}*** %ptls_field14 to i64***: {[-1]:Pointer}, intvals: {}
  %6 = alloca [4 x double], align 8: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
  %7 = call {}*** @julia.get_pgcstack() #60: {}, intvals: {}
  %ptls_field14 = getelementptr inbounds {}**, {}*** %7, i64 2: {}, intvals: {}
  %9 = getelementptr inbounds i64*, i64** %ptls_load1516, i64 2: {[-1]:Pointer}, intvals: {}
  %10 = call double @llvm.fabs.f64(double %2) #60, !dbg !68: {[-1]:Float@double}, intvals: {}
  %12 = call double @julia_nextfloat_16460(double %2) #61, !dbg !73: {[-1]:Float@double}, intvals: {}
  %13 = call double @llvm.fabs.f64(double %3) #60, !dbg !75: {[-1]:Float@double}, intvals: {}
  %15 = call double @julia_prevfloat_16488(double %3) #61, !dbg !77: {[-1]:Float@double}, intvals: {}
  %26 = call double @llvm.fabs.f64(double %value_phi) #60, !dbg !95: {[-1]:Float@double}, intvals: {}
  %27 = call double @llvm.fabs.f64(double %value_phi2) #60, !dbg !102: {[-1]:Float@double}, intvals: {}
  %39 = call double @julia_tanh_16435(double %38) #62, !dbg !122: {[-1]:Float@double}, intvals: {}
i64 9218868437227405312: {[-1]:Anything}, intvals: {9218868437227405312,}
  %.not = icmp eq i64 %11, 9218868437227405312, !dbg !71: {[-1]:Integer}, intvals: {}
  %.not17 = icmp eq i64 %14, 9218868437227405312, !dbg !76: {[-1]:Integer}, intvals: {}
  %16 = fcmp uge double %value_phi, 0.000000e+00, !dbg !78: {[-1]:Integer}, intvals: {}
  %17 = fcmp ule double %value_phi, 0.000000e+00, !dbg !84: {[-1]:Integer}, intvals: {}
  %20 = fcmp uge double %value_phi2, 0.000000e+00, !dbg !78: {[-1]:Integer}, intvals: {}
  %21 = fcmp ule double %value_phi2, 0.000000e+00, !dbg !84: {[-1]:Integer}, intvals: {}
  %24 = fmul double %19, %23, !dbg !91: {[-1]:Float@double}, intvals: {}
  %25 = fcmp uge double %24, 0.000000e+00, !dbg !93: {[-1]:Integer}, intvals: {}
  %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !105: {}, intvals: {}
  %38 = fadd double %value_phi6, %unbox, !dbg !128: {[-1]:Float@double}, intvals: {}
  %43 = fsub double %41, %unbox8, !dbg !138: {[-1]:Float@double}, intvals: {}
  %value_phi6 = phi double [ %35, %L31 ], [ 0.000000e+00, %top ]: {[-1]:Float@double}, intvals: {}
</analysis>

Illegal updateBinop Analysis   %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !105
Illegal binopIn(down): 13 lhs: {[]:Float@double} rhs: {[]:Float@double}

MethodInstance for Roots.init_state(::Roots.Bisection, ::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#60#61"{Float64, Float64, Float64}, Nothing}, ::Float64, ::Float64, ::Float64, ::Float64)


Caused by:
Stacktrace:
 [1] +
   @ ./int.jl:87
 [2] __middle
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:135
 [3] __middle
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:124
 [4] _middle
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:117
 [5] init_state
   @ ~/.julia/packages/Roots/KNVCY/src/Bracketing/bisection.jl:34

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:1508
  [2] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, runtimeActivity::Bool, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/projects/Enzyme.jl/src/api.jl:319
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:4057
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:7125
  [5] codegen
    @ ~/projects/Enzyme.jl/src/compiler.jl:5950 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:8233
  [7] _thunk
    @ ~/projects/Enzyme.jl/src/compiler.jl:8233 [inlined]
  [8] cached_compilation
    @ ~/projects/Enzyme.jl/src/compiler.jl:8274 [inlined]
  [9] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/compiler.jl:8406
 [10] #s2079#19071
    @ ~/projects/Enzyme.jl/src/compiler.jl:8543 [inlined]
 [11]
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [13] runtime_generic_fwd(activity::Type{…}, runtimeActivity::Val{…}, width::Val{…}, RT::Val{…}, f::Bijectors.Inverse{…}, df::Nothing, df_2::Nothing, df_3::Nothing, df_4::Nothing, df_5::Nothing, df_6::Nothing, df_7::Nothing, df_8::Nothing, df_9::Nothing, primal_1::Matrix{…}, shadow_1_1::Matrix{…}, shadow_1_2::Matrix{…}, shadow_1_3::Matrix{…}, shadow_1_4::Matrix{…}, shadow_1_5::Matrix{…}, shadow_1_6::Matrix{…}, shadow_1_7::Matrix{…}, shadow_1_8::Matrix{…}, shadow_1_9::Matrix{…})
    @ Enzyme.Compiler ~/projects/Enzyme.jl/src/rules/jitrules.jl:290
 [14] #1
    @ ./REPL[3]:6 [inlined]
 [15] fwddiffe9julia__1_15870wrap
    @ ./REPL[3]:0
 [16] macro expansion
    @ ~/projects/Enzyme.jl/src/compiler.jl:8163 [inlined]
 [17] enzyme_call
    @ ~/projects/Enzyme.jl/src/compiler.jl:7729 [inlined]
 [18] ForwardModeThunk
    @ ~/projects/Enzyme.jl/src/compiler.jl:7518 [inlined]
 [19] autodiff
    @ ~/projects/Enzyme.jl/src/Enzyme.jl:647 [inlined]
 [20] autodiff
    @ ~/projects/Enzyme.jl/src/Enzyme.jl:512 [inlined]
 [21] macro expansion
    @ ~/projects/Enzyme.jl/src/Enzyme.jl:2069 [inlined]
 [22] gradient(::EnzymeCore.ForwardMode{…}, ::Main.MWE.var"#1#2", ::Matrix{…}; chunk::Nothing, shadows::Tuple{…})
    @ Enzyme ~/projects/Enzyme.jl/src/Enzyme.jl:1971
 [23] gradient(::EnzymeCore.ForwardMode{false, EnzymeCore.FFIABI, false, false}, ::Main.MWE.var"#1#2", ::Matrix{Float64})
    @ Enzyme ~/projects/Enzyme.jl/src/Enzyme.jl:1971
 [24] top-level scope
    @ REPL[3]:7
Some type information was truncated. Use `show(err)` to see complete types.

Affects both forward and reverse mode. Enzyme current main branch.

@wsmoses
Copy link
Member

wsmoses commented Nov 1, 2024

Looks like Roots.jl is missing a rule

@wsmoses wsmoses changed the title Illegal type analysis, Bijectors.PlanarLayer Missing Rule for Roots.jl Nov 3, 2024
@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

@mhauru can you reduce this to just the roots.jl call which errs?

@ChrisRackauckas
Copy link
Contributor

It's probably easiest to just switch to SimpleNonlinearSolve.jl? @avik-pal was adding a specialized adjoint for that case. We can do a PR to Bijectors? SciML/NonlinearSolve.jl#478

@devmotion
Copy link
Contributor

devmotion commented Nov 4, 2024

We switched to Roots.jl at some point since (Simple)NonlinearSolve.jl caused problems. IIRC Roots.jl was also more lightweight.

It's just missing a rule which is available in closed form and implemented for the other supported AD backends. We don't rely on any generic adjoints in Roots.jl.

@devmotion
Copy link
Contributor

Ref TuringLang/Bijectors.jl#202

@ChrisRackauckas
Copy link
Contributor

That PR isn't 1-1 because it's comparing NonlinearSolve vs Roots, whereas BracketingNonlinearSolve (a simplified SimpleNonlinearSolve with only the bracketing methods?) is the direct comparison and that is a much smaller library?

Also, you'd likely want to use ITP these days instead of Falsi in order to enforce robustness.

@devmotion
Copy link
Contributor

devmotion commented Nov 4, 2024

BracketingNonlinearSolve didn't exist back then but it's still a heavier dependency than Roots.jl. Roots.jl also supports ITP.

For a bit more context: Root finding is only needed for a specific method with a closed-form adjoint, so there's no need for very generic root-finding support.

@mhauru
Copy link
Contributor Author

mhauru commented Nov 26, 2024

@wsmoses here's a more barebones version:

module MWE

using Roots: Roots
using Enzyme: Enzyme


function f(y)
    lower = -1.0
    upper = 1.0
    α0 = Roots.find_zero(x -> y*x, (lower, upper), Roots.ITP())
    return α0
end

y = 1.0
Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Duplicated(y, 0.0))

end

Also, if the issue is just that Roots.jl is missing a rule, I think the error message should reflect that. The current error is confusing and, to the average Julia user, very intimidating.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants