Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme compilation failed due to illegal type analysis #1573

Closed
mhauru opened this issue Jun 26, 2024 · 3 comments
Closed

Enzyme compilation failed due to illegal type analysis #1573

mhauru opened this issue Jun 26, 2024 · 3 comments

Comments

@mhauru
Copy link
Contributor

mhauru commented Jun 26, 2024

MWE:

using Bijectors
using Bijectors: PlanarLayer
using Enzyme

function f(θ)
    layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
    flow = transformed(MvNormal(zeros(2), I), inverse(layer))
    x = θ[6:7]
    return logabsdetjac(flow.transform, x)
end
Enzyme.gradient(Enzyme.Forward, f, randn(7))

Output:

ERROR: LoadError: Enzyme compilation failed due to illegal type analysis.
Current scope:
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia_init_state_4328([4 x double]* noalias nocapture nofree noundef nonnull writeonly sret([4 x double]) align 8 dereferenceable(32) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %0, { [3 x double] } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double}" "enzymejl_parmtype"="6145269136" "enzymejl_parmtype_ref"="1" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="5225778384" "enzymejl_parmtype_ref"="0" %2, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="5225778384" "enzymejl_parmtype_ref"="0" %3, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="5225778384" "enzymejl_parmtype_ref"="0" %4, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="5225778384" "enzymejl_parmtype_ref"="0" %5) unnamed_addr #53 !dbg !4122 {
top:
  %6 = alloca [4 x double], align 8
  %7 = call {}*** @julia.get_pgcstack() #54
  %ptls_field14 = getelementptr inbounds {}**, {}*** %7, i64 2
  %8 = bitcast {}*** %ptls_field14 to i64***
  %ptls_load1516 = load i64**, i64*** %8, align 8, !tbaa !44
  %9 = getelementptr inbounds i64*, i64** %ptls_load1516, i64 2
  %safepoint = load i64*, i64** %9, align 8, !tbaa !48
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #54, !dbg !4123
  fence syncscope("singlethread") seq_cst
  %10 = call double @llvm.fabs.f64(double %2) #54, !dbg !4124
  %11 = bitcast double %10 to i64, !dbg !4125
  %.not = icmp eq i64 %11, 9218868437227405312, !dbg !4125
  %12 = call double @julia_nextfloat_4313(double %2) #55, !dbg !4126
  %value_phi = select i1 %.not, double %12, double %2, !dbg !4126
  %13 = call double @llvm.fabs.f64(double %3) #54, !dbg !4127
  %14 = bitcast double %13 to i64, !dbg !4128
  %.not17 = icmp eq i64 %14, 9218868437227405312, !dbg !4128
  %15 = call double @julia_prevfloat_4339(double %3) #55, !dbg !4129
  %value_phi2 = select i1 %.not17, double %15, double %3, !dbg !4129
  %16 = fcmp uge double %value_phi, 0.000000e+00, !dbg !4130
  %17 = fcmp ule double %value_phi, 0.000000e+00, !dbg !4133
  %18 = select i1 %17, double %value_phi, double 1.000000e+00, !dbg !4135
  %19 = select i1 %16, double %18, double -1.000000e+00, !dbg !4135
  %20 = fcmp uge double %value_phi2, 0.000000e+00, !dbg !4130
  %21 = fcmp ule double %value_phi2, 0.000000e+00, !dbg !4133
  %22 = select i1 %21, double %value_phi2, double 1.000000e+00, !dbg !4135
  %23 = select i1 %20, double %22, double -1.000000e+00, !dbg !4135
  %24 = fmul double %19, %23, !dbg !4136
  %25 = fcmp uge double %24, 0.000000e+00, !dbg !4137
  br i1 %25, label %L31, label %L47, !dbg !4132

L31:                                              ; preds = %top
  %26 = call double @llvm.fabs.f64(double %value_phi) #54, !dbg !4139
  %bitcast_coercion = bitcast double %26 to i64, !dbg !4143
  %27 = call double @llvm.fabs.f64(double %value_phi2) #54, !dbg !4144
  %bitcast_coercion9 = bitcast double %27 to i64, !dbg !4146
  %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !4147
  %29 = lshr i64 %28, 1, !dbg !4149
  %30 = fadd double %value_phi, %value_phi2, !dbg !4151
  %31 = fcmp uge double %30, 0.000000e+00, !dbg !4153
  %32 = fcmp ule double %30, 0.000000e+00, !dbg !4155
  %33 = select i1 %32, double %30, double 1.000000e+00, !dbg !4157
  %34 = select i1 %31, double %33, double -1.000000e+00, !dbg !4157
  %bitcast_coercion12 = bitcast i64 %29 to double, !dbg !4158
  %35 = fmul double %34, %bitcast_coercion12, !dbg !4159
  br label %L47, !dbg !4159

L47:                                              ; preds = %L31, %top
  %value_phi6 = phi double [ %35, %L31 ], [ 0.000000e+00, %top ]
  %36 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 1, !dbg !4160
  %37 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 2, !dbg !4160
  %unbox = load double, double addrspace(11)* %37, align 8, !dbg !4162, !tbaa !48, !alias.scope !329, !noalias !330
  %38 = fadd double %value_phi6, %unbox, !dbg !4162
  %39 = call double @julia_tanh_4266(double %38) #56, !dbg !4160
  %unbox7 = load double, double addrspace(11)* %36, align 8, !dbg !4163, !tbaa !48, !alias.scope !329, !noalias !330
  %40 = fmul double %39, %unbox7, !dbg !4163
  %41 = fadd double %value_phi6, %40, !dbg !4162
  %42 = getelementptr inbounds { [3 x double] }, { [3 x double] } addrspace(11)* %1, i64 0, i32 0, i64 0, !dbg !4160
  %unbox8 = load double, double addrspace(11)* %42, align 8, !dbg !4164, !tbaa !48, !alias.scope !329, !noalias !330
  %43 = fsub double %41, %unbox8, !dbg !4164
  call fastcc void @julia__init_state_69_4334([4 x double]* noalias nocapture nofree noundef nonnull writeonly sret([4 x double]) align 8 dereferenceable(32) %6, double %value_phi6, double %43, double %2, double %3, double %4, double %5) #54, !dbg !4123
  %44 = bitcast [4 x double]* %0 to i8*, !dbg !4123
  %45 = bitcast [4 x double]* %6 to i8*, !dbg !4123
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture nofree noundef nonnull writeonly align 8 dereferenceable(32) %44, i8* noundef nonnull align 8 dereferenceable(32) %45, i64 noundef 32, i1 noundef false) #54, !dbg !4123, !noalias !4165
  ret void, !dbg !4123
}

 Type analysis state:
<analysis>
double 0.000000e+00: {[-1]:Anything}, intvals: {}
double 1.000000e+00: {[-1]:Float@double}, intvals: {}
double -1.000000e+00: {[-1]:Float@double}, intvals: {}
i64 9218868437227405312: {[-1]:Anything}, intvals: {9218868437227405312,}
[4 x double]* %0: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
{ [3 x double] } addrspace(11)* %1: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
double %2: {[-1]:Float@double}, intvals: {}
double %3: {[-1]:Float@double}, intvals: {}
double %4: {[-1]:Float@double}, intvals: {}
double %5: {[-1]:Float@double}, intvals: {}
  %10 = call double @llvm.fabs.f64(double %2) #54, !dbg !51: {[-1]:Float@double}, intvals: {}
  %12 = call double @julia_nextfloat_4313(double %2) #55, !dbg !56: {[-1]:Float@double}, intvals: {}
  %13 = call double @llvm.fabs.f64(double %3) #54, !dbg !58: {[-1]:Float@double}, intvals: {}
  %ptls_field14 = getelementptr inbounds {}**, {}*** %7, i64 2: {}, intvals: {}
  %9 = getelementptr inbounds i64*, i64** %ptls_load1516, i64 2: {[-1]:Pointer}, intvals: {}
  %39 = call double @julia_tanh_4266(double %38) #56, !dbg !105: {[-1]:Float@double}, intvals: {}
  %27 = call double @llvm.fabs.f64(double %value_phi2) #54, !dbg !85: {[-1]:Float@double}, intvals: {}
  %26 = call double @llvm.fabs.f64(double %value_phi) #54, !dbg !78: {[-1]:Float@double}, intvals: {}
  %15 = call double @julia_prevfloat_4339(double %3) #55, !dbg !60: {[-1]:Float@double}, intvals: {}
  %7 = call {}*** @julia.get_pgcstack() #54: {}, intvals: {}
  %6 = alloca [4 x double], align 8: {[-1]:Pointer, [-1,-1]:Float@double}, intvals: {}
  %value_phi6 = phi double [ %35, %L31 ], [ 0.000000e+00, %top ]: {[-1]:Float@double}, intvals: {}
  %19 = select i1 %16, double %18, double -1.000000e+00, !dbg !71: {[-1]:Float@double}, intvals: {}
  %value_phi = select i1 %.not, double %12, double %2, !dbg !56: {[-1]:Float@double}, intvals: {}
  %value_phi2 = select i1 %.not17, double %15, double %3, !dbg !60: {[-1]:Float@double}, intvals: {}
  %22 = select i1 %21, double %value_phi2, double 1.000000e+00, !dbg !71: {[-1]:Float@double}, intvals: {}
  %18 = select i1 %17, double %value_phi, double 1.000000e+00, !dbg !71: {[-1]:Float@double}, intvals: {}
  %23 = select i1 %20, double %22, double -1.000000e+00, !dbg !71: {[-1]:Float@double}, intvals: {}
  %.not17 = icmp eq i64 %14, 9218868437227405312, !dbg !59: {[-1]:Integer}, intvals: {}
  %16 = fcmp uge double %value_phi, 0.000000e+00, !dbg !61: {[-1]:Integer}, intvals: {}
  %.not = icmp eq i64 %11, 9218868437227405312, !dbg !54: {[-1]:Integer}, intvals: {}
  %38 = fadd double %value_phi6, %unbox, !dbg !111: {[-1]:Float@double}, intvals: {}
  %43 = fsub double %41, %unbox8, !dbg !121: {[-1]:Float@double}, intvals: {}
  %20 = fcmp uge double %value_phi2, 0.000000e+00, !dbg !61: {[-1]:Integer}, intvals: {}
  %17 = fcmp ule double %value_phi, 0.000000e+00, !dbg !67: {[-1]:Integer}, intvals: {}
  %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !88: {}, intvals: {}
  %25 = fcmp uge double %24, 0.000000e+00, !dbg !76: {[-1]:Integer}, intvals: {}
  %21 = fcmp ule double %value_phi2, 0.000000e+00, !dbg !67: {[-1]:Integer}, intvals: {}
  %24 = fmul double %19, %23, !dbg !74: {[-1]:Float@double}, intvals: {}
  %11 = bitcast double %10 to i64, !dbg !54: {[-1]:Float@double}, intvals: {}
  %ptls_load1516 = load i64**, i64*** %8, align 8, !tbaa !44: {}, intvals: {}
  %bitcast_coercion9 = bitcast double %27 to i64, !dbg !87: {[-1]:Float@double}, intvals: {}
  %safepoint = load i64*, i64** %9, align 8, !tbaa !48: {}, intvals: {}
  %14 = bitcast double %13 to i64, !dbg !59: {[-1]:Float@double}, intvals: {}
  %bitcast_coercion = bitcast double %26 to i64, !dbg !83: {[-1]:Float@double}, intvals: {}
  %8 = bitcast {}*** %ptls_field14 to i64***: {[-1]:Pointer}, intvals: {}
</analysis>

Illegal updateBinop Analysis   %28 = add i64 %bitcast_coercion9, %bitcast_coercion, !dbg !88
Illegal binopIn(down): 13 lhs: {[]:Float@double} rhs: {[]:Float@double}

MethodInstance for Roots.init_state(::Roots.Bisection, ::Roots.Callable_Function{Val{1}, Val{false}, Bijectors.var"#60#61"{Float64, Float64, Float64}, Nothing}, ::Float64, ::Float64, ::Float64, ::Float64)


Caused by:
Stacktrace:
 [1] +
   @ ./int.jl:87
 [2] __middle
   @ ~/.julia/packages/Roots/neTBD/src/Bracketing/bisection.jl:135
 [3] __middle
   @ ~/.julia/packages/Roots/neTBD/src/Bracketing/bisection.jl:124
 [4] _middle
   @ ~/.julia/packages/Roots/neTBD/src/Bracketing/bisection.jl:117
 [5] init_state
   @ ~/.julia/packages/Roots/neTBD/src/Bracketing/bisection.jl:34

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:1991
  [2] EnzymeCreateForwardDiff(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…})
    @ Enzyme.API ~/.julia/packages/Enzyme/3mqec/src/api.jl:170
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:3720
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:5845
  [5] codegen
    @ ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:5123 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:6652
  [7] _thunk
    @ ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:6652 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:6690 [inlined]
  [9] (::Enzyme.Compiler.var"#28587#28588"{…})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:6759
 [10] JuliaContext(f::Enzyme.Compiler.var"#28587#28588"{…}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/nWT2N/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/nWT2N/src/driver.jl:42
 [12] #s2010#28586
    @ ~/.julia/packages/Enzyme/3mqec/src/compiler.jl:6710 [inlined]
 [13]
    @ Enzyme.Compiler ./none:0
 [14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [15] autodiff(::EnzymeCore.ForwardMode{…}, f::EnzymeCore.Const{…}, ::Type{…}, args::EnzymeCore.BatchDuplicated{…})
    @ Enzyme ~/.julia/packages/Enzyme/3mqec/src/Enzyme.jl:415
 [16] autodiff
    @ ~/.julia/packages/Enzyme/3mqec/src/Enzyme.jl:321 [inlined]
 [17] gradient(::EnzymeCore.ForwardMode{EnzymeCore.FFIABI}, f::Function, x::Vector{Float64}; shadow::NTuple{7, Vector{…}})
    @ Enzyme ~/.julia/packages/Enzyme/3mqec/src/Enzyme.jl:1064
 [18] gradient(::EnzymeCore.ForwardMode{EnzymeCore.FFIABI}, f::Function, x::Vector{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/3mqec/src/Enzyme.jl:1060
 [19] top-level scope
    @ ~/projects/Enzyme-mwes/illegal_type_analysis/mwe.jl:13
 [20] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [21] top-level scope
    @ REPL[1]:1
in expression starting at /Users/mhauru/projects/Enzyme-mwes/illegal_type_analysis/mwe.jl:1
Some type information was truncated. Use `show(err)` to see complete types.

Enzyme v0.12.19, Bijectors v0.13.14.

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)
@wsmoses
Copy link
Member

wsmoses commented Jun 26, 2024

Enzyme is correctly yelling about a bithack here rather than risking a wrong answer: https://github.com/JuliaMath/Roots.jl/blob/4263c7683423af209af1879bd9cd223b0ba55acd/src/Bracketing/bisection.jl#L135

The resolution is a custom rule to roots.jl

@wsmoses
Copy link
Member

wsmoses commented Jun 26, 2024

@mhauru can you identify [ideally via the backtrace] what calls the roots.jl, and where presently turing has a custom rule in the call chain. It would be wise to emluate existing rule infra.

@wsmoses
Copy link
Member

wsmoses commented Jun 26, 2024

from here it looks like you added a rule for find_alpha, closing: TuringLang/Bijectors.jl#319

Feel free to use Enzyme's chain rule import macro if you like:

Enzyme.@import_rrule(typeof(Base.sort), Any);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants