Skip to content

Commit

Permalink
add special cases for projection from Covariant => Contravariant
Browse files Browse the repository at this point in the history
fix tensor conversion

fix dispatch
  • Loading branch information
simonbyrne authored and charleskawczynski committed Aug 3, 2022
1 parent 5c8af48 commit 845732e
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,16 @@ const Axis2TensorOrAdj{T, A, S} =

# based on 1st dimension
const Covariant2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{CovariantAxis, AbstractAxis}}
const Contravariant2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{ContravariantAxis, AbstractAxis}}
Axis2Tensor{T, A, S} where {T, A <: Tuple{CovariantAxis, AbstractAxis}, S}
const Contravariant2Tensor{T, A, S} = Axis2Tensor{
T,
A,
S,
} where {T, A <: Tuple{ContravariantAxis, AbstractAxis}, S}
const Cartesian2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{CartesianAxis, AbstractAxis}}
Axis2Tensor{T, A, S} where {T, A <: Tuple{CartesianAxis, AbstractAxis}, S}
const Local2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{LocalAxis, AbstractAxis}}
Axis2Tensor{T, A, S} where {T, A <: Tuple{LocalAxis, AbstractAxis}, S}

const CovariantTensor = Union{CovariantVector, Covariant2Tensor}
const ContravariantTensor = Union{ContravariantVector, Contravariant2Tensor}
Expand Down
108 changes: 108 additions & 0 deletions src/Geometry/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ for op in (:transform, :project)
)

# Covariant <-> Contravariant
#=
@inline $op(
ax::ContravariantAxis,
v::CovariantTensor,
Expand All @@ -302,6 +303,7 @@ for op in (:transform, :project)
local_geometry.∂ξ∂x' *
$op(dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
=#
@inline $op(
ax::CovariantAxis,
v::ContravariantTensor,
Expand All @@ -327,6 +329,112 @@ for op in (:transform, :project)
end
end

@inline transform(
ax::ContravariantAxis,
v::CovariantTensor,
local_geometry::LocalGeometry,
) = transform(
ax,
local_geometry.∂ξ∂x *
local_geometry.∂ξ∂x' *
transform(dual(axes(local_geometry.∂ξ∂x, 1)), v),
)

@generated function project(
ax::ContravariantAxis{Ito},
v::CovariantVector{T, Ifrom},
local_geometry::LocalGeometry{J},
) where {T, Ito, Ifrom, J}
Nfrom = length(Ifrom)
Nto = length(Ito)
NJ = length(J)

vals = []
for i in Ito
if i J
# e.g. i = 2, J = (1,2,3)
IJ = intersect(J, Ifrom)
if isempty(IJ)
val = 0
else
niJ = findfirst(==(i), J)
val = Expr(
:call,
:+,
[
:(
local_geometry.gⁱʲ[$niJ, $(findfirst(==(j), J))] * v[$(findfirst(==(j), Ifrom))]
) for j in IJ
]...,
)
end
elseif i Ifrom
# e.g. i = 2, J = (1,3), Ifrom = (2,)
ni = findfirst(==(i), Ifrom)
val = :(v[$ni])
else
# e.g. i = 2, J = (1,3), Ifrom = (1,)
val = 0
end
push!(vals, val)
end
quote
Base.@_propagate_inbounds_meta
AxisVector(ContravariantAxis{$Ito}(), SVector{$Nto, $T}($(vals...)))
end
end
@generated function project(
ax::ContravariantAxis{Ito},
v::Contravariant2Tensor{T, Tuple{CovariantAxis{Ifrom}, A}},
local_geometry::LocalGeometry{J},
) where {T, Ito, Ifrom, A, J}
Nfrom = length(Ifrom)
Nto = length(Ito)
NJ = length(J)
NA = length(A.instance)

vals = []
for na in 1:NA
for i in Ito
if i J
# e.g. i = 2, J = (1,2,3)
IJ = intersect(J, Ifrom)
if isempty(IJ)
val = 0
else
niJ = findfirst(==(i), J)
val = Expr(
:call,
:+,
[
:(
local_geometry.gⁱʲ[
$niJ,
$(findfirst(==(j), J)),
] * v[$(findfirst(==(j), Ifrom)), $na]
) for j in IJ
]...,
)
end
elseif i Ifrom
# e.g. i = 2, J = (1,3), Ifrom = (2,)
ni = findfirst(==(i), Ifrom)
val = :(v[$ni, $na])
else
# e.g. i = 2, J = (1,3), Ifrom = (1,)
val = 0
end
push!(vals, val)
end
end
quote
Base.@_propagate_inbounds_meta
AxisTensor(
(ContravariantAxis{$Ito}(), A.instance),
SMatrix{$Nto, $NA, $T, $(Nto * NA)}($(vals...)),
)
end
end

"""
divergence_result_type(V)
Expand Down
72 changes: 72 additions & 0 deletions test/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,75 @@ end
@test Geometry.UVVector(Geometry._cross(uⁱ, vⁱ, local_geom), local_geom) ==
Geometry.UVVector(6.0, -3.0)
end


@testset "project" begin
M = @SMatrix [
2.0 0.0
0.0 1.0
]
J = det(M)

local_geom = Geometry.LocalGeometry(
Geometry.XYPoint(0.0, 0.0),
J,
J,
Geometry.Axis2Tensor(
(Geometry.UVAxis(), Geometry.Covariant12Axis()),
M,
),
)

@test Geometry.project(
Geometry.Contravariant12Axis(),
Covariant12Vector(1.0, 1.0),
local_geom,
) == Contravariant12Vector(0.25, 1.0)
@test Geometry.project(
Geometry.Contravariant1Axis(),
Covariant12Vector(1.0, 1.0),
local_geom,
) == Contravariant1Vector(0.25)
@test Geometry.project(
Geometry.Contravariant2Axis(),
Covariant12Vector(1.0, 1.0),
local_geom,
) == Contravariant2Vector(1.0)
@test Geometry.project(
Geometry.Contravariant123Axis(),
Covariant12Vector(1.0, 1.0),
local_geom,
) == Contravariant123Vector(0.25, 1.0, 0.0)
@test Geometry.project(
Geometry.Contravariant123Axis(),
Covariant123Vector(1.0, 1.0, 1.0),
local_geom,
) == Contravariant123Vector(0.25, 1.0, 1.0)


@test Geometry.project(
Geometry.Contravariant12Axis(),
Covariant12Vector(1.0, 1.0) Covariant12Vector(2.0, 8.0),
local_geom,
) == Contravariant12Vector(0.25, 1.0) Covariant12Vector(2.0, 8.0)
@test Geometry.project(
Geometry.Contravariant1Axis(),
Covariant12Vector(1.0, 1.0) Covariant12Vector(2.0, 8.0),
local_geom,
) == Contravariant1Vector(0.25) Covariant12Vector(2.0, 8.0)
@test Geometry.project(
Geometry.Contravariant2Axis(),
Covariant12Vector(1.0, 1.0) Covariant12Vector(2.0, 8.0),
local_geom,
) == Contravariant2Vector(1.0) Covariant12Vector(2.0, 8.0)
@test Geometry.project(
Geometry.Contravariant123Axis(),
Covariant12Vector(1.0, 1.0) Covariant12Vector(2.0, 8.0),
local_geom,
) == Contravariant123Vector(0.25, 1.0, 0.0) Covariant12Vector(2.0, 8.0)
@test Geometry.project(
Geometry.Contravariant123Axis(),
Covariant123Vector(1.0, 1.0, 1.0) Covariant12Vector(2.0, 8.0),
local_geom,
) == Contravariant123Vector(0.25, 1.0, 1.0) Covariant12Vector(2.0, 8.0)
end

0 comments on commit 845732e

Please sign in to comment.