diff --git a/Project.toml b/Project.toml index b9f4c4d..185a324 100644 --- a/Project.toml +++ b/Project.toml @@ -36,6 +36,7 @@ Aqua = "0.8.4" ArrayInterface = "7.8.1" CUDA = "5.2.0" ChainRulesCore = "1.23" +ComponentArrays = "0.15.10" ConcreteStructs = "0.2.3" ExplicitImports = "1.4.0" FastClosures = "0.3.2" @@ -44,6 +45,7 @@ FiniteDiff = "2.22" ForwardDiff = "0.10.36" LinearAlgebra = "1.10" LinearSolve = "2.27" +Lux = "0.5.23" LuxCUDA = "0.3.2" LuxDeviceUtils = "0.1.17" LuxTestUtils = "0.1.15" @@ -52,16 +54,19 @@ Random = "<0.0.1, 1" ReTestItems = "1.23.1" ReverseDiff = "1.15" StableRNGs = "1.0.1" +Statistics = "1.11.1" Test = "<0.0.1, 1" Zygote = "0.6.69" julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" @@ -69,8 +74,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Test", "Zygote"] +test = ["Aqua", "ComponentArrays", "ExplicitImports", "FiniteDiff", "ForwardDiff", "LinearSolve", "Lux", "LuxCUDA", "LuxDeviceUtils", "LuxTestUtils", "Random", "ReTestItems", "ReverseDiff", "StableRNGs", "Statistics", "Test", "Zygote"] diff --git a/ext/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt.jl index 0793a10..057c04a 100644 --- a/ext/BatchedRoutinesForwardDiffExt.jl +++ b/ext/BatchedRoutinesForwardDiffExt.jl @@ -65,7 +65,7 @@ end return ForwardDiff.value.(y_duals), J_partial end -function BatchedRoutines.__batched_value_and_jacobian( +function __batched_value_and_jacobian( ad::AutoForwardDiff, f::F, u::AbstractMatrix{T}, ck::Val{chunksize}) where {F, T, chunksize} N, B = size(u) @@ -104,19 +104,19 @@ function BatchedRoutines.__batched_value_and_jacobian( return y, J end -@generated function BatchedRoutines.__batched_value_and_jacobian( +@generated function __batched_value_and_jacobian( ad::AutoForwardDiff{CK}, f::F, u::AbstractMatrix{T}) where {CK, F, T} if CK === nothing || CK ≤ 0 if _assert_type(u) && Base.issingletontype(F) rType = Tuple{u, parameterless_type(u){T, 3}} - jac_call = :((y, J) = BatchedRoutines.__batched_value_and_jacobian( + jac_call = :((y, J) = __batched_value_and_jacobian( ad, f, u, Val(batched_pickchunksize(u)))::$(rType)) else # Cases like ReverseDiff over ForwardDiff - jac_call = :((y, J) = BatchedRoutines.__batched_value_and_jacobian( + jac_call = :((y, J) = __batched_value_and_jacobian( ad, f, u, Val(batched_pickchunksize(u)))) end else - jac_call = :((y, J) = BatchedRoutines.__batched_value_and_jacobian( + jac_call = :((y, J) = __batched_value_and_jacobian( ad, f, u, $(Val(CK)))) end return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalMatrix(J)))) @@ -140,7 +140,7 @@ end @inline function BatchedRoutines._batched_jacobian( ad::AutoForwardDiff, f::F, u::AbstractMatrix) where {F} - return last(BatchedRoutines.__batched_value_and_jacobian(ad, f, u)) + return last(__batched_value_and_jacobian(ad, f, u)) end # We don't use the ForwardDiff.gradient since it causes GPU compilation errors due to @@ -226,10 +226,7 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{ function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F} Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag) - T = promote_type(eltype(x), eltype(u)) - dev = get_device(x) - partials = ForwardDiff.Partials{1, T}.(tuple.(u)) |> dev - x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials) + x_dual = _construct_jvp_duals(Tag, x, u) y_dual = f(x_dual) return ForwardDiff.partials.(y_dual, 1) end @@ -237,12 +234,15 @@ end function BatchedRoutines._jacobian_vector_product( ad::AutoForwardDiff, f::F, x, u, p) where {F} Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag) - T = promote_type(eltype(x), eltype(u)) - dev = get_device(x) - partials = ForwardDiff.Partials{1, T}.(tuple.(u)) |> dev - x_dual = ForwardDiff.Dual{Tag, T, 1}.(x, partials) + x_dual = _construct_jvp_duals(Tag, x, u) y_dual = f(x_dual, p) return ForwardDiff.partials.(y_dual, 1) end +@inline function _construct_jvp_duals(::Type{Tag}, x, u) where {Tag} + T = promote_type(eltype(x), eltype(u)) + partials = ForwardDiff.Partials{1, T}.(tuple.(u)) + return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x))) +end + end diff --git a/src/chainrules.jl b/src/chainrules.jl index 3cb20f9..c3e1970 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,9 +1,3 @@ -function __batched_value_and_jacobian(ad, f::F, x) where {F} - J = batched_jacobian(ad, f, x) - return f(x), J -end - -# TODO: Use OneElement for this function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) where {F} if !_is_extension_loaded(Val(:ForwardDiff)) || !_is_extension_loaded(Val(:Zygote)) throw(ArgumentError("`ForwardDiff.jl` and `Zygote.jl` needs to be loaded to \ @@ -15,9 +9,8 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) wher ∇batched_jacobian = Δ -> begin gradient_ad = AutoZygote() _map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), - x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x), - x, reshape(Δᵢ, size(x))) - ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x)) + x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x), x, Δᵢ) + ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x)) return NoTangent(), NoTangent(), NoTangent(), ∂x end return J, ∇batched_jacobian @@ -33,16 +26,15 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F} ∇batched_jacobian = Δ -> begin _map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), - x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x), - x, reshape(Δᵢ, size(x))) + x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x), x, Δᵢ) - ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x)) + ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x)) _map_fnₚ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), (x, p_) -> batched_gradient(AutoZygote(), p__ -> sum(vec(f(x, p__))[i:i]), p_), - x, reshape(Δᵢ, size(x)), p) + x, Δᵢ, p) - ∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(eachrow(Δ))), size(p)) + ∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(_eachrow(Δ))), size(p)) return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p end diff --git a/src/helpers.jl b/src/helpers.jl index c7b2217..7cc8c89 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -147,3 +147,23 @@ end resolved && return T, true return promote_type(T, eltype(f.x)), false end + +# eachrow override +@inline _eachrow(X) = eachrow(X) + +# MLUtils.jl has too many unwanted dependencies +@inline fill_like(x::AbstractArray, v, ::Type{T}, dims...) where {T} = fill!( + similar(x, T, dims...), v) +@inline fill_like(x::AbstractArray, v, dims...) = fill_like(x, v, eltype(x), dims...) + +@inline zeros_like(x::AbstractArray, ::Type{T}, dims...) where {T} = fill_like( + x, zero(T), T, dims...) +@inline zeros_like(x::AbstractArray, dims...) = zeros_like(x, eltype(x), dims...) + +@inline ones_like(x::AbstractArray, ::Type{T}, dims...) where {T} = fill_like( + x, one(T), T, dims...) +@inline ones_like(x::AbstractArray, dims...) = ones_like(x, eltype(x), dims...) + +CRC.@non_differentiable fill_like(::Any...) +CRC.@non_differentiable zeros_like(::Any...) +CRC.@non_differentiable ones_like(::Any...) diff --git a/src/matrix.jl b/src/matrix.jl index 7d894c5..4f7cc1e 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -149,6 +149,23 @@ function Base.fill!(A::UniformBlockDiagonalMatrix, v) return A end +@inline function _eachrow(X::UniformBlockDiagonalMatrix) + row_fn = @closure i -> begin + M, N, K = size(X.data) + k = (i - 1) ÷ M + 1 + i_ = mod1(i, M) + data = view(X.data, i_, :, k) + if k == 1 + return vcat(data, zeros_like(data, N * (K - 1))) + elseif k == K + return vcat(zeros_like(data, N * (K - 1)), data) + else + return vcat(zeros_like(data, N * (k - 1)), data, zeros_like(data, N * (K - k))) + end + end + return map(row_fn, 1:size(X, 1)) +end + # Broadcasting struct UniformBlockDiagonalMatrixStyle{N} <: Broadcast.AbstractArrayStyle{2} end diff --git a/test/integration_tests.jl b/test/integration_tests.jl index 82aa7ee..a03603a 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -52,3 +52,60 @@ end end end + +@testitem "Simple Lux Integration" setup=[SharedTestSetup] begin + using ComponentArrays, ForwardDiff, Lux, Random, Zygote + + rng = get_stable_rng(1001) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + model = Chain(Dense(4 => 6, tanh), Dense(6 => 3)) + ps, st = Lux.setup(rng, model) + ps = ComponentArray(ps) |> dev + st = st |> dev + + x = randn(rng, 4, 3) |> dev + y = randn(rng, 4, 3) |> dev + target_jac = batched_jacobian( + AutoForwardDiff(; chunksize=4), StatefulLuxLayer(model, nothing, st), y, ps) + + loss_function = (model, x, target_jac, ps, st) -> begin + m = StatefulLuxLayer(model, nothing, st) + jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x, ps) + return sum(abs2, jac_full .- target_jac) + end + + @test loss_function(model, x, target_jac, ps, st) isa Number + @test !iszero(loss_function(model, x, target_jac, ps, st)) + + cdev = cpu_device() + _fn_x = x -> loss_function(model, x, target_jac |> cdev, ps |> cdev, st) + _fn_ps = p -> loss_function( + model, x |> cdev, target_jac |> cdev, ComponentArray(p, getaxes(ps)), st) + + ∂x_fdiff = ForwardDiff.gradient(_fn_x, cdev(x)) + ∂ps_fdiff = ForwardDiff.gradient(_fn_ps, cdev(ps)) + + _, ∂x, _, ∂ps, _ = Zygote.gradient(loss_function, model, x, target_jac, ps, st) + + @test cdev(∂x) ≈ ∂x_fdiff + @test cdev(∂ps) ≈ ∂ps_fdiff + + loss_function2 = (model, x, target_jac, ps, st) -> begin + m = StatefulLuxLayer(model, ps, st) + jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x) + return sum(abs2, jac_full .- target_jac) + end + + @test loss_function2(model, x, target_jac, ps, st) isa Number + @test !iszero(loss_function2(model, x, target_jac, ps, st)) + + _fn_x = x -> loss_function2(model, x, cdev(target_jac), cdev(ps), st) + + ∂x_fdiff = ForwardDiff.gradient(_fn_x, cdev(x)) + + _, ∂x, _, _, _ = Zygote.gradient(loss_function2, model, x, target_jac, ps, st) + + @test cdev(∂x) ≈ ∂x_fdiff + end +end