Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Use a heuristic to select broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 22, 2024
1 parent 17ae652 commit 6feb0c2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down Expand Up @@ -62,6 +63,7 @@ Reexport = "1"
ReverseDiff = "1.15"
StableRNGs = "1"
Statistics = "1.10"
Strided = "2"
Test = "1.10"
Tracker = "0.2.34"
Zygote = "0.6.69"
Expand Down
1 change: 1 addition & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using PrecompileTools: @recompile_invalidations
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using Statistics: Statistics, mean, std, var
using Strided: Strided, @strided
end

@reexport using NNlib
Expand Down
24 changes: 15 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,28 +130,34 @@ CRC.@non_differentiable __get_concrete_fba_output_eltype(::Any...)
end

@inline function __fast_broadcast(f::F, x, args...) where {F}
return ArrayInterface.fast_scalar_indexing(x) ? @..(f(x, args...)) : @.(f(x, args...))
ArrayInterface.fast_scalar_indexing(x) && return @.. f(x, args...)
return @. f(x, args...)
end
@inline function __fast_broadcast!(f::F, x, args...) where {F}
if ArrayInterface.fast_scalar_indexing(x)
@.. x = f(x, args...)
if maximum(length, (x, args...)) > 20_000
@strided x .= f.(x, args...)
else
@.. x = f(x, args...)
end
elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1
# Has GPU Compilation Problems
x .= sigmoid_fast.(x .+ first(args))
y = first(args)
@. x = sigmoid_fast(x + y) # Has GPU Compilation Problems
else
@. x = f(x, args...)
end
return x
end
@inline function __nonuniform_fast_broadcast!(f::F, x, args...) where {F}
if ArrayInterface.fast_scalar_indexing(x)
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, args...))
@simd ivdep for i in eachindex(bc)
@inbounds x[i] = bc[i]
if maximum(length, (x, args...)) > 20_000
@strided x .= f.(x, args...)
else
@. x = f(x, args...)
end
elseif f === ComposedFunction(sigmoid_fast, +) && length(args) == 1
# Has GPU Compilation Problems
x .= sigmoid_fast.(x .+ first(args))
y = first(args)
@. x = sigmoid_fast(x + y) # Has GPU Compilation Problems
else
@. x = f(x, args...)
end
Expand Down
2 changes: 1 addition & 1 deletion test/qa_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@testitem "Aqua: Quality Assurance" tags=[:nworkers, :others] begin
using Aqua

Aqua.test_all(LuxLib; unbound_args=(; broken = true))
Aqua.test_all(LuxLib; unbound_args=(; broken=true))
end

@testitem "Explicit Imports" tags=[:nworkers, :others] begin
Expand Down

0 comments on commit 6feb0c2

Please sign in to comment.