Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use package extensions #41

Merged
merged 3 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
name = "ArraysOfArrays"
uuid = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
version = "0.6.3"
version = "0.6.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[extensions]
ArraysOfArraysAdaptExt = "Adapt"
ArraysOfArraysChainRulesCoreExt = "ChainRulesCore"
ArraysOfArraysStaticArraysCoreExt = "StaticArraysCore"

[compat]
Adapt = "1, 2, 3, 4"
ChainRulesCore = "1"
StaticArraysCore = "1"
Statistics = "1"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "ElasticArrays", "StaticArrays", "StatsBase", "Test"]
28 changes: 28 additions & 0 deletions ext/ArraysOfArraysAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).

module ArraysOfArraysAdaptExt

import Adapt
using Adapt: adapt

using ArraysOfArrays: ArrayOfSimilarArrays, VectorOfArrays
using ArraysOfArrays: no_consistency_checks


function Adapt.adapt_structure(to, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
adapted_data = adapt(to, A.data)
ArrayOfSimilarArrays{eltype(adapted_data),M,N}(adapted_data)
end


function Adapt.adapt_structure(to, A::VectorOfArrays)
VectorOfArrays(
adapt(to, A.data),
adapt(to, A.elem_ptr),
adapt(to, A.kernel_size),
no_consistency_checks
)
end


end # module ArraysOfArraysAdaptExt
37 changes: 37 additions & 0 deletions ext/ArraysOfArraysChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).

module ArraysOfArraysChainRulesCoreExt

import ChainRulesCore
using ChainRulesCore: NoTangent, unthunk

using ArraysOfArrays: ArrayOfSimilarArrays
using ArraysOfArrays: flatview


function _aosa_ctor_fromflat_pullback(ΔΩ)
NoTangent(), flatview(convert(ArrayOfSimilarArrays, unthunk(ΔΩ)))
end

function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, flat_data::AbstractArray{U,L}) where {T,M,N,L,U}
return ArrayOfSimilarArrays{T,M,N}(flat_data), _aosa_ctor_fromflat_pullback
end

_aosa_ctor_fromnested_pullback(ΔΩ) = NoTangent(), ΔΩ

function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
return ArrayOfSimilarArrays{T,M,N}(A), _aosa_ctor_fromnested_pullback
end


function ChainRulesCore.rrule(::typeof(flatview), A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
function flatview_pullback(ΔΩ)
data = unthunk(ΔΩ)
NoTangent(), ArrayOfSimilarArrays{eltype(data),M,N}(data)

Check warning on line 30 in ext/ArraysOfArraysChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ArraysOfArraysChainRulesCoreExt.jl#L27-L30

Added lines #L27 - L30 were not covered by tests
end

return flatview(A), flatview_pullback

Check warning on line 33 in ext/ArraysOfArraysChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ArraysOfArraysChainRulesCoreExt.jl#L33

Added line #L33 was not covered by tests
end


end # module ArraysOfArraysChainRulesCoreExt
26 changes: 26 additions & 0 deletions ext/ArraysOfArraysStaticArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).

module ArraysOfArraysStaticArraysCoreExt

import StaticArraysCore
using StaticArraysCore: StaticArray, SVector

import ArraysOfArrays
using ArraysOfArrays: nestedview


@inline ArraysOfArrays.flatview(A::AbstractArray{SA,N}) where {S,T,M,N,SA<:StaticArray{S,T,M}} =
reshape(reinterpret(T, A), size(SA)..., size(A)...)


@inline function ArraysOfArrays.nestedview(A::AbstractArray{T}, SA::Type{SVector{S,T}}) where {T,S}
size_A = size(A)
size_A[1] == S || throw(DimensionMismatch("Length $S of static vector type does not match first dimension of array of size $size_A"))
reshape(reinterpret(SA, A), ArraysOfArrays._tail(size_A)...)
end

@inline ArraysOfArrays.nestedview(A::AbstractArray{T}, ::Type{SVector{S}}) where {T,S} =
nestedview(A, SVector{S,T})


end # module ArraysOfArraysStaticArraysCoreExt
11 changes: 6 additions & 5 deletions src/ArraysOfArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ ArraysOfArrays provides two different types of nested arrays:
"""
module ArraysOfArrays

using Adapt
using Statistics
using ChainRulesCore

import StaticArraysCore

include("util.jl")
include("functions.jl")
include("array_of_similar_arrays.jl")
include("vector_of_arrays.jl")
include("arrays_of_static_arrays.jl")

@static if !isdefined(Base, :get_extension)
include("../ext/ArraysOfArraysAdaptExt.jl")
include("../ext/ArraysOfArraysChainRulesCoreExt.jl")
include("../ext/ArraysOfArraysStaticArraysCoreExt.jl")
end

end # module
30 changes: 0 additions & 30 deletions src/array_of_similar_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,11 @@ end

export ArrayOfSimilarArrays

function _aosa_ctor_fromflat_pullback(ΔΩ)
NoTangent(), flatview(convert(ArrayOfSimilarArrays, unthunk(ΔΩ)))
end

function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, flat_data::AbstractArray{U,L}) where {T,M,N,L,U}
return ArrayOfSimilarArrays{T,M,N}(flat_data), _aosa_ctor_fromflat_pullback
end

function ArrayOfSimilarArrays{T,M,N}(A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
B = ArrayOfSimilarArrays{T,M,N}(Array{T}(undef, innersize(A)..., size(A)...))
copyto!(B, A)
end

_aosa_ctor_fromnested_pullback(ΔΩ) = NoTangent(), ΔΩ

function ChainRulesCore.rrule(::Type{ArrayOfSimilarArrays{T,M,N}}, A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U}
return ArrayOfSimilarArrays{T,M,N}(A), _aosa_ctor_fromnested_pullback
end

ArrayOfSimilarArrays{T}(A::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U} =
ArrayOfSimilarArrays{T,M,N}(A)

Expand Down Expand Up @@ -143,15 +129,6 @@ the result may be freely changed without breaking the inner consistency of
"""
flatview(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} = A.data

function ChainRulesCore.rrule(::typeof(flatview), A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
function flatview_pullback(ΔΩ)
data = unthunk(ΔΩ)
NoTangent(), ArrayOfSimilarArrays{eltype(data),M,N}(data)
end

return flatview(A), flatview_pullback
end


Base.size(A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N} = split_tuple(size(A.data), Val{M}())[2]

Expand Down Expand Up @@ -216,13 +193,6 @@ end
Base.prepend!(dest::ArrayOfSimilarArrays{T,M,N}, src::AbstractArray{<:AbstractArray{U,M},N}) where {T,M,N,U} =
prepend!(dest, ArrayOfSimilarArrays(src))


function Adapt.adapt_structure(to, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
adapted_data = adapt(to, A.data)
ArrayOfSimilarArrays{eltype(adapted_data),M,N}(adapted_data)
end


function innermap(f::Base.Callable, A::ArrayOfSimilarArrays{T,M,N}) where {T,M,N}
new_data = map(f, A.data)
U = eltype(new_data)
Expand Down
15 changes: 0 additions & 15 deletions src/arrays_of_static_arrays.jl

This file was deleted.

10 changes: 0 additions & 10 deletions src/vector_of_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,16 +398,6 @@ function Base.empty!(A::VectorOfArrays)
end


function Adapt.adapt_structure(to, A::VectorOfArrays)
VectorOfArrays(
adapt(to, A.data),
adapt(to, A.elem_ptr),
adapt(to, A.kernel_size),
no_consistency_checks
)
end


function innermap(f::Base.Callable, A::VectorOfArrays)
new_data = map(f, A.data)
VectorOfArrays(new_data, A.elem_ptr, A.kernel_size, simple_consistency_checks)
Expand Down
13 changes: 13 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Documenter = "1"
6 changes: 4 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).

using Test
import Test

Test.@testset "Package ArraysOfArrays" begin
include("test_aqua.jl")
include("functions.jl")
include("array_of_similar_arrays.jl")
include("vector_of_arrays.jl")
end
include("test_docs.jl")
end # testset
16 changes: 16 additions & 0 deletions test/test_aqua.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).

import Test
import Aqua
import ArraysOfArrays

Test.@testset "Package ambiguities" begin
Test.@test isempty(Test.detect_ambiguities(ArraysOfArrays))
end # testset

Test.@testset "Aqua tests" begin
Aqua.test_all(
ArraysOfArrays,
ambiguities = true
)
end # testset
13 changes: 13 additions & 0 deletions test/test_docs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# This file is a part of ArraysOfArrays.jl, licensed under the MIT License (MIT).

using Test
using ArraysOfArrays
import Documenter

Documenter.DocMeta.setdocmeta!(
ArraysOfArrays,
:DocTestSetup,
:(using ArraysOfArrays);
recursive=true,
)
Documenter.doctest(ArraysOfArrays)
Loading