From 0e552c51ccf05172717e38f2abce7842f3564754 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Jan 2024 06:12:57 -0500 Subject: [PATCH] use axes to store the axes and not size --- Project.toml | 2 +- ext/StaticArraysChainRulesCoreExt.jl | 5 +++-- test/chainrules.jl | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index f854810d..22d98b43 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.0" +version = "1.9.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/ext/StaticArraysChainRulesCoreExt.jl b/ext/StaticArraysChainRulesCoreExt.jl index 5f7904aa..266a0dd7 100644 --- a/ext/StaticArraysChainRulesCoreExt.jl +++ b/ext/StaticArraysChainRulesCoreExt.jl @@ -15,12 +15,13 @@ end # Project SArray to SArray function ProjectTo(x::SArray{S, T}) where {S, T} - return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x)) + return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = axes(x), + size = Size(x)) end @inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx) -(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx) +(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx) # Adjoint for SArray constructor function rrule(::Type{T}, x::Tuple) where {T <: SArray} diff --git a/test/chainrules.jl b/test/chainrules.jl index ee3c02cc..c586432a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,4 +1,4 @@ -using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test +using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test @testset "ChainRules Integration" begin @testset "Projection" begin