diff --git a/Project.toml b/Project.toml index 87a03923..1d85f3dc 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" [weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" @@ -62,6 +63,7 @@ Reexport = "1" ReverseDiff = "1.15" StableRNGs = "1" Statistics = "1.10" +Strided = "2" Test = "1.10" Tracker = "0.2.34" Zygote = "0.6.69" diff --git a/src/LuxLib.jl b/src/LuxLib.jl index c47a0f25..e962279e 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -16,6 +16,7 @@ using PrecompileTools: @recompile_invalidations using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, std, var + using Strided: Strided, @strided end @reexport using NNlib diff --git a/src/utils.jl b/src/utils.jl index 66fd289b..bc219fd5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -130,14 +130,19 @@ CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...) end @inline function __fast_broadcast(f::F, x, args...) where {F} - return ArrayInterface.fast_scalar_indexing(x) ? @..(f(x, args...)) : @.(f(x, args...)) + ArrayInterface.fast_scalar_indexing(x) && return @.. f(x, args...) + return @. f(x, args...) end @inline function __fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - @.. x = f(x, args...) + if maximum(length, (x, args...)) > 20_000 + @strided x .= f.(x, args...) + else + @.. x = f(x, args...) + end elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 - # Has GPU Compilation Problems - x .= sigmoid_fast.(x .+ first(args)) + y = first(args) + @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems else @. x = f(x, args...) end @@ -145,13 +150,14 @@ end end @inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F} if ArrayInterface.fast_scalar_indexing(x) - bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...)) - @simd ivdep for i in eachindex(bc) - @inbounds x[i] = bc[i] + if maximum(length, (x, args...)) > 20_000 + @strided x .= f.(x, args...) + else + @. x = f(x, args...) end elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1 - # Has GPU Compilation Problems - x .= sigmoid_fast.(x .+ first(args)) + y = first(args) + @. x = sigmoid_fast(x + y) # Has GPU Compilation Problems else @. x = f(x, args...) end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 644830b5..dc3d3d99 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,7 +1,7 @@ @testitem "Aqua: Quality Assurance" tags=[:nworkers, :others] begin using Aqua - Aqua.test_all(LuxLib; unbound_args=(; broken = true)) + Aqua.test_all(LuxLib; unbound_args=(; broken=true)) end @testitem "Explicit Imports" tags=[:nworkers, :others] begin