Skip to content

Commit

Permalink
Merge #853
Browse files Browse the repository at this point in the history
853: add special cases for projection from Covariant => Contravariant r=charleskawczynski a=simonbyrne



Co-authored-by: Simon Byrne <[email protected]>
Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
3 people authored Aug 25, 2022
2 parents 701090c + 5f9a0c7 commit 088c231
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 143 deletions.
44 changes: 27 additions & 17 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,16 @@ const Axis2TensorOrAdj{T, A, S} =

# based on 1st dimension
const Covariant2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{CovariantAxis, AbstractAxis}}
const Contravariant2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{ContravariantAxis, AbstractAxis}}
Axis2Tensor{T, A, S} where {T, A <: Tuple{CovariantAxis, AbstractAxis}, S}
const Contravariant2Tensor{T, A, S} = Axis2Tensor{
T,
A,
S,
} where {T, A <: Tuple{ContravariantAxis, AbstractAxis}, S}
const Cartesian2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{CartesianAxis, AbstractAxis}}
Axis2Tensor{T, A, S} where {T, A <: Tuple{CartesianAxis, AbstractAxis}, S}
const Local2Tensor{T, A, S} =
Axis2Tensor{T, A, S} where {A <: Tuple{LocalAxis, AbstractAxis}}
Axis2Tensor{T, A, S} where {T, A <: Tuple{LocalAxis, AbstractAxis}, S}

const CovariantTensor = Union{CovariantVector, Covariant2Tensor}
const ContravariantTensor = Union{ContravariantVector, Contravariant2Tensor}
Expand Down Expand Up @@ -477,6 +480,9 @@ end
x
end

#= Set `assert_exact_transform() = true` for debugging=#
assert_exact_transform() = false

@generated function _transform(
ato::Ato,
x::Axis2Tensor{T, Tuple{Afrom, A2}},
Expand All @@ -487,12 +493,14 @@ end
} where {Ito, Ifrom, J, T}
N = length(Ifrom)
M = length(J)
errcond = false
for n in 1:N
i = Ifrom[n]
if i Ito
for m in 1:M
errcond = :($errcond || x[$n, $m] != zero(T))
if assert_exact_transform()
errcond = false
for n in 1:N
i = Ifrom[n]
if i Ito
for m in 1:M
errcond = :($errcond || x[$n, $m] != zero(T))
end
end
end
end
Expand All @@ -511,8 +519,10 @@ end
end
quote
Base.@_propagate_inbounds_meta
if $errcond
throw(InexactError(:transform, Ato, x))
if assert_exact_transform()
if $errcond
throw(InexactError(:transform, Ato, x))
end
end
@inbounds Axis2Tensor(
(ato, axes(x, 2)),
Expand Down Expand Up @@ -550,11 +560,11 @@ end
))
end

@inline transform(ato::CovariantAxis, v::CovariantTensor) = _transform(ato, v)
@inline transform(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
@inline transform(ato::ContravariantAxis, v::ContravariantTensor) =
_transform(ato, v)
@inline transform(ato::CartesianAxis, v::CartesianTensor) = _transform(ato, v)
@inline transform(ato::LocalAxis, v::LocalTensor) = _transform(ato, v)
_project(ato, v)
@inline transform(ato::CartesianAxis, v::CartesianTensor) = _project(ato, v)
@inline transform(ato::LocalAxis, v::LocalTensor) = _project(ato, v)

@inline project(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
@inline project(ato::ContravariantAxis, v::ContravariantTensor) =
Expand Down
194 changes: 167 additions & 27 deletions src/Geometry/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ LocalVector(
(::Type{<:ContravariantVector{<:Any, I}})(
u::ContravariantVector,
::LocalGeometry,
) where {I} = transform(ContravariantAxis{I}(), u)
) where {I} = project(ContravariantAxis{I}(), u)

(::Type{<:ContravariantVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} =
transform(ContravariantAxis{I}(), ContravariantVector(u, local_geometry))
project(ContravariantAxis{I}(), ContravariantVector(u, local_geometry))

(::Type{<:CovariantVector{<:Any, I}})(
u::CovariantVector{<:Any, I},
Expand All @@ -50,67 +50,61 @@ LocalVector(
(::Type{<:CovariantVector{<:Any, I}})(
u::CovariantVector,
::LocalGeometry,
) where {I} = transform(CovariantAxis{I}(), u)
) where {I} = project(CovariantAxis{I}(), u)

(::Type{<:CovariantVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} = transform(CovariantAxis{I}(), CovariantVector(u, local_geometry))
) where {I} = project(CovariantAxis{I}(), CovariantVector(u, local_geometry))

(::Type{<:LocalVector{<:Any, I}})(
u::LocalVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:LocalVector{<:Any, I}})(u::LocalVector, ::LocalGeometry) where {I} =
transform(LocalAxis{I}(), u)
project(LocalAxis{I}(), u)

(::Type{<:LocalVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} = transform(LocalAxis{I}(), LocalVector(u, local_geometry))
) where {I} = project(LocalAxis{I}(), LocalVector(u, local_geometry))

# Generic N-axis conversion functions,
# Convert to specific local geometry dimension then convert vector type
LocalVector(u::CovariantVector, local_geometry::LocalGeometry{I}) where {I} =
transform(LocalAxis{I}(), transform(CovariantAxis{I}(), u), local_geometry)
project(LocalAxis{I}(), project(CovariantAxis{I}(), u), local_geometry)

LocalVector(
u::ContravariantVector,
local_geometry::LocalGeometry{I},
) where {I} = transform(
LocalAxis{I}(),
transform(ContravariantAxis{I}(), u),
local_geometry,
)
) where {I} =
project(LocalAxis{I}(), project(ContravariantAxis{I}(), u), local_geometry)

CovariantVector(u::LocalVector, local_geometry::LocalGeometry{I}) where {I} =
transform(CovariantAxis{I}(), transform(LocalAxis{I}(), u), local_geometry)
project(CovariantAxis{I}(), project(LocalAxis{I}(), u), local_geometry)

CovariantVector(
u::ContravariantVector,
local_geometry::LocalGeometry{I},
) where {I} = transform(
) where {I} = project(
CovariantAxis{I}(),
transform(ContravariantAxis{I}(), u),
project(ContravariantAxis{I}(), u),
local_geometry,
)

ContravariantVector(
u::LocalVector,
local_geometry::LocalGeometry{I},
) where {I} = transform(
ContravariantAxis{I}(),
transform(LocalAxis{I}(), u),
local_geometry,
)
) where {I} =
project(ContravariantAxis{I}(), project(LocalAxis{I}(), u), local_geometry)

ContravariantVector(
u::CovariantVector,
local_geometry::LocalGeometry{I},
) where {I} = transform(
) where {I} = project(
ContravariantAxis{I}(),
transform(CovariantAxis{I}(), u),
project(CovariantAxis{I}(), u),
local_geometry,
)

Expand Down Expand Up @@ -172,7 +166,7 @@ Base.@propagate_inbounds Jcontravariant3(
)
u₁, v, u₃ = components(vector)
vector2 = Covariant13Vector(u₁, u₃)
u, w = components(transform(LocalAxis{(1, 3)}(), vector2, local_geometry))
u, w = components(project(LocalAxis{(1, 3)}(), vector2, local_geometry))
return UVWVector(u, v, w)
end
@inline function contravariant1(
Expand All @@ -181,25 +175,25 @@ end
)
u₁, _, u₃ = components(vector)
vector2 = Covariant13Vector(u₁, u₃)
return transform(Contravariant13Axis(), vector2, local_geometry).
return project(Contravariant13Axis(), vector2, local_geometry).
end
@inline function contravariant3(
vector::CovariantVector{<:Any, (1, 2)},
local_geometry::LocalGeometry{(1, 3)},
)
u₁, _ = components(vector)
vector2 = Covariant13Vector(u₁, zero(u₁))
return transform(Contravariant13Axis(), vector2, local_geometry).
return project(Contravariant13Axis(), vector2, local_geometry).
end
@inline function ContravariantVector(
vector::CovariantVector{<:Any, (1, 2)},
local_geometry::LocalGeometry{(1, 3)},
)
u₁, v = components(vector)
vector2 = Covariant1Vector(u₁)
vector3 = transform(
vector3 = project(
ContravariantAxis{(1, 3)}(),
transform(CovariantAxis{(1, 3)}(), vector2),
project(CovariantAxis{(1, 3)}(), vector2),
local_geometry,
)
u¹, u³ = components(vector3)
Expand Down Expand Up @@ -333,6 +327,7 @@ for op in (:transform, :project)
)

# Covariant <-> Contravariant
#=
@inline $op(
ax::ContravariantAxis,
v::CovariantTensor,
Expand All @@ -343,6 +338,7 @@ for op in (:transform, :project)
local_geometry.∂ξ∂x' *
$op(dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
=#
@inline $op(
ax::CovariantAxis,
v::ContravariantTensor,
Expand All @@ -368,6 +364,150 @@ for op in (:transform, :project)
end
end

@inline transform(
ax::ContravariantAxis,
v::CovariantTensor,
local_geometry::LocalGeometry,
) = project(
ax,
local_geometry.∂ξ∂x *
local_geometry.∂ξ∂x' *
project(dual(axes(local_geometry.∂ξ∂x, 1)), v),
)

@generated function project(
ax::ContravariantAxis{Ito},
v::CovariantVector{T, Ifrom},
local_geometry::LocalGeometry{J},
) where {T, Ito, Ifrom, J}
Nfrom = length(Ifrom)
Nto = length(Ito)
NJ = length(J)

vals = []
for i in Ito
if i J
# e.g. i = 2, J = (1,2,3)
IJ = intersect(J, Ifrom)
if isempty(IJ)
val = 0
else
niJ = findfirst(==(i), J)
val = Expr(
:call,
:+,
[
:(
local_geometry.gⁱʲ[$niJ, $(findfirst(==(j), J))] * v[$(findfirst(==(j), Ifrom))]
) for j in IJ
]...,
)
end
elseif i Ifrom
# e.g. i = 2, J = (1,3), Ifrom = (2,)
ni = findfirst(==(i), Ifrom)
val = :(v[$ni])
else
# e.g. i = 2, J = (1,3), Ifrom = (1,)
val = 0
end
push!(vals, val)
end
quote
Base.@_propagate_inbounds_meta
AxisVector(ContravariantAxis{$Ito}(), SVector{$Nto, $T}($(vals...)))
end
end
@generated function project(
ax::ContravariantAxis{Ito},
v::Contravariant2Tensor{T, Tuple{CovariantAxis{Ifrom}, A}},
local_geometry::LocalGeometry{J},
) where {T, Ito, Ifrom, A, J}
Nfrom = length(Ifrom)
Nto = length(Ito)
NJ = length(J)
NA = length(A.instance)

vals = []
for na in 1:NA
for i in Ito
if i J
# e.g. i = 2, J = (1,2,3)
IJ = intersect(J, Ifrom)
if isempty(IJ)
val = 0
else
niJ = findfirst(==(i), J)
val = Expr(
:call,
:+,
[
:(
local_geometry.gⁱʲ[
$niJ,
$(findfirst(==(j), J)),
] * v[$(findfirst(==(j), Ifrom)), $na]
) for j in IJ
]...,
)
end
elseif i Ifrom
# e.g. i = 2, J = (1,3), Ifrom = (2,)
ni = findfirst(==(i), Ifrom)
val = :(v[$ni, $na])
else
# e.g. i = 2, J = (1,3), Ifrom = (1,)
val = 0
end
push!(vals, val)
end
end
quote
Base.@_propagate_inbounds_meta
AxisTensor(
(ContravariantAxis{$Ito}(), A.instance),
SMatrix{$Nto, $NA, $T, $(Nto * NA)}($(vals...)),
)
end
end

# A few other expensive ones:
#! format: off
@inline function project(
ax::ContravariantAxis{(1,)},
v::AxisTensor{FT,2,Tuple{LocalAxis{(1, 2)},LocalAxis{(1, 2)}},SMatrix{2,2,FT,4}},
lg::LocalGeometry{(1, 2, 3),XYZPoint{FT},FT,SMatrix{3,3,FT,9}}
) where {FT}
AxisTensor(
(ContravariantAxis{(1,)}(), LocalAxis{(1, 2)}()),
@inbounds @SMatrix [
lg.∂ξ∂x[1, 1]*v[1, 1]+lg.∂ξ∂x[1, 2]*v[2, 1] lg.∂ξ∂x[1, 1]*v[1, 2]+lg.∂ξ∂x[1, 2]*v[2, 2]
])
end
@inline function project(
ax::ContravariantAxis{(2,)},
v::AxisTensor{FT,2,Tuple{LocalAxis{(1,2)},LocalAxis{(1,2)}},SMatrix{2,2,FT,4}},
lg::LocalGeometry{(1,2,3),XYZPoint{FT},FT,SMatrix{3,3,FT,9}}
) where {FT}
AxisTensor(
(ContravariantAxis{(2,)}(), LocalAxis{(1, 2)}()),
@inbounds @SMatrix [
lg.∂ξ∂x[2, 1]*v[1, 1]+lg.∂ξ∂x[2, 2]*v[2, 1] lg.∂ξ∂x[2, 1]*v[1, 2]+lg.∂ξ∂x[2, 2]*v[2, 2]
]
)
end
@inline function project(
ax::ContravariantAxis{(3,)},
v::AxisTensor{FT,2,Tuple{LocalAxis{(3,)},LocalAxis{(1,2)}},SMatrix{1,2,FT,2}},
lg::LocalGeometry{(1,2,3),XYZPoint{FT},FT,SMatrix{3,3,FT,9}}
) where {FT}
AxisTensor(
(ContravariantAxis{(3,)}(), LocalAxis{(1, 2)}()),
@inbounds @SMatrix [lg.∂ξ∂x[3, 3]*v[1, 1] lg.∂ξ∂x[3, 3]*v[1, 2]]
)
end
#! format: on


"""
divergence_result_type(V)
Expand Down
2 changes: 1 addition & 1 deletion src/Geometry/globalgeometry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Cartesian123Point(pt::AbstractPoint, global_geometry::AbstractGlobalGeometry) =
u::AxisVector,
global_geometry::AbstractGlobalGeometry,
local_geometry::LocalGeometry,
) where {I} = transform(
) where {I} = project(
CartesianAxis{I}(),
CartesianVector(u, global_geometry, local_geometry),
)
Expand Down
Loading

0 comments on commit 088c231

Please sign in to comment.