Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNlib activations cannot be compiled without manual dispatches #54

Closed
avik-pal opened this issue Jul 26, 2024 · 5 comments · Fixed by #130
Closed

NNlib activations cannot be compiled without manual dispatches #54

avik-pal opened this issue Jul 26, 2024 · 5 comments · Fixed by #130

Comments

@avik-pal
Copy link
Collaborator

julia> xr = Reactant.ConcreteRArray(rand(Float32, 2, 3))
2×3 Reactant.ConcreteRArray{Float32, (2, 3), 2}:
 0.184252  0.863562  0.0996157
 0.14061   0.574859  0.236953

julia> Reactant.@code_hlo broadcast(tanh, xr)
ERROR: TypeError: in typeassert, expected Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}, got a value of type Reactant.TracedRArray{Float32, (2, 3), 2}
Stacktrace:
  [1] (::Reactant.var"#398#408"{typeof(broadcast), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{typeof(tanh), Reactant.TracedRArray{Float32, (2, 3), 2}}})()
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
  [2] block!(f::Reactant.var"#398#408"{typeof(broadcast), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{typeof(tanh), Reactant.TracedRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
  [3] make_mlir_fn(f::Function, args::Vector{Any}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
  [4] make_mlir_fn
    @ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
  [5] #127
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1186 [inlined]
  [6] block!(f::Reactant.var"#127#132"{typeof(broadcast), Vector{Any}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
  [7] #126
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1185 [inlined]
  [8] mmodule!(f::Reactant.var"#126#131"{Reactant.MLIR.IR.Module, typeof(broadcast), Vector{Any}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Module.jl:93
  [9] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Any}; optimize::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1182
 [10] (::var"#73#74")()
    @ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1314
 [11] context!(f::var"#73#74", ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
 [12] top-level scope
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1312
 [13] top-level scope
    @ none:1
@avik-pal
Copy link
Collaborator Author

Might be related. If we try to compile a broadcasted function that doesn't have the stablehlo mapping defined we get an error:

julia> using NNlib, Reactant

julia> act_fn(x) = swish.(x)
act_fn (generic function with 1 method)

julia> Reactant.@code_hlo act_fn(xr)
ERROR: MethodError: no method matching (::Core.OpaqueClosure{Tuple{Reactant.TracedRArray{Float32, (), 0}, Tuple{}}, Union{}})(::Reactant.TracedRArray{Float32, (), 0})
This error has been manually thrown, explicitly, so the method may exist but be intentionally marked as unimplemented.

Closest candidates are:
  (::Core.OpaqueClosure{Tuple{Reactant.TracedRArray{Float32, (), 0}, Tuple{}}, Union{}})(::Reactant.TracedRArray{Float32, (), 0}, ::Tuple{}) (method too new to be called from this world context.)
   @ Core :0

Stacktrace:
  [1] (::Reactant.var"#398#408"{typeof(swish), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (), 0}}})()
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
  [2] block!(f::Reactant.var"#398#408"{typeof(swish), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (), 0}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
  [3] make_mlir_fn(f::Function, args::Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
  [4] make_mlir_fn
    @ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
  [5] elem_apply(f::Function, args::Reactant.TracedRArray{Float32, (2, 3), 2})
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:190
  [6] _copyto!
    @ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:555 [inlined]
  [7] copyto!
    @ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:470 [inlined]
  [8] copyto!
    @ ./broadcast.jl:920 [inlined]
  [9] copy
    @ /mnt/research/lux/XLA/Reactant.jl/src/overloads.jl:461 [inlined]
 [10] materialize
    @ ./broadcast.jl:867 [inlined]
 [11] act_fn
    @ ./REPL[90]:1 [inlined]
 [12] (::Tuple{})(none::Reactant.TracedRArray{Float32, (2, 3), 2})
    @ Base.Experimental ./<missing>:0
 [13] (::Reactant.var"#398#408"{typeof(act_fn), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}})()
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:105
 [14] block!(f::Reactant.var"#398#408"{typeof(act_fn), Reactant.MLIR.IR.Block, Vector{Reactant.TracedRArray}, Tuple{Reactant.TracedRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
 [15] make_mlir_fn(f::Function, args::Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:67
 [16] make_mlir_fn
    @ /mnt/research/lux/XLA/Reactant.jl/src/utils.jl:16 [inlined]
 [17] #127
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1186 [inlined]
 [18] block!(f::Reactant.var"#127#132"{typeof(act_fn), Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Block.jl:201
 [19] #126
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1185 [inlined]
 [20] mmodule!(f::Reactant.var"#126#131"{Reactant.MLIR.IR.Module, typeof(act_fn), Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Module.jl:93
 [21] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Reactant.ConcreteRArray{Float32, (2, 3), 2}}; optimize::Bool)
    @ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1182
 [22] (::var"#77#78")()
    @ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1314
 [23] context!(f::var"#77#78", ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
 [24] top-level scope
    @ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1312
 [25] top-level scope
    @ none:1

Note that I am running a version of Reactant that has broadcast defined for sigmoid, so swish = x * sigmoid(x) should be compiled fine.

@wsmoses
Copy link
Member

wsmoses commented Sep 16, 2024

@avik-pal is this still an issue at the moment?

@avik-pal
Copy link
Collaborator Author

Yeah this still exists

@avik-pal
Copy link
Collaborator Author

ir = 754 1%1 = Core.tuple(_2)::Tuple{Reactant.TracedRArray{Float32, 0}}%2 = Base.Broadcast.nothing::Nothing%3 = %new(Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(swish), Tuple{Reactant.TracedRArray{Float32, 0}}}, $(QuoteNode(Reactant.AbstractReactantArrayStyle{0}())), NNlib.swish, %1, %2)::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(swish), Tuple{Reactant.TracedRArray{Float32, 0}}}
    │        invoke Base.Broadcast.copy(%3::Base.Broadcast.Broadcasted{Reactant.AbstractReactantArrayStyle{0}, Nothing, typeof(swish), Tuple{Reactant.TracedRArray{Float32, 0}}})::Union{}
    └──      unreachable
    2 ─      φ ()::Union{}
    └──      unreachable

@avik-pal
Copy link
Collaborator Author

I see what's happening. NNlib does this irritating thing of forwarding f(x::AbstractArray) to f.(x) that is creating a stackoverflow and hence the unreachable

@avik-pal avik-pal changed the title TypeError on compiling broadcast NNlib activations cannot be compiled without manual dispatches Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants