Skip to content

Commit

Permalink
use axes to store the axes and not size
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 12, 2024
1 parent 7ca316b commit 31eea47
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.9.0"
version = "1.9.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
5 changes: 3 additions & 2 deletions ext/StaticArraysChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ end

# Project SArray to SArray
function ProjectTo(x::SArray{S, T}) where {S, T}
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x))
return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = axes(x),
size = Size(x))
end

@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx)

(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx)
(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx)

# Adjoint for SArray constructor
function rrule(::Type{T}, x::Tuple) where {T <: SArray}
Expand Down
2 changes: 1 addition & 1 deletion test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test
using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, LinearAlgebra, Test

@testset "ChainRules Integration" begin
@testset "Projection" begin
Expand Down

0 comments on commit 31eea47

Please sign in to comment.