Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add a lux test
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 16, 2024
1 parent c45432f commit 18afaca
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 29 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Aqua = "0.8.4"
ArrayInterface = "7.8.1"
CUDA = "5.2.0"
ChainRulesCore = "1.23"
ComponentArrays = "0.15.10"
ConcreteStructs = "0.2.3"
ExplicitImports = "1.4.0"
FastClosures = "0.3.2"
Expand All @@ -44,6 +45,7 @@ FiniteDiff = "2.22"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.27"
Lux = "0.5.23"
LuxCUDA = "0.3.2"
LuxDeviceUtils = "0.1.17"
LuxTestUtils = "0.1.15"
Expand All @@ -52,25 +54,29 @@ Random = "<0.0.1, 1"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Zygote"]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Zygote"]
28 changes: 14 additions & 14 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ end
return ForwardDiff.value.(y_duals), J_partial
end

function BatchedRoutines.__batched_value_and_jacobian(
function __batched_value_and_jacobian(
ad::AutoForwardDiff, f::F, u::AbstractMatrix{T},
ck::Val{chunksize}) where {F, T, chunksize}
N, B = size(u)
Expand Down Expand Up @@ -104,19 +104,19 @@ function BatchedRoutines.__batched_value_and_jacobian(
return y, J
end

@generated function BatchedRoutines.__batched_value_and_jacobian(
@generated function __batched_value_and_jacobian(
ad::AutoForwardDiff{CK}, f::F, u::AbstractMatrix{T}) where {CK, F, T}
if CK === nothing || CK 0
if _assert_type(u) && Base.issingletontype(F)
rType = Tuple{u, parameterless_type(u){T, 3}}
jac_call = :((y, J) = BatchedRoutines.__batched_value_and_jacobian(
jac_call = :((y, J) = __batched_value_and_jacobian(
ad, f, u, Val(batched_pickchunksize(u)))::$(rType))
else # Cases like ReverseDiff over ForwardDiff
jac_call = :((y, J) = BatchedRoutines.__batched_value_and_jacobian(
jac_call = :((y, J) = __batched_value_and_jacobian(
ad, f, u, Val(batched_pickchunksize(u))))
end
else
jac_call = :((y, J) = BatchedRoutines.__batched_value_and_jacobian(
jac_call = :((y, J) = __batched_value_and_jacobian(
ad, f, u, $(Val(CK))))
end
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalMatrix(J))))
Expand All @@ -140,7 +140,7 @@ end

@inline function BatchedRoutines._batched_jacobian(
ad::AutoForwardDiff, f::F, u::AbstractMatrix) where {F}
return last(BatchedRoutines.__batched_value_and_jacobian(ad, f, u))
return last(__batched_value_and_jacobian(ad, f, u))
end

# We don't use the ForwardDiff.gradient since it causes GPU compilation errors due to
Expand Down Expand Up @@ -226,23 +226,23 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{

function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
dev = get_device(x)
partials = ForwardDiff.Partials{1, T}.(tuple.(u)) |> dev
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials)
x_dual = _construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual)
return ForwardDiff.partials.(y_dual, 1)
end

function BatchedRoutines._jacobian_vector_product(
ad::AutoForwardDiff, f::F, x, u, p) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
T = promote_type(eltype(x), eltype(u))
dev = get_device(x)
partials = ForwardDiff.Partials{1, T}.(tuple.(u)) |> dev
x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials)
x_dual = _construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual, p)
return ForwardDiff.partials.(y_dual, 1)
end

@inline function _construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
end

end
20 changes: 6 additions & 14 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
function __batched_value_and_jacobian(ad, f::F, x) where {F}
J = batched_jacobian(ad, f, x)
return f(x), J
end

# TODO: Use OneElement for this
function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) where {F}
if !_is_extension_loaded(Val(:ForwardDiff)) || !_is_extension_loaded(Val(:Zygote))
throw(ArgumentError("`ForwardDiff.jl` and `Zygote.jl` needs to be loaded to \
Expand All @@ -15,9 +9,8 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) wher
∇batched_jacobian = Δ -> begin
gradient_ad = AutoZygote()
_map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x),
x, reshape(Δᵢ, size(x)))
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x))
x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x), x, Δᵢ)
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x))
return NoTangent(), NoTangent(), NoTangent(), ∂x
end
return J, ∇batched_jacobian
Expand All @@ -33,16 +26,15 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F}

∇batched_jacobian = Δ -> begin
_map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x),
x, reshape(Δᵢ, size(x)))
x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x), x, Δᵢ)

∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x))
∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x))

_map_fnₚ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(),
(x, p_) -> batched_gradient(AutoZygote(), p__ -> sum(vec(f(x, p__))[i:i]), p_),
x, reshape(Δᵢ, size(x)), p)
x, Δᵢ, p)

∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(eachrow(Δ))), size(p))
∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(_eachrow(Δ))), size(p))

return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p
end
Expand Down
20 changes: 20 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,23 @@ end
resolved && return T, true
return promote_type(T, eltype(f.x)), false
end

# eachrow override
@inline _eachrow(X) = eachrow(X)

# MLUtils.jl has too many unwanted dependencies
@inline fill_like(x::AbstractArray, v, ::Type{T}, dims...) where {T} = fill!(
similar(x, T, dims...), v)
@inline fill_like(x::AbstractArray, v, dims...) = fill_like(x, v, eltype(x), dims...)

@inline zeros_like(x::AbstractArray, ::Type{T}, dims...) where {T} = fill_like(
x, zero(T), T, dims...)
@inline zeros_like(x::AbstractArray, dims...) = zeros_like(x, eltype(x), dims...)

@inline ones_like(x::AbstractArray, ::Type{T}, dims...) where {T} = fill_like(
x, one(T), T, dims...)
@inline ones_like(x::AbstractArray, dims...) = ones_like(x, eltype(x), dims...)

CRC.@non_differentiable fill_like(::Any...)
CRC.@non_differentiable zeros_like(::Any...)
CRC.@non_differentiable ones_like(::Any...)
17 changes: 17 additions & 0 deletions src/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,23 @@ function Base.fill!(A::UniformBlockDiagonalMatrix, v)
return A
end

@inline function _eachrow(X::UniformBlockDiagonalMatrix)
row_fn = @closure i -> begin
M, N, K = size(X.data)
k = (i - 1) ÷ M + 1
i_ = mod1(i, M)
data = view(X.data, i_, :, k)
if k == 1
return vcat(data, zeros_like(data, N * (K - 1)))
elseif k == K
return vcat(zeros_like(data, N * (K - 1)), data)
else
return vcat(zeros_like(data, N * (k - 1)), data, zeros_like(data, N * (K - k)))
end
end
return map(row_fn, 1:size(X, 1))
end

# Broadcasting
struct UniformBlockDiagonalMatrixStyle{N} <: Broadcast.AbstractArrayStyle{2} end

Expand Down
57 changes: 57 additions & 0 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,60 @@
end
end
end

@testitem "Simple Lux Integration" setup=[SharedTestSetup] begin
using ComponentArrays, ForwardDiff, Lux, Random, Zygote

rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
model = Chain(Dense(4 => 6, tanh), Dense(6 => 3))
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps) |> dev
st = st |> dev

x = randn(rng, 4, 3) |> dev
y = randn(rng, 4, 3) |> dev
target_jac = batched_jacobian(
AutoForwardDiff(; chunksize=4), StatefulLuxLayer(model, nothing, st), y, ps)

loss_function = (model, x, target_jac, ps, st) -> begin
m = StatefulLuxLayer(model, nothing, st)
jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x, ps)
return sum(abs2, jac_full .- target_jac)
end

@test loss_function(model, x, target_jac, ps, st) isa Number
@test !iszero(loss_function(model, x, target_jac, ps, st))

cdev = cpu_device()
_fn_x = x -> loss_function(model, x, target_jac |> cdev, ps |> cdev, st)
_fn_ps = p -> loss_function(
model, x |> cdev, target_jac |> cdev, ComponentArray(p, getaxes(ps)), st)

∂x_fdiff = ForwardDiff.gradient(_fn_x, cdev(x))
∂ps_fdiff = ForwardDiff.gradient(_fn_ps, cdev(ps))

_, ∂x, _, ∂ps, _ = Zygote.gradient(loss_function, model, x, target_jac, ps, st)

@test cdev(∂x) ∂x_fdiff
@test cdev(∂ps) ∂ps_fdiff

loss_function2 = (model, x, target_jac, ps, st) -> begin
m = StatefulLuxLayer(model, ps, st)
jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x)
return sum(abs2, jac_full .- target_jac)
end

@test loss_function2(model, x, target_jac, ps, st) isa Number
@test !iszero(loss_function2(model, x, target_jac, ps, st))

_fn_x = x -> loss_function2(model, x, cdev(target_jac), cdev(ps), st)

∂x_fdiff = ForwardDiff.gradient(_fn_x, cdev(x))

_, ∂x, _, _, _ = Zygote.gradient(loss_function2, model, x, target_jac, ps, st)

@test cdev(∂x) ∂x_fdiff
end
end

0 comments on commit 18afaca

Please sign in to comment.