Skip to content

Commit

Permalink
Code clean and reuse for constructor. (#1016)
Browse files Browse the repository at this point in the history
* Define `construct_type`

Similar to `similar_type`, but not fallback to `SArray`. It is used to pick the most `concrete` constructor with `size`, `eltype` and `ndims` defined.

* Define general `constructor` with `construct_type`

* Define general `convert` with `construct_type`.

And fix for input with non-1 based axes.

* Make `FieldArray`'s general constructor based on `construct_type`

This also enable auto type promotion.

* `Constructor` and `convertor` clean.

With `construct_type`, there's no need to keep all these dispatches.
This also fix empty construction for `S/MVector`, and remove most of the ambiguities.

* Add more test

* Test fix and clean.

* Drop `StaticSquareMatrix` & add more test.

1. remove `StaticSquareMatrix` (`StaticMatrix{N,N}` should be shorter and clear enough)
2. Add missing Test.

* Add constructor test for `OffsetArray`.

* Rename `FirstClass` as `SizeEltypeAdaptable`

And convert comments to docstring.

* Replace `_NTuple` with `_TupleOf`

and typo fix.
  • Loading branch information
N5N3 authored May 16, 2022
1 parent c09bc9a commit 8ca11f8
Show file tree
Hide file tree
Showing 26 changed files with 351 additions and 248 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ julia = "1.6"
[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["InteractiveUtils", "Test", "BenchmarkTools"]
test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays"]
11 changes: 6 additions & 5 deletions src/FieldArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ array operations as in the example below.
"""
abstract type FieldVector{N, T} <: FieldArray{Tuple{N}, T, 1} end

@inline function (::Type{FA})(x::Tuple{Vararg{Any, N}}) where {N, FA <: FieldArray}
@boundscheck if length(FA) != length(x)
throw(DimensionMismatch("No precise constructor for $FA found. Length of input was $(length(x))."))
end
return FA(x...)
@inline (::Type{FA})(x::Tuple) where {FA <: FieldArray} = construct_type(FA, x)(x...)

function construct_type(::Type{FA}, x) where {FA <: FieldArray}
has_size(FA) || error("$FA has no static size!")
length_match_size(FA, x)
return adapt_eltype(FA, x)
end

@propagate_inbounds getindex(a::FieldArray, i::Int) = getfield(a, i)
Expand Down
29 changes: 4 additions & 25 deletions src/MArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,23 @@ the compiler (the element type may optionally also be specified).
mutable struct MArray{S <: Tuple, T, N, L} <: StaticArray{S, T, N}
data::NTuple{L,T}

function MArray{S,T,N,L}(x::NTuple{L,T}) where {S,T,N,L}
function MArray{S,T,N,L}(x::NTuple{L,T}) where {S<:Tuple,T,N,L}
check_array_parameters(S, T, Val{N}, Val{L})
new{S,T,N,L}(x)
end

function MArray{S,T,N,L}(x::NTuple{L,Any}) where {S,T,N,L}
function MArray{S,T,N,L}(x::NTuple{L,Any}) where {S<:Tuple,T,N,L}
check_array_parameters(S, T, Val{N}, Val{L})
new{S,T,N,L}(convert_ntuple(T, x))
end

function MArray{S,T,N,L}(::UndefInitializer) where {S,T,N,L}
function MArray{S,T,N,L}(::UndefInitializer) where {S<:Tuple,T,N,L}
check_array_parameters(S, T, Val{N}, Val{L})
new{S,T,N,L}()
end
end

@generated function (::Type{MArray{S,T,N}})(x::Tuple) where {S,T,N}
return quote
$(Expr(:meta, :inline))
MArray{S,T,N,$(tuple_prod(S))}(x)
end
end

@generated function (::Type{MArray{S,T}})(x::Tuple) where {S,T}
return quote
$(Expr(:meta, :inline))
MArray{S,T,$(tuple_length(S)),$(tuple_prod(S))}(x)
end
end

@generated function (::Type{MArray{S}})(x::T) where {S, T <: Tuple}
return quote
$(Expr(:meta, :inline))
MArray{S,promote_tuple_eltype(T),$(tuple_length(S)),$(tuple_prod(S))}(x)
end
end
@inline MArray{S,T,N}(x::Tuple) where {S<:Tuple,T,N} = MArray{S,T,N,tuple_prod(S)}(x)

@generated function (::Type{MArray{S,T,N}})(::UndefInitializer) where {S,T,N}
return quote
Expand All @@ -71,8 +52,6 @@ end
end
end

@inline MArray(a::StaticArray{S,T}) where {S<:Tuple,T} = MArray{S,T}(Tuple(a))

####################
## MArray methods ##
####################
Expand Down
30 changes: 0 additions & 30 deletions src/MMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,13 @@ unknown to the compiler (the element type may optionally also be specified).
"""
const MMatrix{S1, S2, T, L} = MArray{Tuple{S1, S2}, T, 2, L}

@generated function (::Type{MMatrix{S1}})(x::NTuple{L}) where {S1,L}
S2 = div(L, S1)
if S1*S2 != L
throw(DimensionMismatch("Incorrect matrix sizes. $S1 does not divide $L elements"))
end
return quote
$(Expr(:meta, :inline))
T = eltype(typeof(x))
MMatrix{S1, $S2, T, L}(x)
end
end

@generated function (::Type{MMatrix{S1,S2}})(x::NTuple{L}) where {S1,S2,L}
return quote
$(Expr(:meta, :inline))
T = eltype(typeof(x))
MMatrix{S1, S2, T, L}(x)
end
end

@generated function (::Type{MMatrix{S1,S2,T}})(x::NTuple{L}) where {S1,S2,T,L}
return quote
$(Expr(:meta, :inline))
MMatrix{S1, S2, T, L}(x)
end
end

@generated function (::Type{MMatrix{S1,S2,T}})(::UndefInitializer) where {S1,S2,T}
return quote
$(Expr(:meta, :inline))
MMatrix{S1, S2, T, $(S1*S2)}(undef)
end
end

@inline convert(::Type{MMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = MMatrix{S1,S2,T}(Tuple(a))
@inline MMatrix(a::StaticMatrix{N,M,T}) where {N,M,T} = MMatrix{N,M,T}(Tuple(a))

# Some more advanced constructor-like functions
@inline one(::Type{MMatrix{N}}) where {N} = one(MMatrix{N,N})

Expand Down
5 changes: 0 additions & 5 deletions src/MVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ compiler (the element type may optionally also be specified).
"""
const MVector{S, T} = MArray{Tuple{S}, T, 1, S}

@inline MVector(a::StaticVector{N,T}) where {N,T} = MVector{N,T}(a)
@inline MVector(x::NTuple{S,Any}) where {S} = MVector{S}(x)
@inline MVector{S}(x::NTuple{S,T}) where {S, T} = MVector{S, T}(x)
@inline MVector{S}(x::NTuple{S,Any}) where {S} = MVector{S, promote_tuple_eltype(typeof(x))}(x)

# Some more advanced constructor-like functions
@inline zeros(::Type{MVector{N}}) where {N} = zeros(MVector{N,Float64})
@inline ones(::Type{MVector{N}}) where {N} = ones(MVector{N,Float64})
Expand Down
28 changes: 3 additions & 25 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,18 @@ compiler (the element type may optionally also be specified).
struct SArray{S <: Tuple, T, N, L} <: StaticArray{S, T, N}
data::NTuple{L,T}

function SArray{S, T, N, L}(x::NTuple{L,T}) where {S, T, N, L}
function SArray{S, T, N, L}(x::NTuple{L,T}) where {S<:Tuple, T, N, L}
check_array_parameters(S, T, Val{N}, Val{L})
new{S, T, N, L}(x)
end

function SArray{S, T, N, L}(x::NTuple{L,Any}) where {S, T, N, L}
function SArray{S, T, N, L}(x::NTuple{L,Any}) where {S<:Tuple, T, N, L}
check_array_parameters(S, T, Val{N}, Val{L})
new{S, T, N, L}(convert_ntuple(T, x))
end
end

@generated function (::Type{SArray{S, T, N}})(x::Tuple) where {S <: Tuple, T, N}
return quote
@_inline_meta
SArray{S, T, N, $(tuple_prod(S))}(x)
end
end

@generated function (::Type{SArray{S, T}})(x::Tuple) where {S <: Tuple, T}
return quote
@_inline_meta
SArray{S, T, $(tuple_length(S)), $(tuple_prod(S))}(x)
end
end

@generated function (::Type{SArray{S}})(x::T) where {S <: Tuple, T <: Tuple}
return quote
@_inline_meta
SArray{S, promote_tuple_eltype(T), $(tuple_length(S)), $(tuple_prod(S))}(x)
end
end

@inline SArray{S,T,N}(x::Tuple) where {S<:Tuple,T,N} = SArray{S,T,N,tuple_prod(S)}(x)

@noinline function generator_too_short_error(inds::CartesianIndices, i::CartesianIndex)
error("Generator produced too few elements: Expected exactly $(shape_string(inds)) elements, but generator stopped at $(shape_string(i))")
Expand Down Expand Up @@ -106,8 +86,6 @@ sacollect
@inline (::Type{SA})(gen::Base.Generator) where {SA <: StaticArray} =
sacollect(SA, gen)

@inline SArray(a::StaticArray{S,T}) where {S<:Tuple,T} = SArray{S,T}(Tuple(a))

####################
## SArray methods ##
####################
Expand Down
9 changes: 4 additions & 5 deletions src/SHermitianCompact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ lowertriangletype(::Type{SHermitianCompact{N}}) where {N} = SVector{triangularnu
end

@generated function SHermitianCompact{N, T, L}(a::Tuple) where {N, T, L}
_check_hermitian_parameters(Val(N), Val(L))
expr = Vector{Expr}(undef, L)
i = 0
for col = 1 : N, row = col : N
Expand All @@ -77,15 +78,13 @@ end
SHermitianCompact{N, T, L}(a)
end

@inline SHermitianCompact{N}(a::Tuple) where {N} = SHermitianCompact{N, promote_tuple_eltype(a)}(a)
@inline SHermitianCompact{N}(a::NTuple{M, T}) where {N, T, M} = SHermitianCompact{N, T}(a)
@inline SHermitianCompact(a::StaticMatrix{N, N, T}) where {N, T} = SHermitianCompact{N, T}(a)

@inline (::Type{SSC})(a::SHermitianCompact) where {SSC <: SHermitianCompact} = SSC(a.lowertriangle)
@inline (::Type{SSC})(a::SSC) where {SSC <: SHermitianCompact} = a

@inline (::Type{SSC})(a::AbstractVector) where {SSC <: SHermitianCompact} = SSC(convert(lowertriangletype(SSC), a))

# disambiguation
@inline (::Type{SSC})(a::StaticArray{<:Tuple,<:Any,1}) where {SSC <: SHermitianCompact} = SSC(convert(SVector, a))

@generated function _hermitian_compact_indices(::Val{N}) where N
# Returns a Tuple{Pair{Int, Bool}} I such that for linear index i,
# * I[i][1] is the index into the lowertriangle field of an SHermitianCompact{N};
Expand Down
48 changes: 0 additions & 48 deletions src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,13 @@ unknown to the compiler (the element type may optionally also be specified).
"""
const SMatrix{S1, S2, T, L} = SArray{Tuple{S1, S2}, T, 2, L}

@generated function SMatrix{S1}(x::NTuple{L,Any}) where {S1,L}
S2 = div(L, S1)
if S1*S2 != L
throw(DimensionMismatch("Incorrect matrix sizes. $S1 does not divide $L elements"))
end

return quote
$(Expr(:meta, :inline))
T = promote_tuple_eltype(typeof(x))
SMatrix{S1, $S2, T, L}(x)
end
end

@generated function SMatrix{S1,S2}(x::NTuple{L,Any}) where {S1,S2,L}
return quote
$(Expr(:meta, :inline))
T = promote_tuple_eltype(typeof(x))
SMatrix{S1, S2, T, L}(x)
end
end
SMatrixNoType{S1, S2, L, T} = SMatrix{S1, S2, T, L}
@generated function SMatrixNoType{S1, S2, L}(x::NTuple{L,Any}) where {S1,S2,L}
return quote
$(Expr(:meta, :inline))
T = promote_tuple_eltype(typeof(x))
SMatrix{S1, S2, T, L}(x)
end
end

@generated function SMatrix{S1,S2,T}(x::NTuple{L,Any}) where {S1,S2,T,L}
return quote
$(Expr(:meta, :inline))
SMatrix{S1, S2, T, L}(x)
end
end

@inline SMatrix{M, N, T}(gen::Base.Generator) where {M, N, T} =
sacollect(SMatrix{M, N, T}, gen)
@inline SMatrix{M, N}(gen::Base.Generator) where {M, N} =
sacollect(SMatrix{M, N}, gen)

@inline convert(::Type{SMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = SMatrix{S1,S2,T}(Tuple(a))
@inline SMatrix(a::StaticMatrix{S1, S2, T}) where {S1, S2, T} = SMatrix{S1, S2, T}(Tuple(a))

# Some more advanced constructor-like functions
@inline one(::Type{SMatrix{N}}) where {N} = one(SMatrix{N,N})

#####################
## SMatrix methods ##
#####################

@propagate_inbounds function getindex(v::SMatrix, i::Int)
v.data[i]
end

function check_matrix_size(x::Tuple, T = :S)
if length(x) > 2
all(isone, x[3:end]) || error("Bad input for @$(T)Matrix, must be matrix like.")
Expand Down
13 changes: 0 additions & 13 deletions src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,6 @@ compiler (the element type may optionally also be specified).
"""
const SVector{S, T} = SArray{Tuple{S}, T, 1, S}

@inline SVector(a::StaticVector{N,T}) where {N,T} = SVector{N,T}(a)
@inline SVector(x::NTuple{S,Any}) where {S} = SVector{S}(x)
@inline SVector{S}(x::NTuple{S,T}) where {S, T} = SVector{S,T}(x)
@inline SVector{S}(x::T) where {S, T <: Tuple} = SVector{S,promote_tuple_eltype(T)}(x)

@inline SVector{N, T}(gen::Base.Generator) where {N, T} =
sacollect(SVector{N, T}, gen)
@inline SVector{N}(gen::Base.Generator) where {N} =
sacollect(SVector{N}, gen)

# conversion from AbstractVector / AbstractArray (better inference than default)
#@inline convert{S,T}(::Type{SVector{S}}, a::AbstractArray{T}) = SVector{S,T}((a...))

# Some more advanced constructor-like functions
@inline zeros(::Type{SVector{N}}) where {N} = zeros(SVector{N,Float64})
@inline ones(::Type{SVector{N}}) where {N} = ones(SVector{N,Float64})
Expand Down
8 changes: 3 additions & 5 deletions src/Scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@ Construct a statically-sized 0-dimensional array that contains a single element,
"""
const Scalar{T} = SArray{Tuple{},T,0,1}

@inline Scalar(x::Tuple{T}) where {T} = Scalar{T}(x[1])
@inline Scalar(a::AbstractArray) = Scalar{typeof(a)}((a,))
@inline Scalar(a::StaticArray) = Scalar{typeof(a)}((a,)) # disambiguation

@inline Scalar(a::AbstractScalar) = Scalar{eltype(a)}((a[],)) # Do we want this to convert or wrap?
@inline function convert(::Type{SA}, a::AbstractArray) where {SA <: Scalar}
return SA((a[],))
end
@inline convert(::Type{SA}, sa::SA) where {SA <: Scalar} = sa
@inline Scalar(a::StaticScalar) = Scalar{eltype(a)}((a[],)) # disambiguation

@propagate_inbounds function getindex(v::Scalar, i::Int)
@boundscheck if i != 1
Expand Down
Loading

0 comments on commit 8ca11f8

Please sign in to comment.