From c7216c3241346d550f3ac2a4fe0211161e268ffe Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 8 Nov 2024 19:38:44 +0530 Subject: [PATCH] feat: add adjoint for `ArrayInterface.restructure` --- Project.toml | 7 ++++++- ext/ArrayInterfaceChainRulesCoreExt.jl | 22 ++++++++++++++++++++++ test/chainrules.jl | 7 +++++++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 ext/ArrayInterfaceChainRulesCoreExt.jl diff --git a/Project.toml b/Project.toml index 52cce972..c1b74442 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -24,6 +25,7 @@ ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" ArrayInterfaceCUDAExt = "CUDA" ArrayInterfaceCUDSSExt = "CUDSS" ArrayInterfaceChainRulesExt = "ChainRules" +ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" ArrayInterfaceReverseDiffExt = "ReverseDiff" ArrayInterfaceSparseArraysExt = "SparseArrays" @@ -37,6 +39,8 @@ BlockBandedMatrices = "0.13" CUDA = "5" CUDSS = "0.2, 0.3" ChainRules = "1" +ChainRulesCore = "1" +ChainRulesTestUtils = "1" GPUArraysCore = "0.1, 0.2" LinearAlgebra = "1.10" ReverseDiff = "1" @@ -51,6 +55,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -66,4 +71,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"] +test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays", "ChainRulesTestUtils"] diff --git a/ext/ArrayInterfaceChainRulesCoreExt.jl b/ext/ArrayInterfaceChainRulesCoreExt.jl new file mode 100644 index 00000000..6cf4c406 --- /dev/null +++ b/ext/ArrayInterfaceChainRulesCoreExt.jl @@ -0,0 +1,22 @@ +module ArrayInterfaceChainRulesCoreExt + +import ArrayInterface +import ChainRulesCore +import ChainRulesCore: unthunk, NoTangent, ZeroTangent, ProjectTo, @thunk + +function ChainRulesCore.rrule(::typeof(ArrayInterface.restructure), target, src) + projectT = ProjectTo(target) + function restructure_pullback(dt) + dt = unthunk(dt) + + f̄ = NoTangent() + t̄ = ZeroTangent() + s̄ = @thunk(projectT(ArrayInterface.restructure(src, dt))) + + f̄, t̄, s̄ + end + + return ArrayInterface.restructure(target, src), restructure_pullback +end + +end diff --git a/test/chainrules.jl b/test/chainrules.jl index df47b2ff..759a55be 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,6 +1,13 @@ using ArrayInterface, ChainRules, Test +using ComponentArrays, ChainRulesTestUtils, StaticArrays x = ChainRules.OneElement(3.0, (3, 3), (1:4, 1:4)) @test !ArrayInterface.can_setindex(x) @test !ArrayInterface.can_setindex(typeof(x)) + +arr = ComponentArray(a = 1.0, b = [2.0, 3.0], c = (; a = 4.0, b = 5.0), d = SVector{2}(6.0, 7.0)) +b = zeros(length(arr)) + +ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, arr, b) +ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, b, arr)