-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: more coverage for common NN operations #55
Conversation
Need to fix the compats. Lux failure will be gone with LuxDL/LuxLib.jl#105. |
a93e795
to
807d004
Compare
1.11 failures are due to Enzyme not working on 1.11 |
fe1df41
to
477bf21
Compare
ext/ReactantNNlibExt.jl
Outdated
end | ||
end | ||
end | ||
|
||
function Reactant.elem_apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We generally shouldn’t need to add more elem alllirs anymore now that we have batching properly. I think just defining this for an rarray of size zero should suffice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this
function NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T}
return max(x, zero(T))
end
but I am getting
error: 'stablehlo.constant' op inferred type(s) 'tensor<f32>' are incompatible with return type(s) of operation 'tensor<2x3xf32>'
error: 'stablehlo.constant' op failed to infer returned types
ERROR: "failed to run pass manager on module"
Stacktrace:
[1] run!
@ /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Pass.jl:70 [inlined]
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String)
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1178
[3] compile_to_module(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{Any}; optimize::Bool)
@ Reactant /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1198
[4] (::var"#51#52")()
@ Main /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1315
[5] context!(f::var"#51#52", ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/research/lux/XLA/Reactant.jl/src/mlir/IR/Context.jl:71
[6] top-level scope
@ /mnt/research/lux/XLA/Reactant.jl/src/Reactant.jl:1313
Pre-optimize the code is
Module:
module {
func.func private @relu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%1 = stablehlo.maximum %0, %cst : tensor<f32>
%2 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
%3 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
return %2, %3 : tensor<f32>, tensor<f32>
}
func.func @main(%arg0: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32>
%1:2 = enzyme.batch @relu_broadcast_scalar(%0) {batch_shape = array<i64: 2, 3>} : (tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>)
%2 = stablehlo.transpose %1#0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
%3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
return %2, %3 : tensor<3x2xf32>, tensor<3x2xf32>
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the "we dont support constants in batching yet" which I'm presently working on. I'll try to get this squared today/tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I just leave a "TODO" in comment for relu or wait for the support?
ext/ReactantNNlibExt.jl
Outdated
::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N} | ||
) where {ElType,Shape,N} | ||
# See https://arxiv.org/pdf/1606.08415v5 Section 2 | ||
return lhs .* sigmoid.(ElType(1.702) .* lhs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there an erf
op in HLO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the most positive things of setting up docs for the hlo stuff: https://enzymead.github.io/Reactant.jl/dev/api/#Reactant.MLIR.Dialects.chlo.erf-Tuple{Reactant.MLIR.IR.Value}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we can use any of the dialects? I was using https://openxla.org/s/results?q=erf#gsc.tab=0&gsc.q=erf&gsc.sort= as a reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we feel about adding a direct dep on SpecialFunctions
for erf
/erfinv
...? Else we will have to create a 2nd NNlibSpecialFunctionsExt to define the exact gelu impl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
presently yes (and we need to make sure we have corresponding lowering from one to the others, and potentially derivatives)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
erf
doesn't seem to have the derivatives implemented EnzymeAD/Enzyme-JAX#88, so I am more inclined to keep the current implementation, and switch it later once derivatives are implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to take a stab at it? It would go right here https://github.com/EnzymeAD/Enzyme-JAX/blob/4eeaef06e0da144bebd08ec739cf01911dcddb47/src/enzyme_ad/jax/Implementations/CHLODerivatives.td#L142 and shouldn't be bad cc @mofeing and or @Pangoraw who may be able to help with
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will give it a shot
How important is 1.9 support? Lux doesn't support 1.9 (neither does ArrayInterface) which causes the test failures. I could skip the tests on 1.9 but now that 1.10 will be LTS we might as well drop it |
It shouldn’t be hard to add (like two lines) if you want to go for it. I
also have an open PR fixing batched constants so we can get the new jll of
all the stuff at once
…On Sat, Jul 27, 2024 at 1:31 PM Avik Pal ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In ext/ReactantNNlibExt.jl
<#55 (comment)>:
> end
end
end
+function Reactant.elem_apply(
+ ::typeof(NNlib.relu), lhs::Reactant.TracedRArray{ElType,Shape,N}
+) where {ElType,Shape,N}
+ return ifelse.((lhs .> zero(ElType)), lhs, zero(ElType))
+end
+
+function Reactant.elem_apply(
+ ::typeof(NNlib.gelu), lhs::Reactant.TracedRArray{ElType,Shape,N}
+) where {ElType,Shape,N}
+ # See https://arxiv.org/pdf/1606.08415v5 Section 2
+ return lhs .* sigmoid.(ElType(1.702) .* lhs)
erf doesn't seem to have the derivatives implemented
EnzymeAD/Enzyme-JAX#88 <EnzymeAD/Enzyme-JAX#88>,
so I am more inclined to keep the current implementation, and switch it
later once derivatives are implemented.
—
Reply to this email directly, view it on GitHub
<#55 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXE6DVCJ32POU2DRYA3ZOPKP5AVCNFSM6AAAAABLRHMKTSVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDEMBTGIYTGMJSGQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Welcome to Codecov 🎉Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests. Thanks for integrating Codecov - We've got you covered ☂️ |
@avik-pal the jll bump with constant propagation of broadcast is merged. rebase this? |
9cf45c7
to
2c65ae7
Compare
Darn, can you extract the failing input hlo?
…On Thu, Aug 1, 2024 at 8:26 PM Avik Pal ***@***.***> wrote:
Still seems to error
https://github.com/EnzymeAD/Reactant.jl/actions/runs/10207710411/job/28242989697?pr=55#step:9:753
—
Reply to this email directly, view it on GitHub
<#55 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXAIW2RKUBP2NVZ3XVLZPLG4PAVCNFSM6AAAAABLRHMKTSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRUGI3DQNRXGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Module:
module {
func.func private @relu_broadcast_scalar(%arg0: tensor<f64>) -> (tensor<f64>, tensor<f64>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%1 = stablehlo.maximum %0, %cst : tensor<f64>
%2 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
%3 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
return %2, %3 : tensor<f64>, tensor<f64>
}
func.func @main(%arg0: tensor<2x2xf64>) -> (tensor<2x2xf64>, tensor<2x2xf64>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
%1:2 = enzyme.batch @relu_broadcast_scalar(%0) {batch_shape = array<i64: 2, 2>} : (tensor<2x2xf64>) -> (tensor<2x2xf64>, tensor<2x2xf64>)
%2 = stablehlo.transpose %1#0, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
%3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
return %2, %3 : tensor<2x2xf64>, tensor<2x2xf64>
}
} |
Now awaiting jll: JuliaPackaging/Yggdrasil#9196 [tho also @avik-pal this won't include the erf derivative yet -- unless you have a PR we can quickly merge and then get into the new jll] |
let's move forward without the erf derivative for now. I am trying to help one of our GSoCs with some benchmarking so that needs to be finished first 😓 |
@avik-pal the jll merged, can you rebase here? |
343a45d
to
d5d75f3
Compare
@@ -140,6 +172,17 @@ for (jlop, hloop, RT) in ( | |||
end | |||
end | |||
|
|||
function Base.ifelse( | |||
pred::TracedRArray{Bool,(),0}, x::TracedRArray{T1,(),0}, y::TracedRArray{T2,(),0} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make this generalize to any shape/size, not just 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't the broadcasting handle the shape automatically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but someone could also call ifelse(true, ones(4,4), zeros(4,4)) or ifelse(trues(4,4), ones(4,4), zeros(4,4)), etc, outside a broadcast [tho yes the 0 dim one will generalize to anything in a broadcast]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though I don't think the latter case is legal in julia atm, so just generalizing to ifelse(true, ones(4,4), zeros(4,4))
probably makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something is wrong with this version I defined:
julia> f(x) = ifelse.(true, x, x)
f (generic function with 1 method)
julia> Reactant.@code_hlo optimize=false f(x)
Module:
module {
func.func private @ifelse_broadcast_scalar(%arg0: tensor<i1>, %arg1: tensor<f64>) -> (tensor<i1>, tensor<f64>, tensor<f64>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<i1>) -> tensor<i1>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<f64>) -> tensor<f64>
%2 = stablehlo.select %0, %1, %1 : tensor<i1>, tensor<f64>
%3 = stablehlo.transpose %0, dims = [] : (tensor<i1>) -> tensor<i1>
%4 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
%5 = stablehlo.transpose %2, dims = [] : (tensor<f64>) -> tensor<f64>
return %3, %4, %5 : tensor<i1>, tensor<f64>, tensor<f64>
}
func.func @main(%arg0: tensor<3x2xf64>) -> (tensor<3x2xf64>, tensor<3x2xf64>) {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<3x2xf64>) -> tensor<2x3xf64>
%c = stablehlo.constant dense<true> : tensor<2x3xi1>
%1:3 = enzyme.batch @ifelse_broadcast_scalar(%c, %0) {batch_shape = array<i64: 2, 3>} : (tensor<2x3xi1>, tensor<2x3xf64>) -> (tensor<2x3xi1>, tensor<2x3xf64>, tensor<2x3xf64>)
%2 = stablehlo.transpose %1#2, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
%3 = stablehlo.transpose %1#1, dims = [1, 0] : (tensor<2x3xf64>) -> tensor<3x2xf64>
return %2, %3 : tensor<3x2xf64>, tensor<3x2xf64>
}
}
julia> Reactant.@code_hlo f(x)
Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<3x2xf64>) {
return
}
}
d74fbb2
to
d0e1dbd
Compare
Overview
sigmoid
sigmoid_fast
relu
gelu
abs2
mean
var