diff --git a/Project.toml b/Project.toml index 88980610..7b19264f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "0.3.48" +version = "0.3.49" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/impl/bias_activation.jl b/src/impl/bias_activation.jl index 536cd504..70cf7029 100644 --- a/src/impl/bias_activation.jl +++ b/src/impl/bias_activation.jl @@ -2,7 +2,7 @@ bias_activation(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation(σ::F, x::AbstractVector, bias::$(bType)) where {F} - return vec(bias_activation(σ, reshape(x, :, 1), bias)) + return vec(bias_activation(σ, get_utils(:insert_batch_dim)(x), bias)) end end @@ -91,7 +91,7 @@ end bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x for bType in (Nothing, AbstractVector) @eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F} - return vec(bias_activation!!(σ, reshape(x, :, 1), bias)) + return vec(bias_activation!!(σ, get_utils(:insert_batch_dim)(x), bias)) end end diff --git a/src/impl/matmul.jl b/src/impl/matmul.jl index 9794e2ee..25933898 100644 --- a/src/impl/matmul.jl +++ b/src/impl/matmul.jl @@ -1,7 +1,7 @@ # Wrappers over Base & LinearAlgebra implementations to use poly algs if needed matmuladd(A, B, ::Nothing) = matmul(A, B) function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) - return matmuladd(A, reshape(B, :, 1), bias) + return matmuladd(A, get_utils(:insert_batch_dim)(B), bias) end function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) @@ -24,7 +24,9 @@ function matmuladd(opmode::AbstractInternalArrayOpMode, A::AbstractMatrix, return C end -matmul(A::AbstractMatrix, B::AbstractVector) = vec(matmul(A, reshape(B, :, 1))) +function matmul(A::AbstractMatrix, B::AbstractVector) + return vec(matmul(A, get_utils(:insert_batch_dim)(B))) +end function matmul(A::AbstractMatrix, B::AbstractMatrix) if size(A, 2) != size(B, 1) throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) diff --git a/src/traits.jl b/src/traits.jl index 86130a6a..301dfd7c 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -6,6 +6,7 @@ using ForwardDiff: ForwardDiff using NNlib: NNlib using Static: True, False, static using StaticArraysCore: StaticArray +using UnrolledUtilities: unrolled_map using ..LuxLib: Numeric using ..Utils @@ -26,6 +27,12 @@ for op in (:has_dual, :has_float16, :is_tracked) @eval $op(x::Numeric) = $op(eltype(x)) end +unwrap_array(x) = x +function unwrap_array(x::AbstractArray) + parent(x) === x && return x + return unwrap_array(parent(x)) +end + has_dual(_) = False() has_dual(::Type{<:ForwardDiff.Dual}) = True() @@ -42,9 +49,10 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) function use_generic_broadcasting(xs::Tuple) # Float16 is a bit iffy and reordering operations are not optimal for numerical # stability so we use the generic implementation for now. - return Utils.unrolled_any(has_autodiff_value, xs) | - Utils.unrolled_any(has_float16, xs) | - Utils.unrolled_any(static_isa(StaticArray), xs) + xs_unwrapped = unrolled_map(unwrap_array, xs) + return Utils.unrolled_any(has_autodiff_value, xs_unwrapped) | + Utils.unrolled_any(has_float16, xs_unwrapped) | + Utils.unrolled_any(static_isa(StaticArray), xs_unwrapped) end activation_intermediate_not_needed(::typeof(identity), ::Type) = True() diff --git a/src/utils.jl b/src/utils.jl index d1d77613..a15d863b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,6 +9,7 @@ using LinearAlgebra: LinearAlgebra, BLAS using MLDataDevices: get_device_type, CPUDevice using NNlib: NNlib using Static: Static, False, True +using StaticArraysCore: SVector, SMatrix using ..LuxLib: Optional, ∂∅ @@ -231,6 +232,9 @@ end return end +insert_batch_dim(x::AbstractVector) = reshape(x, :, 1) +insert_batch_dim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x) + end # Accessing properties of modules leads to type instability in Zygote reverse pass diff --git a/test/others/misc_tests.jl b/test/others/misc_tests.jl new file mode 100644 index 00000000..6943de74 --- /dev/null +++ b/test/others/misc_tests.jl @@ -0,0 +1,33 @@ +@testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(Float32, 4, 3) |> aType + retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp + @test LuxLib.internal_operation_mode(x) isa retval + end + + using StaticArrays, JLArrays + + x = rand(Float32, 4, 3) |> JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = @SArray rand(Float32, 4, 3) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = reshape(@SArray(rand(Float32, 4)), :, 1) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end + +@testitem "Matmul: StaticArrays" tags=[:others] setup=[SharedTestSetup] begin + using LuxLib.Impl: matmuladd + using StaticArrays + + A = rand(2, 2) + bias = rand(2) + + # This works with LoopVectorization + B = ones(SMatrix{2, 1, Float64}) + @test matmuladd(A, B, bias) ≈ A * B .+ bias + + b = ones(SVector{2, Float64}) + @test matmuladd(A, b, bias) ≈ A * b .+ bias +end