Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix: decide internal operation based on unwrapped arrays #141

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.48"
version = "0.3.49"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
4 changes: 2 additions & 2 deletions src/impl/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
bias_activation(::typeof(identity), x::AbstractVector, ::Nothing) = x
for bType in (Nothing, AbstractVector)
@eval function bias_activation(σ::F, x::AbstractVector, bias::$(bType)) where {F}
return vec(bias_activation(σ, reshape(x, :, 1), bias))
return vec(bias_activation(σ, get_utils(:insert_batch_dim)(x), bias))
end
end

Expand Down Expand Up @@ -91,7 +91,7 @@ end
bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x
for bType in (Nothing, AbstractVector)
@eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F}
return vec(bias_activation!!(σ, reshape(x, :, 1), bias))
return vec(bias_activation!!(σ, get_utils(:insert_batch_dim)(x), bias))
end
end

Expand Down
6 changes: 4 additions & 2 deletions src/impl/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Wrappers over Base & LinearAlgebra implementations to use poly algs if needed
matmuladd(A, B, ::Nothing) = matmul(A, B)
function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector)
return matmuladd(A, reshape(B, :, 1), bias)
return matmuladd(A, get_utils(:insert_batch_dim)(B), bias)
end
function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector)
return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias)
Expand All @@ -24,7 +24,9 @@ function matmuladd(opmode::AbstractInternalArrayOpMode, A::AbstractMatrix,
return C
end

matmul(A::AbstractMatrix, B::AbstractVector) = vec(matmul(A, reshape(B, :, 1)))
function matmul(A::AbstractMatrix, B::AbstractVector)
return vec(matmul(A, get_utils(:insert_batch_dim)(B)))
end
function matmul(A::AbstractMatrix, B::AbstractMatrix)
if size(A, 2) != size(B, 1)
throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))"))
Expand Down
14 changes: 11 additions & 3 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ForwardDiff: ForwardDiff
using NNlib: NNlib
using Static: True, False, static
using StaticArraysCore: StaticArray
using UnrolledUtilities: unrolled_map

using ..LuxLib: Numeric
using ..Utils
Expand All @@ -26,6 +27,12 @@ for op in (:has_dual, :has_float16, :is_tracked)
@eval $op(x::Numeric) = $op(eltype(x))
end

unwrap_array(x) = x
function unwrap_array(x::AbstractArray)
parent(x) === x && return x
return unwrap_array(parent(x))
end

has_dual(_) = False()
has_dual(::Type{<:ForwardDiff.Dual}) = True()

Expand All @@ -42,9 +49,10 @@ static_isa(x, ::Type{T}) where {T} = static(isa(x, T))
function use_generic_broadcasting(xs::Tuple)
# Float16 is a bit iffy and reordering operations are not optimal for numerical
# stability so we use the generic implementation for now.
return Utils.unrolled_any(has_autodiff_value, xs) |
Utils.unrolled_any(has_float16, xs) |
Utils.unrolled_any(static_isa(StaticArray), xs)
xs_unwrapped = unrolled_map(unwrap_array, xs)
return Utils.unrolled_any(has_autodiff_value, xs_unwrapped) |
Utils.unrolled_any(has_float16, xs_unwrapped) |
Utils.unrolled_any(static_isa(StaticArray), xs_unwrapped)
end

activation_intermediate_not_needed(::typeof(identity), ::Type) = True()
Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using LinearAlgebra: LinearAlgebra, BLAS
using MLDataDevices: get_device_type, CPUDevice
using NNlib: NNlib
using Static: Static, False, True
using StaticArraysCore: SVector, SMatrix

using ..LuxLib: Optional, ∂∅

Expand Down Expand Up @@ -231,6 +232,9 @@ end
return
end

insert_batch_dim(x::AbstractVector) = reshape(x, :, 1)
insert_batch_dim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x)

end

# Accessing properties of modules leads to type instability in Zygote reverse pass
Expand Down
33 changes: 33 additions & 0 deletions test/others/misc_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
@testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin
@testset "$mode" for (mode, aType, ongpu) in MODES
x = rand(Float32, 4, 3) |> aType
retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp
@test LuxLib.internal_operation_mode(x) isa retval
end

using StaticArrays, JLArrays

x = rand(Float32, 4, 3) |> JLArray
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp

x = @SArray rand(Float32, 4, 3)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp

x = reshape(@SArray(rand(Float32, 4)), :, 1)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp
end

@testitem "Matmul: StaticArrays" tags=[:others] setup=[SharedTestSetup] begin
using LuxLib.Impl: matmuladd
using StaticArrays

A = rand(2, 2)
bias = rand(2)

# This works with LoopVectorization
B = ones(SMatrix{2, 1, Float64})
@test matmuladd(A, B, bias) ≈ A * B .+ bias

b = ones(SVector{2, Float64})
@test matmuladd(A, b, bias) ≈ A * b .+ bias
end
Loading