Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it more like StaticArrays.jl #222

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/AbstractTensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const AbstractMat{m, n, T} = AbstractTensor{Tuple{m, n}, T, 2}
const AbstractMatLike{T} = Union{
AbstractMat{<: Any, <: Any, T},
AbstractSymmetricSecondOrderTensor{<: Any, T},
Transpose{T, <: AbstractVec{<: Any, T}},
}
const AbstractVecOrMatLike{T} = Union{AbstractVec{<: Any, T}, AbstractMatLike{T}}

Expand Down Expand Up @@ -77,3 +78,6 @@ end
@inbounds SArray{Tuple{$(size(x)...)}}($(exps...))
end
end
function StaticArrays.SArray(x::Transpose{<: T, <: AbstractVec{dim, T}}) where {dim, T}
transpose(SArray(parent(x)))
end
2 changes: 1 addition & 1 deletion src/Space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
otimes(x::Space) = x
otimes(x::Space, y::Space) = Space(Tuple(x)..., Tuple(y)...)
otimes(x::Space, y::Space, z::Space...) = otimes(otimes(x, y), z...)
dot(x::Space, y::Space) = contract(x, y, Val(1))
contract1(x::Space, y::Space) = contract(x, y, Val(1))

Check warning on line 111 in src/Space.jl

View check run for this annotation

Codecov / codecov/patch

src/Space.jl#L111

Added line #L111 was not covered by tests
contract2(x::Space, y::Space) = contract(x, y, Val(2))

# promote_space
Expand Down
3 changes: 3 additions & 0 deletions src/Tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ end
@inline function Tensor(A::StaticArray{S, T}) where {S, T}
Tensor{S, T}(Tuple(A))
end
@inline function Tensor(A::LinearAlgebra.Transpose{<: Any, <: StaticArray})
transpose(Tensor(parent(A)))
end

## for aliases
@inline function Tensor{S, T, N}(data::Tuple{Vararg{Any, L}}) where {S, T, N, L}
Expand Down
3 changes: 1 addition & 2 deletions src/Tensorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ export
contract2,
contract3,
otimes,
dotdot,
symmetric,
minorsymmetric,
skew,
Expand Down Expand Up @@ -104,8 +103,8 @@ include("abstractarray.jl")

include("quaternion.jl")

const ⊗ = otimes
const ⊡ = contract2
const ⊗ = otimes

@deprecate contraction contract true
@deprecate double_contraction contract2 true
Expand Down
52 changes: 18 additions & 34 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ end
@inline contract2(x::AbstractSquareTensor, y::UniformScaling) = tr(x) * y.λ
@inline contract2(x::UniformScaling, y::AbstractSquareTensor) = x.λ * tr(y)

# error for standard multiplications
error_multiply() = error("use `⋅` (`\\cdot`) for single contraction and `⊡` (`\\boxdot`) for double contraction instead of `*`")
Base.:*(::AbstractTensor, ::AbstractTensor) = error_multiply()
Base.:*(::AbstractTensor, ::UniformScaling) = error_multiply()
Base.:*(::UniformScaling, ::AbstractTensor) = error_multiply()
# multiplication
@inline Base.:*(x::AbstractVecOrMatLike, y::AbstractVecOrMatLike) = Tensor(SArray(x) * SArray(y))
@inline Base.:*(x::LinearAlgebra.Transpose{T, <: AbstractVec{<: Any, T}}, y::AbstractVec{<: Any, U}) where {T <: Real, U <: Real} = parent(x) ⋅ y
@inline Base.:*(x::AbstractVecOrMatLike, I::UniformScaling) = x * I.λ
@inline Base.:*(I::UniformScaling, x::AbstractVecOrMatLike) = I.λ * x

"""
contract(x, y, ::Val{N})
Expand All @@ -65,9 +65,9 @@ julia> A = contract(B, C, Val(2))

Following symbols are also available for specific contractions:

- `x ⊗ y` (where `⊗` can be typed by `\\otimes<tab>`): `contract(x, y, Val(0))`
- `x ⋅ y` (where `⋅` can be typed by `\\cdot<tab>`): `contract(x, y, Val(1))`
- `x ⊡ y` (where `⊡` can be typed by `\\boxdot<tab>`): `contract(x, y, Val(2))`
- `x ⊗ y` (where `⊗` can be typed by `\\otimes<tab>` ): `contract(x, y, Val(0))`
- `x ⋅ y` (where `⋅` can be typed by `\\cdot<tab>` ): `contract(x, y, Val(1))`
- `x ⊡ y` (where `⊡` can be typed by `\\boxdot<tab>` ): `contract(x, y, Val(2))`
"""
@generated function contract(t1::AbstractTensor, t2::AbstractTensor, ::Val{N}) where {N}
S1 = Space(t1)
Expand Down Expand Up @@ -212,12 +212,10 @@ otimes(n::Int) = OTimes{n}()
end

"""
dot(x::AbstractTensor, y::AbstractTensor)
x ⋅ y
contract1(x::AbstractTensor, y::AbstractTensor)

Compute dot product such as ``a = x_i y_i``.
Compute single contraction such as ``a = x_i y_i``.
This is equivalent to [`contract(::AbstractTensor, ::AbstractTensor, Val(1))`](@ref).
`x ⋅ y` (where `⋅` can be typed by `\\cdot<tab>`) is a synonym for `dot(x, y)`.

# Examples
```jldoctest
Expand All @@ -237,7 +235,7 @@ julia> a = x ⋅ y
0.5715585109976284
```
"""
@inline dot(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(1)) # dot, ⋅
@inline contract1(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(1))
@inline contract2(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(2)) # ⊡
@inline contract3(x1::AbstractTensor, x2::AbstractTensor) = contract(x1, x2, Val(3))

Expand All @@ -259,15 +257,8 @@ julia> norm(x)
```
"""
@inline norm(x::AbstractTensor) = sqrt(contract(x, x, Val(ndims(x))))

@inline normalize(x::AbstractTensor) = x / norm(x)

# v_k * S_ikjl * u_l
@inline function dotdot(v1::Vec{dim}, S::SymmetricFourthOrderTensor{dim}, v2::Vec{dim}) where {dim}
S′ = SymmetricFourthOrderTensor{dim}((i,j,k,l) -> @inbounds S[j,i,l,k])
v1 ⋅ S′ ⋅ v2
end

"""
tr(::AbstractSecondOrderTensor)
tr(::AbstractSymmetricSecondOrderTensor)
Expand Down Expand Up @@ -418,7 +409,6 @@ end
@inline transpose(x::AbstractTensor{Tuple{@Symmetry{dim, dim}}}) where {dim} = x
@inline transpose(x::AbstractTensor{Tuple{m, n}}) where {m, n} = Tensor{Tuple{n, m}}((i,j) -> @inbounds x[j,i])
@inline adjoint(x::AbstractTensor) = transpose(x)
@inline adjoint(::AbstractVec) = throw(ArgumentError("adjoint for `AbstractVec` is not allowed"))

# det
@generated function extract_vecs(x::AbstractSquareTensor{dim}) where {dim}
Expand Down Expand Up @@ -454,15 +444,11 @@ end
end

"""
cross(x::Vec{3}, y::Vec{3}) -> Vec{3}
cross(x::Vec{2}, y::Vec{2}) -> Vec{3}
cross(x::Vec{1}, y::Vec{1}) -> Vec{3}
cross(x::Vec, y::Vec)
x × y

Compute the cross product between two vectors.
The vectors are expanded to 3D frist for dimensions 1 and 2.
The infix operator `×` (written `\\times`) can also be used.
`x × y` (where `×` can be typed by `\\times<tab>`) is a synonym for `cross(x, y)`.
The infix operation `x × y` (where `×` can be typed by `\\times<tab>`) is a synonym for `cross(x, y)`.

# Examples
```jldoctest
Expand All @@ -485,16 +471,14 @@ julia> x × y
-0.37588028973385323
```
"""
@inline cross(x::Vec{1, T1}, y::Vec{1, T2}) where {T1, T2} = zero(Vec{3, promote_type(T1, T2)})
@inline function cross(x::Vec{2, T1}, y::Vec{2, T2}) where {T1, T2}
z = zero(promote_type(T1, T2))
@inbounds Vec(z, z, x[1]*y[2] - x[2]*y[1])
end
@inline function cross(x::Vec{3}, y::Vec{3})
@inbounds Vec(x[2]*y[3] - x[3]*y[2],
x[3]*y[1] - x[1]*y[3],
x[1]*y[2] - x[2]*y[1])
end
@inline function cross(x::Vec{2}, y::Vec{2})
@inbounds x[1]*y[2] - x[2]*y[1]
end

# power
@inline Base.literal_pow(::typeof(^), x::AbstractSquareTensor, ::Val{-1}) = inv(x)
Expand All @@ -509,7 +493,7 @@ end
y
end
## helper functions
@inline _powdot(x::AbstractSecondOrderTensor, y::AbstractSecondOrderTensor) = dot(x, y)
@inline _powdot(x::AbstractSecondOrderTensor, y::AbstractSecondOrderTensor) = contract1(x, y)
@inline function _powdot(x::AbstractSymmetricSecondOrderTensor{dim}, y::AbstractSymmetricSecondOrderTensor{dim}) where {dim}
contract(SymmetricSecondOrderTensor{dim}, x, y, Val(2), Val(1))
end
Expand Down Expand Up @@ -693,7 +677,7 @@ function rotmat(pair::Pair{Vec{dim, T}, Vec{dim, T}})::Mat{dim, dim, T} where {d
# https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d/2672702#2672702
a = pair.first
b = pair.second
dot(a, a) ≈ dot(b, b) || throw(ArgumentError("the norms of two vectors must be the same"))
contract1(a, a) ≈ contract1(b, b) || throw(ArgumentError("the norms of two vectors must be the same"))
a == b && return one(Mat{dim, dim, T})
a == -b && return -one(Mat{dim, dim, T})
c = a + b
Expand Down
2 changes: 1 addition & 1 deletion src/quaternion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ julia> rotate(v, quaternion(π/4, Vec(0,0,1)))
@inline rotate(v::Vec, q::Quaternion) = (q * v / q).vector

@inline Base.conj(q::Quaternion) = Quaternion(q.scalar, -q.vector)
@inline Base.abs2(q::Quaternion) = (v = Vec(q); dot(v, v))
@inline Base.abs2(q::Quaternion) = (v = Vec(q); contract1(v, v))
@inline Base.abs(q::Quaternion) = sqrt(abs2(q))
@inline norm(q::Quaternion) = abs(q)
@inline inv(q::Quaternion) = conj(q) / abs2(q)
Expand Down
12 changes: 8 additions & 4 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@
end
@testset "vcat/hcat" begin
@test @inferred(vcat(Vec(1,2,3))) === Vec(1,2,3)
@test @inferred(vcat(Vec(1,2,3)')) === Vec(1,2,3)'
@test @inferred(vcat(Mat{1,3}(1,2,3))) === Mat{1,3}(1,2,3)
@test @inferred(hcat(Vec(1,2,3))) === Mat{3,1}(1,2,3)
@test @inferred(hcat(Vec(1,2,3)')) === Vec(1,2,3)'
@test @inferred(hcat(Mat{3,1}(1,2,3))) === Mat{3,1}(1,2,3)

@test @inferred(vcat(Vec(1,2,3), Vec(4,5,6))) === Vec(1,2,3,4,5,6)
@test @inferred(vcat(Vec(1,2,3), Mat{3,1}(4,5,6))) === Mat{6,1}(1,2,3,4,5,6)
@test @inferred(vcat(Mat{1,3}(1,2,3), Mat{1,3}(4,5,6))) === @Mat [1 2 3; 4 5 6]
@test @inferred(vcat(symmetric(Mat{2,2}(1,2,2,3), :U), Mat{1,2}(4,5))) === @Mat [1 2; 2 3; 4 5]
@test @inferred(vcat(Mat{1,3}(1,2,3), Vec(4,5,6)')) === @Mat [1 2 3; 4 5 6]
@test @inferred(vcat(symmetric(Mat{2,2}(1,2,2,3), :U), Vec(4,5)')) === @Mat [1 2; 2 3; 4 5]
@test @inferred(hcat(Vec(1,2,3), Vec(4,5,6))) === Mat{3,2}(1,2,3,4,5,6)
@test @inferred(hcat(Vec(1,2,3)', Vec(4,5,6)')) === Mat{1,6}(1,2,3,4,5,6)
@test @inferred(hcat(Vec(1,2,3), Mat{3,2}(4,5,6,7,8,9))) === Mat{3,3}(1,2,3,4,5,6,7,8,9)
@test @inferred(hcat(Mat{3,1}(1,2,3), Mat{3,2}(4,5,6,7,8,9))) === Mat{3,3}(1,2,3,4,5,6,7,8,9)
@test @inferred(hcat(symmetric(Mat{2,2}(1,2,2,3), :U), Vec(4,5))) === @Mat [1 2 4; 2 3 5]
Expand Down Expand Up @@ -85,7 +89,7 @@
for T in (Float32, Float64)
x11 = rand(Mat{2,2,T})
x12 = rand(Vec{2,T})
x21 = rand(Mat{1,2,T})
x21 = rand(Vec{2,T})'
x22 = rand(T)
@test (@inferred f_2_2(x11, x12, x21, x22))::Mat{3,3,T} == f_2_2(Array.((x11, x12, x21))..., x22)
x11 = rand(SymmetricSecondOrderTensor{3,T})
Expand All @@ -95,11 +99,11 @@
x23 = rand(Mat{2,3,T})
@test (@inferred f_2_3(x11, x12, x21, x22, x23))::Mat{5,6,T} == f_2_3(Array.((x11, x12, x21, x22, x23))...)
x11 = rand(T)
x12 = rand(Mat{1,3,T})
x12 = rand(Vec{3,T})'
x21 = rand(Vec{3,T})
x22_11 = rand(SymmetricSecondOrderTensor{2,T})
x22_12 = rand(Vec{2,T})
x22_21 = rand(Mat{1,2,T})
x22_21 = rand(Vec{2,T})'
x22_22 = rand(T)
@test (@inferred f_2_2_recurse(x11,x12,x21,x22,x22_11,x22_12,x22_21,x22_22))::Mat{4,4,T} == f_2_2_recurse(x11,Array.((x12,x21,x22,x22_11,x22_12,x22_21))...,x22_22)
end
Expand Down
50 changes: 19 additions & 31 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,16 @@
@test (@inferred a * x)::Vec{2, T} == a * Array(x)
@test (@inferred x * a)::Vec{2, T} == Array(x) * a
@test (@inferred x / a)::Vec{2, T} == Array(x) / a
# bad operations
@test_throws Exception x * y
@test_throws Exception y * x
# multiplication
x = rand(Vec{2,T})
y = rand(Mat{2,2,T})
z = Vec(1,2)
@test (@inferred x' * y)::Transpose{T, Vec{2,T}} ≈ Array(x)' * Array(y)
@test (@inferred y * x)::Vec{2,T} ≈ Array(y) * Array(x)
@test (@inferred y' * x)::Vec{2,T} ≈ Array(y)' * Array(x)
@test (@inferred x * x')::Mat{2,2,T} ≈ Array(x) * Array(x)'
@test (@inferred x'x)::T ≈ Array(x)'Array(x)
@test (@inferred x'z)::T ≈ Array(x)'Array(z)
end
end

Expand Down Expand Up @@ -129,15 +136,6 @@ end
@test (@inferred ⊗(x, y, x, y))::Tensor{Tuple{3,2,3,2}, T} ≈ x ⊗ y ⊗ x ⊗ y
end
end
@testset "dotdot" begin
for T in (Float32, Float64)
x = rand(Vec{3, T})
y = rand(Vec{3, T})
S = rand(SymmetricFourthOrderTensor{3, T})
A = FourthOrderTensor{3, T}((i,j,k,l) -> S[i,k,j,l])
@test (@inferred dotdot(x, S, y))::Tensor{Tuple{3,3}, T} ≈ A ⊡ (x ⊗ y)
end
end
@testset "tr" begin
for T in (Float32, Float64)
x = rand(SecondOrderTensor{3, T})
Expand Down Expand Up @@ -205,21 +203,12 @@ end
end
end
@testset "cross" begin
for T in (Float32, Float64), dim in 1:3
for T in (Float32, Float64), dim in 2:3
x = rand(Vec{dim, T})
y = rand(Vec{dim, T})
@test (@inferred x × x)::Vec{3, T} ≈ zero(Vec{3, T})
@test x × y ≈ -y × x
if dim == 2
a = Vec{2, T}(1,0)
b = Vec{2, T}(0,1)
@test (@inferred a × b)::Vec{3, T} ≈ Vec{3, T}(0,0,1)
end
if dim == 3
a = Vec{3, T}(1,0,0)
b = Vec{3, T}(0,1,0)
@test (@inferred a × b)::Vec{3, T} ≈ Vec{3, T}(0,0,1)
end
dim == 2 && @test (@inferred x × x)::T ≈ zero(T)
dim == 3 && @test (@inferred x × x)::Vec{3, T} ≈ zero(Vec{3, T})
end
end
@testset "pow" begin
Expand Down Expand Up @@ -390,13 +379,12 @@ end
@test (@inferred I ⋅ v)::Vec{3, T} == one(x) ⋅ v
@test (@inferred I ⊡ x)::T == one(x) ⊡ x
@test (@inferred x ⊡ I)::T == x ⊡ one(x)
# wrong input
@test_throws Exception x * I
@test_throws Exception y * I
@test_throws Exception v * I
@test_throws Exception I * x
@test_throws Exception I * y
@test_throws Exception I * v
# multiplication
@test (@inferred x * I)::typeof(x) == x
@test (@inferred y * I)::typeof(y) == y
@test (@inferred I * x)::typeof(x) == x
@test (@inferred I * y)::typeof(y) == y
@test (@inferred I * v)::typeof(v) == v
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Tensorial
using Test, Random

using LinearAlgebra: Symmetric, Eigen
using LinearAlgebra: Symmetric, Eigen, Transpose
using StaticArrays: SArray, SVector, SOneTo, SUnitRange
using TensorOperations: @tensor

Expand Down
Loading