Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: decide internal operation based on unwrapped arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 23, 2024
1 parent c185f04 commit a5983f9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.48"
version = "0.3.49"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
14 changes: 11 additions & 3 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ForwardDiff: ForwardDiff
using NNlib: NNlib
using Static: True, False, static
using StaticArraysCore: StaticArray
using UnrolledUtilities: unrolled_map

using ..LuxLib: Numeric
using ..Utils
Expand All @@ -26,6 +27,12 @@ for op in (:has_dual, :has_float16, :is_tracked)
@eval $op(x::Numeric) = $op(eltype(x))
end

unwrap_array(x) = x

Check warning on line 30 in src/traits.jl

View check run for this annotation

Codecov / codecov/patch

src/traits.jl#L30

Added line #L30 was not covered by tests
function unwrap_array(x::AbstractArray)
parent(x) === x && return x
return unwrap_array(parent(x))
end

has_dual(_) = False()
has_dual(::Type{<:ForwardDiff.Dual}) = True()

Expand All @@ -42,9 +49,10 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T))
function use_generic_broadcasting(xs::Tuple)
# Float16 is a bit iffy and reordering operations are not optimal for numerical
# stability so we use the generic implementation for now.
return Utils.unrolled_any(has_autodiff_value, xs) |
Utils.unrolled_any(has_float16, xs) |
Utils.unrolled_any(static_isa(StaticArray), xs)
xs_unwrapped = unrolled_map(unwrap_array, xs)
return Utils.unrolled_any(has_autodiff_value, xs_unwrapped) |
Utils.unrolled_any(has_float16, xs_unwrapped) |
Utils.unrolled_any(static_isa(StaticArray), xs_unwrapped)
end

activation_intermediate_not_needed(::typeof(identity), ::Type) = True()
Expand Down
18 changes: 18 additions & 0 deletions test/others/misc_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin
@testset "$mode" for (mode, aType, ongpu) in MODES
x = rand(Float32, 4, 3) |> aType
retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp
@test LuxLib.internal_operation_mode(x) isa retval
end

using StaticArrays, JLArrays

x = rand(Float32, 4, 3) |> JLArray
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp

x = @SArray rand(Float32, 4, 3)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp

x = reshape(@SArray(rand(Float32, 4)), :, 1)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp
end

0 comments on commit a5983f9

Please sign in to comment.