Skip to content

Commit

Permalink
inline all FD operators and axisvector conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Sep 17, 2023
1 parent 4962dfc commit 25e8f21
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 21 deletions.
11 changes: 7 additions & 4 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,10 +609,13 @@ end
push!(vals, val)
end
end
return :(@inbounds Axis2Tensor(
(ato, axes(x, 2)),
SMatrix{$(length(Ito)), $M}($(vals...)),
))
quote
Base.@_propagate_inbounds_meta
@inbounds Axis2Tensor(
(ato, axes(x, 2)),
SMatrix{$(length(Ito)), $M}($(vals...)),
)
end
end

@inline transform(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
Expand Down
40 changes: 24 additions & 16 deletions src/Geometry/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,65 +49,73 @@ CovariantVector(
) where {T, I} = local_geometry.gᵢⱼ * u

# Converting to specific dimension types
(::Type{<:ContravariantVector{<:Any, I}})(
@inline (::Type{<:ContravariantVector{<:Any, I}})(
u::ContravariantVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:ContravariantVector{<:Any, I}})(
@inline (::Type{<:ContravariantVector{<:Any, I}})(
u::ContravariantVector,
::LocalGeometry,
) where {I} = project(ContravariantAxis{I}(), u)

(::Type{<:ContravariantVector{<:Any, I}})(
@inline (::Type{<:ContravariantVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} =
project(ContravariantAxis{I}(), ContravariantVector(u, local_geometry))

(::Type{<:CovariantVector{<:Any, I}})(
@inline (::Type{<:CovariantVector{<:Any, I}})(
u::CovariantVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:CovariantVector{<:Any, I}})(
@inline (::Type{<:CovariantVector{<:Any, I}})(
u::CovariantVector,
::LocalGeometry,
) where {I} = project(CovariantAxis{I}(), u)

(::Type{<:CovariantVector{<:Any, I}})(
@inline (::Type{<:CovariantVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} = project(CovariantAxis{I}(), CovariantVector(u, local_geometry))

(::Type{<:LocalVector{<:Any, I}})(
@inline (::Type{<:LocalVector{<:Any, I}})(
u::LocalVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:LocalVector{<:Any, I}})(u::LocalVector, ::LocalGeometry) where {I} =
project(LocalAxis{I}(), u)
@inline (::Type{<:LocalVector{<:Any, I}})(
u::LocalVector,
::LocalGeometry,
) where {I} = project(LocalAxis{I}(), u)

(::Type{<:LocalVector{<:Any, I}})(
@inline (::Type{<:LocalVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} = project(LocalAxis{I}(), LocalVector(u, local_geometry))

# Generic N-axis conversion functions,
# Convert to specific local geometry dimension then convert vector type
LocalVector(u::CovariantVector, local_geometry::LocalGeometry{I}) where {I} =
@inline LocalVector(
u::CovariantVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(LocalAxis{I}(), project(CovariantAxis{I}(), u), local_geometry)

LocalVector(
@inline LocalVector(
u::ContravariantVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(LocalAxis{I}(), project(ContravariantAxis{I}(), u), local_geometry)

CovariantVector(u::LocalVector, local_geometry::LocalGeometry{I}) where {I} =
@inline CovariantVector(
u::LocalVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(CovariantAxis{I}(), project(LocalAxis{I}(), u), local_geometry)

CovariantVector(
@inline CovariantVector(
u::ContravariantVector,
local_geometry::LocalGeometry{I},
) where {I} = project(
Expand All @@ -116,13 +124,13 @@ CovariantVector(
local_geometry,
)

ContravariantVector(
@inline ContravariantVector(
u::LocalVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(ContravariantAxis{I}(), project(LocalAxis{I}(), u), local_geometry)

ContravariantVector(
@inline ContravariantVector(
u::CovariantVector,
local_geometry::LocalGeometry{I},
) where {I} = project(
Expand Down
2 changes: 1 addition & 1 deletion src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3446,7 +3446,7 @@ function Base.copyto!(
max_threads = 256
nitems = Nv * Nq * Nq * Nh # # of independent items
(nthreads, nblocks) = Spaces._configure_threadblock(max_threads, nitems)
@cuda threads = (nthreads,) blocks = (nblocks,) copyto_stencil_kernel!(
@cuda always_inline = true threads = (nthreads,) blocks = (nblocks,) copyto_stencil_kernel!(
strip_space(out, space),
strip_space(bc, space),
axes(out),
Expand Down

0 comments on commit 25e8f21

Please sign in to comment.