diff --git a/Project.toml b/Project.toml index 88980610..7b19264f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.48" +version = "0.3.49" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/traits.jl b/src/traits.jl index 86130a6a..301dfd7c 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -6,6 +6,7 @@ using ForwardDiff: ForwardDiff using NNlib: NNlib using Static: True, False, static using StaticArraysCore: StaticArray +using UnrolledUtilities: unrolled_map using ..LuxLib: Numeric using ..Utils @@ -26,6 +27,12 @@ for op in (:has_dual, :has_float16, :is_tracked) @eval $op(x::Numeric) = $op(eltype(x)) end +unwrap_array(x) = x +function unwrap_array(x::AbstractArray) + parent(x) === x && return x + return unwrap_array(parent(x)) +end + has_dual(_) = False() has_dual(::Type{<:ForwardDiff.Dual}) = True() @@ -42,9 +49,10 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. - return Utils.unrolled_any(has_autodiff_value, xs) | - Utils.unrolled_any(has_float16, xs) | - Utils.unrolled_any(static_isa(StaticArray), xs) + xs_unwrapped = unrolled_map(unwrap_array, xs) + return Utils.unrolled_any(has_autodiff_value, xs_unwrapped) | + Utils.unrolled_any(has_float16, xs_unwrapped) | + Utils.unrolled_any(static_isa(StaticArray), xs_unwrapped) end activation_intermediate_not_needed(::typeof(identity), ::Type) = True() diff --git a/test/others/misc_tests.jl b/test/others/misc_tests.jl new file mode 100644 index 00000000..7b00aa64 --- /dev/null +++ b/test/others/misc_tests.jl @@ -0,0 +1,18 @@ +@testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(Float32, 4, 3) |> aType + retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp + @test LuxLib.internal_operation_mode(x) isa retval + end + + using StaticArrays, JLArrays + + x = rand(Float32, 4, 3) |> JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = @SArray rand(Float32, 4, 3) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = reshape(@SArray(rand(Float32, 4)), :, 1) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end