From 26d0e05105bba10c3ac75ae2b490c88b9b93090d Mon Sep 17 00:00:00 2001 From: Chris Foster Date: Sat, 15 Feb 2020 15:00:32 +1000 Subject: [PATCH] Default implementation of similar_type for FieldArray (#731) This should free users from defining similar_type in many cases, though they'll still need to do so when their type is parametric on the eltype. (There's no general way for us to know how to reparameterize such user-defined types.) --- src/FieldArray.jl | 57 +++++++++++++++++++++++++++++++++------------ test/FieldMatrix.jl | 2 +- test/FieldVector.jl | 20 +++++++++++++++- 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/src/FieldArray.jl b/src/FieldArray.jl index 30e2ea25..f479b845 100644 --- a/src/FieldArray.jl +++ b/src/FieldArray.jl @@ -6,7 +6,15 @@ will automatically define `getindex` and `setindex!` appropriately. An immutable `FieldArray` will be as performant as an `SArray` of similar length and element type, while a mutable `FieldArray` will behave similarly to an `MArray`. -For example: +Note that you must define the fields of any `FieldArray` subtype in column major order. If you +want to use an alternative ordering you will need to pay special attention in providing your +own definitions of `getindex`, `setindex!` and tuple conversion. + +If you define a `FieldArray` which is parametric on the element type you should +consider defining `similar_type` as in the `FieldVector` example. + + +# Example struct Stiffness <: FieldArray{Tuple{2,2,2,2}, Float64, 4} xxxx::Float64 @@ -26,10 +34,6 @@ For example: xyyy::Float64 yyyy::Float64 end - - Note that you must define the fields of any `FieldArray` subtype in column major order. If you - want to use an alternative ordering you will need to pay special attention in providing your - own definitions of `getindex`, `setindex!` and tuple conversion. """ abstract type FieldArray{N, T, D} <: StaticArray{N, T, D} end @@ -41,7 +45,13 @@ will automatically define `getindex` and `setindex!` appropriately. An immutable `FieldMatrix` will be as performant as an `SMatrix` of similar length and element type, while a mutable `FieldMatrix` will behave similarly to an `MMatrix`. -For example: +Note that the fields of any subtype of `FieldMatrix` must be defined in column +major order unless you are willing to implement your own `getindex`. + +If you define a `FieldMatrix` which is parametric on the element type you +should consider defining `similar_type` as in the `FieldVector` example. + +# Example struct Stress <: FieldMatrix{3, 3, Float64} xx::Float64 @@ -67,13 +77,12 @@ For example: 2.0 5.0 8.0 3.0 6.0 9.0 - will give you the transpose of what the multi-argument formatting suggests. For clarity, you may consider using the alternative - sigma = Stress(@SArray[1.0 2.0 3.0; - 4.0 5.0 6.0; - 7.0 8.0 9.0]) + sigma = Stress(SA[1.0 2.0 3.0; + 4.0 5.0 6.0; + 7.0 8.0 9.0]) """ abstract type FieldMatrix{N1, N2, T} <: FieldArray{Tuple{N1, N2}, T, 2} end @@ -85,13 +94,19 @@ will automatically define `getindex` and `setindex!` appropriately. An immutable `FieldVector` will be as performant as an `SVector` of similar length and element type, while a mutable `FieldVector` will behave similarly to an `MVector`. -For example: +If you define a `FieldVector` which is parametric on the element type you +should consider defining `similar_type` to preserve your array type through +array operations as in the example below. + +# Example - struct Point3D <: FieldVector{3, Float64} - x::Float64 - y::Float64 - z::Float64 + struct Vec3D{T} <: FieldVector{3, T} + x::T + y::T + z::T end + + StaticArrays.similar_type(::Type{<:Vec3D}, ::Type{T}, s::Size{(3,)}) where {T} = Vec3D{T} """ abstract type FieldVector{N, T} <: FieldArray{Tuple{N}, T, 1} end @@ -109,3 +124,15 @@ end Base.cconvert(::Type{<:Ptr}, a::FieldArray) = Base.RefValue(a) Base.unsafe_convert(::Type{Ptr{T}}, m::Base.RefValue{FA}) where {N,T,D,FA<:FieldArray{N,T,D}} = Ptr{T}(Base.unsafe_convert(Ptr{FA}, m)) + +# We can automatically preserve FieldArrays in array operations which do not +# change their eltype or Size. This should cover all non-parametric FieldArray, +# but for those which are parametric on the eltype the user will still need to +# overload similar_type themselves. +similar_type(::Type{A}, ::Type{T}, S::Size) where {N, T, A<:FieldArray{N, T}} = + _fieldarray_similar_type(A, T, S, Size(A)) + +# Extra layer of dispatch to match NewSize and OldSize +_fieldarray_similar_type(A, T, NewSize::S, OldSize::S) where {S} = A +_fieldarray_similar_type(A, T, NewSize, OldSize) = + default_similar_type(T, NewSize, length_val(NewSize)) diff --git a/test/FieldMatrix.jl b/test/FieldMatrix.jl index 7eb6ef35..c1649f73 100644 --- a/test/FieldMatrix.jl +++ b/test/FieldMatrix.jl @@ -13,7 +13,7 @@ zz::Float64 end - StaticArrays.similar_type(::Type{Tensor3x3}, ::Type{Float64}, s::Size{(3,3)}) = Tensor3x3 + # No need to define similar_type for non-parametric FieldMatrix (#792) end) p = Tensor3x3(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0) diff --git a/test/FieldVector.jl b/test/FieldVector.jl index 72cf249e..313c240d 100644 --- a/test/FieldVector.jl +++ b/test/FieldVector.jl @@ -7,7 +7,7 @@ z::Float64 end - StaticArrays.similar_type(::Type{Point3D}, ::Type{Float64}, s::Size{(3,)}) = Point3D + # No need to define similar_type for non-parametric FieldVector (#792) end) p = Point3D(1.0, 2.0, 3.0) @@ -29,6 +29,8 @@ 0.0 0.0 2.0] @test @inferred(m*p) === Point3D(2.0, 4.0, 6.0) + @test @inferred(SA[2.0 0.0 0.0; + 0.0 2.0 0.0]*p) === SVector((2.0, 4.0)) @test @inferred(similar_type(Point3D)) == Point3D @test @inferred(similar_type(Point3D, Float64)) == Point3D @@ -92,4 +94,20 @@ @test length(x[1]) == 2 @test x.x == (1, 2) end + + @testset "FieldVector with parametric eltype and without similar_type" begin + eval(quote + struct FVT{T} <: FieldVector{2, T} + x::T + y::T + end + + # No similar_type defined - test fallback codepath + end) + + @test @inferred(similar_type(FVT{Float64}, Float32)) == SVector{2,Float32} # Fallback code path + @test @inferred(similar_type(FVT{Float64}, Size(2))) == FVT{Float64} + @test @inferred(similar_type(FVT{Float64}, Size(3))) == SVector{3,Float64} + @test @inferred(similar_type(FVT{Float64}, Float32, Size(3))) == SVector{3,Float32} + end end