diff --git a/src/AbstractTensor.jl b/src/AbstractTensor.jl index 9290fcee..b2d38dd8 100644 --- a/src/AbstractTensor.jl +++ b/src/AbstractTensor.jl @@ -11,6 +11,7 @@ const AbstractMat{m, n, T} = AbstractTensor{Tuple{m, n}, T, 2} const AbstractMatLike{T} = Union{ AbstractMat{<: Any, <: Any, T}, AbstractSymmetricSecondOrderTensor{<: Any, T}, + Transpose{T, <: AbstractVec{<: Any, T}}, } const AbstractVecOrMatLike{T} = Union{AbstractVec{<: Any, T}, AbstractMatLike{T}} @@ -77,3 +78,6 @@ end @inbounds SArray{Tuple{$(size(x)...)}}($(exps...)) end end +function StaticArrays.SArray(x::Transpose{<: T, <: AbstractVec{dim, T}}) where {dim, T} + transpose(SArray(parent(x))) +end diff --git a/src/Space.jl b/src/Space.jl index 9b8460e8..60255a87 100644 --- a/src/Space.jl +++ b/src/Space.jl @@ -108,7 +108,7 @@ end otimes(x::Space) = x otimes(x::Space, y::Space) = Space(Tuple(x)..., Tuple(y)...) otimes(x::Space, y::Space, z::Space...) = otimes(otimes(x, y), z...) -dot(x::Space, y::Space) = contract(x, y, Val(1)) +contract1(x::Space, y::Space) = contract(x, y, Val(1)) contract2(x::Space, y::Space) = contract(x, y, Val(2)) # promote_space diff --git a/src/Tensor.jl b/src/Tensor.jl index b15044b6..dd7756d8 100644 --- a/src/Tensor.jl +++ b/src/Tensor.jl @@ -72,6 +72,9 @@ end @inline function Tensor(A::StaticArray{S, T}) where {S, T} Tensor{S, T}(Tuple(A)) end +@inline function Tensor(A::LinearAlgebra.Transpose{<: Any, <: StaticArray}) + transpose(Tensor(parent(A))) +end ## for aliases @inline function Tensor{S, T, N}(data::Tuple{Vararg{Any, L}}) where {S, T, N, L} diff --git a/src/Tensorial.jl b/src/Tensorial.jl index 06329b4e..2bb355fc 100644 --- a/src/Tensorial.jl +++ b/src/Tensorial.jl @@ -51,7 +51,6 @@ export contract2, contract3, otimes, - dotdot, symmetric, minorsymmetric, skew, @@ -104,8 +103,8 @@ include("abstractarray.jl") include("quaternion.jl") -const ⊗ = otimes const ⊡ = contract2 +const ⊗ = otimes @deprecate contraction contract true @deprecate double_contraction contract2 true diff --git a/src/ops.jl b/src/ops.jl index ae69cf14..6144b56c 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -38,11 +38,11 @@ end @inline contract2(x::AbstractSquareTensor, y::UniformScaling) = tr(x) * y.λ @inline contract2(x::UniformScaling, y::AbstractSquareTensor) = x.λ * tr(y) -# error for standard multiplications -error_multiply() = error("use `⋅` (`\\cdot`) for single contraction and `⊡` (`\\boxdot`) for double contraction instead of `*`") -Base.:*(::AbstractTensor, ::AbstractTensor) = error_multiply() -Base.:*(::AbstractTensor, ::UniformScaling) = error_multiply() -Base.:*(::UniformScaling, ::AbstractTensor) = error_multiply() +# multiplication +@inline Base.:*(x::AbstractVecOrMatLike, y::AbstractVecOrMatLike) = Tensor(SArray(x) * SArray(y)) +@inline Base.:*(x::LinearAlgebra.Transpose{T, <: AbstractVec{<: Any, T}}, y::AbstractVec{<: Any, U}) where {T <: Real, U <: Real} = parent(x) ⋅ y +@inline Base.:*(x::AbstractVecOrMatLike, I::UniformScaling) = x * I.λ +@inline Base.:*(I::UniformScaling, x::AbstractVecOrMatLike) = I.λ * x """ contract(x, y, ::Val{N}) @@ -65,9 +65,9 @@ julia> A = contract(B, C, Val(2)) Following symbols are also available for specific contractions: -- `x ⊗ y` (where `⊗` can be typed by `\\otimes`): `contract(x, y, Val(0))` -- `x ⋅ y` (where `⋅` can be typed by `\\cdot`): `contract(x, y, Val(1))` -- `x ⊡ y` (where `⊡` can be typed by `\\boxdot`): `contract(x, y, Val(2))` +- `x ⊗ y` (where `⊗` can be typed by `\\otimes` ): `contract(x, y, Val(0))` +- `x ⋅ y` (where `⋅` can be typed by `\\cdot` ): `contract(x, y, Val(1))` +- `x ⊡ y` (where `⊡` can be typed by `\\boxdot` ): `contract(x, y, Val(2))` """ @generated function contract(t1::AbstractTensor, t2::AbstractTensor, ::Val{N}) where {N} S1 = Space(t1) @@ -212,12 +212,10 @@ otimes(n::Int) = OTimes{n}() end """ - dot(x::AbstractTensor, y::AbstractTensor) - x ⋅ y + contract1(x::AbstractTensor, y::AbstractTensor) -Compute dot product such as ``a = x_i y_i``. +Compute single contraction such as ``a = x_i y_i``. This is equivalent to [`contract(::AbstractTensor, ::AbstractTensor, Val(1))`](@ref). -`x ⋅ y` (where `⋅` can be typed by `\\cdot`) is a synonym for `dot(x, y)`. # Examples ```jldoctest @@ -237,7 +235,7 @@ julia> a = x ⋅ y 0.5715585109976284 ``` """ -@inline dot(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(1)) # dot, ⋅ +@inline contract1(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(1)) @inline contract2(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(2)) # ⊡ @inline contract3(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(3)) @@ -259,15 +257,8 @@ julia> norm(x) ``` """ @inline norm(x::AbstractTensor) = sqrt(contract(x, x, Val(ndims(x)))) - @inline normalize(x::AbstractTensor) = x / norm(x) -# v_k * S_ikjl * u_l -@inline function dotdot(v1::Vec{dim}, S::SymmetricFourthOrderTensor{dim}, v2::Vec{dim}) where {dim} - S′ = SymmetricFourthOrderTensor{dim}((i,j,k,l) -> @inbounds S[j,i,l,k]) - v1 ⋅ S′ ⋅ v2 -end - """ tr(::AbstractSecondOrderTensor) tr(::AbstractSymmetricSecondOrderTensor) @@ -418,7 +409,6 @@ end @inline transpose(x::AbstractTensor{Tuple{@Symmetry{dim, dim}}}) where {dim} = x @inline transpose(x::AbstractTensor{Tuple{m, n}}) where {m, n} = Tensor{Tuple{n, m}}((i,j) -> @inbounds x[j,i]) @inline adjoint(x::AbstractTensor) = transpose(x) -@inline adjoint(::AbstractVec) = throw(ArgumentError("adjoint for `AbstractVec` is not allowed")) # det @generated function extract_vecs(x::AbstractSquareTensor{dim}) where {dim} @@ -454,15 +444,11 @@ end end """ - cross(x::Vec{3}, y::Vec{3}) -> Vec{3} - cross(x::Vec{2}, y::Vec{2}) -> Vec{3} - cross(x::Vec{1}, y::Vec{1}) -> Vec{3} + cross(x::Vec, y::Vec) x × y Compute the cross product between two vectors. -The vectors are expanded to 3D frist for dimensions 1 and 2. -The infix operator `×` (written `\\times`) can also be used. -`x × y` (where `×` can be typed by `\\times`) is a synonym for `cross(x, y)`. +The infix operation `x × y` (where `×` can be typed by `\\times`) is a synonym for `cross(x, y)`. # Examples ```jldoctest @@ -485,16 +471,14 @@ julia> x × y -0.37588028973385323 ``` """ -@inline cross(x::Vec{1, T1}, y::Vec{1, T2}) where {T1, T2} = zero(Vec{3, promote_type(T1, T2)}) -@inline function cross(x::Vec{2, T1}, y::Vec{2, T2}) where {T1, T2} - z = zero(promote_type(T1, T2)) - @inbounds Vec(z, z, x[1]*y[2] - x[2]*y[1]) -end @inline function cross(x::Vec{3}, y::Vec{3}) @inbounds Vec(x[2]*y[3] - x[3]*y[2], x[3]*y[1] - x[1]*y[3], x[1]*y[2] - x[2]*y[1]) end +@inline function cross(x::Vec{2}, y::Vec{2}) + @inbounds x[1]*y[2] - x[2]*y[1] +end # power @inline Base.literal_pow(::typeof(^), x::AbstractSquareTensor, ::Val{-1}) = inv(x) @@ -509,7 +493,7 @@ end y end ## helper functions -@inline _powdot(x::AbstractSecondOrderTensor, y::AbstractSecondOrderTensor) = dot(x, y) +@inline _powdot(x::AbstractSecondOrderTensor, y::AbstractSecondOrderTensor) = contract1(x, y) @inline function _powdot(x::AbstractSymmetricSecondOrderTensor{dim}, y::AbstractSymmetricSecondOrderTensor{dim}) where {dim} contract(SymmetricSecondOrderTensor{dim}, x, y, Val(2), Val(1)) end @@ -693,7 +677,7 @@ function rotmat(pair::Pair{Vec{dim, T}, Vec{dim, T}})::Mat{dim, dim, T} where {d # https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d/2672702#2672702 a = pair.first b = pair.second - dot(a, a) ≈ dot(b, b) || throw(ArgumentError("the norms of two vectors must be the same")) + contract1(a, a) ≈ contract1(b, b) || throw(ArgumentError("the norms of two vectors must be the same")) a == b && return one(Mat{dim, dim, T}) a == -b && return -one(Mat{dim, dim, T}) c = a + b diff --git a/src/quaternion.jl b/src/quaternion.jl index 74070afc..39f2e6e3 100644 --- a/src/quaternion.jl +++ b/src/quaternion.jl @@ -176,7 +176,7 @@ julia> rotate(v, quaternion(π/4, Vec(0,0,1))) @inline rotate(v::Vec, q::Quaternion) = (q * v / q).vector @inline Base.conj(q::Quaternion) = Quaternion(q.scalar, -q.vector) -@inline Base.abs2(q::Quaternion) = (v = Vec(q); dot(v, v)) +@inline Base.abs2(q::Quaternion) = (v = Vec(q); contract1(v, v)) @inline Base.abs(q::Quaternion) = sqrt(abs2(q)) @inline norm(q::Quaternion) = abs(q) @inline inv(q::Quaternion) = conj(q) / abs2(q) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 8530a53a..975c85e4 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -39,15 +39,19 @@ end @testset "vcat/hcat" begin @test @inferred(vcat(Vec(1,2,3))) === Vec(1,2,3) + @test @inferred(vcat(Vec(1,2,3)')) === Vec(1,2,3)' @test @inferred(vcat(Mat{1,3}(1,2,3))) === Mat{1,3}(1,2,3) @test @inferred(hcat(Vec(1,2,3))) === Mat{3,1}(1,2,3) + @test @inferred(hcat(Vec(1,2,3)')) === Vec(1,2,3)' @test @inferred(hcat(Mat{3,1}(1,2,3))) === Mat{3,1}(1,2,3) @test @inferred(vcat(Vec(1,2,3), Vec(4,5,6))) === Vec(1,2,3,4,5,6) @test @inferred(vcat(Vec(1,2,3), Mat{3,1}(4,5,6))) === Mat{6,1}(1,2,3,4,5,6) @test @inferred(vcat(Mat{1,3}(1,2,3), Mat{1,3}(4,5,6))) === @Mat [1 2 3; 4 5 6] - @test @inferred(vcat(symmetric(Mat{2,2}(1,2,2,3), :U), Mat{1,2}(4,5))) === @Mat [1 2; 2 3; 4 5] + @test @inferred(vcat(Mat{1,3}(1,2,3), Vec(4,5,6)')) === @Mat [1 2 3; 4 5 6] + @test @inferred(vcat(symmetric(Mat{2,2}(1,2,2,3), :U), Vec(4,5)')) === @Mat [1 2; 2 3; 4 5] @test @inferred(hcat(Vec(1,2,3), Vec(4,5,6))) === Mat{3,2}(1,2,3,4,5,6) + @test @inferred(hcat(Vec(1,2,3)', Vec(4,5,6)')) === Mat{1,6}(1,2,3,4,5,6) @test @inferred(hcat(Vec(1,2,3), Mat{3,2}(4,5,6,7,8,9))) === Mat{3,3}(1,2,3,4,5,6,7,8,9) @test @inferred(hcat(Mat{3,1}(1,2,3), Mat{3,2}(4,5,6,7,8,9))) === Mat{3,3}(1,2,3,4,5,6,7,8,9) @test @inferred(hcat(symmetric(Mat{2,2}(1,2,2,3), :U), Vec(4,5))) === @Mat [1 2 4; 2 3 5] @@ -85,7 +89,7 @@ for T in (Float32, Float64) x11 = rand(Mat{2,2,T}) x12 = rand(Vec{2,T}) - x21 = rand(Mat{1,2,T}) + x21 = rand(Vec{2,T})' x22 = rand(T) @test (@inferred f_2_2(x11, x12, x21, x22))::Mat{3,3,T} == f_2_2(Array.((x11, x12, x21))..., x22) x11 = rand(SymmetricSecondOrderTensor{3,T}) @@ -95,11 +99,11 @@ x23 = rand(Mat{2,3,T}) @test (@inferred f_2_3(x11, x12, x21, x22, x23))::Mat{5,6,T} == f_2_3(Array.((x11, x12, x21, x22, x23))...) x11 = rand(T) - x12 = rand(Mat{1,3,T}) + x12 = rand(Vec{3,T})' x21 = rand(Vec{3,T}) x22_11 = rand(SymmetricSecondOrderTensor{2,T}) x22_12 = rand(Vec{2,T}) - x22_21 = rand(Mat{1,2,T}) + x22_21 = rand(Vec{2,T})' x22_22 = rand(T) @test (@inferred f_2_2_recurse(x11,x12,x21,x22,x22_11,x22_12,x22_21,x22_22))::Mat{4,4,T} == f_2_2_recurse(x11,Array.((x12,x21,x22,x22_11,x22_12,x22_21))...,x22_22) end diff --git a/test/ops.jl b/test/ops.jl index fcb7f9ac..f32e576f 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -30,9 +30,16 @@ @test (@inferred a * x)::Vec{2, T} == a * Array(x) @test (@inferred x * a)::Vec{2, T} == Array(x) * a @test (@inferred x / a)::Vec{2, T} == Array(x) / a - # bad operations - @test_throws Exception x * y - @test_throws Exception y * x + # multiplication + x = rand(Vec{2,T}) + y = rand(Mat{2,2,T}) + z = Vec(1,2) + @test (@inferred x' * y)::Transpose{T, Vec{2,T}} ≈ Array(x)' * Array(y) + @test (@inferred y * x)::Vec{2,T} ≈ Array(y) * Array(x) + @test (@inferred y' * x)::Vec{2,T} ≈ Array(y)' * Array(x) + @test (@inferred x * x')::Mat{2,2,T} ≈ Array(x) * Array(x)' + @test (@inferred x'x)::T ≈ Array(x)'Array(x) + @test (@inferred x'z)::T ≈ Array(x)'Array(z) end end @@ -129,15 +136,6 @@ end @test (@inferred ⊗(x, y, x, y))::Tensor{Tuple{3,2,3,2}, T} ≈ x ⊗ y ⊗ x ⊗ y end end - @testset "dotdot" begin - for T in (Float32, Float64) - x = rand(Vec{3, T}) - y = rand(Vec{3, T}) - S = rand(SymmetricFourthOrderTensor{3, T}) - A = FourthOrderTensor{3, T}((i,j,k,l) -> S[i,k,j,l]) - @test (@inferred dotdot(x, S, y))::Tensor{Tuple{3,3}, T} ≈ A ⊡ (x ⊗ y) - end - end @testset "tr" begin for T in (Float32, Float64) x = rand(SecondOrderTensor{3, T}) @@ -205,21 +203,12 @@ end end end @testset "cross" begin - for T in (Float32, Float64), dim in 1:3 + for T in (Float32, Float64), dim in 2:3 x = rand(Vec{dim, T}) y = rand(Vec{dim, T}) - @test (@inferred x × x)::Vec{3, T} ≈ zero(Vec{3, T}) @test x × y ≈ -y × x - if dim == 2 - a = Vec{2, T}(1,0) - b = Vec{2, T}(0,1) - @test (@inferred a × b)::Vec{3, T} ≈ Vec{3, T}(0,0,1) - end - if dim == 3 - a = Vec{3, T}(1,0,0) - b = Vec{3, T}(0,1,0) - @test (@inferred a × b)::Vec{3, T} ≈ Vec{3, T}(0,0,1) - end + dim == 2 && @test (@inferred x × x)::T ≈ zero(T) + dim == 3 && @test (@inferred x × x)::Vec{3, T} ≈ zero(Vec{3, T}) end end @testset "pow" begin @@ -390,13 +379,12 @@ end @test (@inferred I ⋅ v)::Vec{3, T} == one(x) ⋅ v @test (@inferred I ⊡ x)::T == one(x) ⊡ x @test (@inferred x ⊡ I)::T == x ⊡ one(x) - # wrong input - @test_throws Exception x * I - @test_throws Exception y * I - @test_throws Exception v * I - @test_throws Exception I * x - @test_throws Exception I * y - @test_throws Exception I * v + # multiplication + @test (@inferred x * I)::typeof(x) == x + @test (@inferred y * I)::typeof(y) == y + @test (@inferred I * x)::typeof(x) == x + @test (@inferred I * y)::typeof(y) == y + @test (@inferred I * v)::typeof(v) == v end end end diff --git a/test/runtests.jl b/test/runtests.jl index f8843b88..53bd9d7d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using Tensorial using Test, Random -using LinearAlgebra: Symmetric, Eigen +using LinearAlgebra: Symmetric, Eigen, Transpose using StaticArrays: SArray, SVector, SOneTo, SUnitRange using TensorOperations: @tensor