Skip to content

Commit

Permalink
Merge #848
Browse files Browse the repository at this point in the history
848: Generalize axis tensor tests r=charleskawczynski a=charleskawczynski

This PR generalizes some of the axis tensor tests. There's still some issues with being able to test a complete range of inputs, but these changes should make testing more methods (beyond transform/project) easier.

Co-authored-by: Charles Kawczynski <[email protected]>
  • Loading branch information
bors[bot] and charleskawczynski authored Aug 1, 2022
2 parents cb952ba + a54ad5b commit b1daf53
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 88 deletions.
201 changes: 113 additions & 88 deletions test/Geometry/axistensor_conversion_benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,114 +1,139 @@
using Test
using StaticArrays
import BenchmarkTools
import StatsBase
import OrderedCollections
import Random
using Test, StaticArrays
import Random, BenchmarkTools, StatsBase, OrderedCollections, LinearAlgebra
#! format: off
using ClimaCore.Geometry:
AbstractAxis,
CovariantAxis,
AxisVector,
ContravariantAxis,
LocalAxis,
CartesianAxis,
AxisTensor,
Covariant1Vector,
Covariant13Vector,
UVVector,
UWVector,
UVector,
WVector,
Covariant12Vector,
UVWVector,
Covariant123Vector,
Covariant3Vector,
Contravariant12Vector,
Contravariant3Vector,
Contravariant123Vector,
Contravariant13Vector,
Contravariant2Vector,
Axis2Tensor
using ClimaCore.Geometry:Geometry, AbstractAxis, CovariantAxis,
AxisVector, ContravariantAxis, LocalAxis, CartesianAxis, AxisTensor,
Covariant1Vector, Covariant13Vector, UVVector, UWVector, UVector,
WVector, Covariant12Vector, UVWVector, Covariant123Vector, Covariant3Vector,
Contravariant12Vector, Contravariant3Vector, Contravariant123Vector,
Contravariant13Vector, Contravariant2Vector, Axis2Tensor, Contravariant3Axis,
LocalGeometry, CovariantTensor, CartesianTensor, LocalTensor, ContravariantTensor

include("transform_project.jl") # compact, generic but unoptimized reference
include("used_transform_args.jl")
include("ref_funcs.jl")
include("used_project_args.jl")

function benchmark_axistensor_conversions(args, ::Type{FT}, func::F) where {FT, F}
results = OrderedCollections.OrderedDict()
for (aTo, x) in args
func(aTo, x) # compile first

# Setting up a benchmark is _really_ slow
# for many many microbenchmarks. So, let's
# just hardcode something simple that runs
# fast..

#### Using BenchmarkTools

# b = BenchmarkTools.@benchmarkable $func($aTo, $x)
# trial = BenchmarkTools.run(b, samples = 3)
# time = StatsBase.mean(trial.times)
# time = BenchmarkTools.@btime begin; result = func($aTo, $x); end

#### Hard-code average of 25 calls
time = @elapsed begin
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x)
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x)
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x)
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x)
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x)
end
ns = 1e9
time = time/25*ns # average and convert to ns
# time_func(func::F, x, y) where {F} = time_func_accurate(func, x, y)
time_func(func::F, x, y) where {F} = time_func_fast(func, x, y)

function time_func_accurate(func::F, x, y) where {F}
b = BenchmarkTools.@benchmarkable $func($x, $y)
trial = BenchmarkTools.run(b; samples = 3)
# show(stdout, MIME("text/plain"), trial)
time = StatsBase.mean(trial.times)
return time
end

function time_func_fast(func::F, x, y) where {F}
time = 0
nsamples = 100
for i in 1:nsamples
time += @elapsed func(x, y)
end
ns = 1e9
time = time/nsamples*ns # average and convert to ns
return time
end

key = typeof.((aTo, x))
result = func(aTo, x)
function benchmark_conversion(arg_set, ::Type{FT}, func::F) where {FT, F}
results = OrderedCollections.OrderedDict()
ref_func = reference_func(func)
for args in arg_set
key = typeof.(args)
# Reference
result = ref_func(args...) # compile first
time = time_func(ref_func, args...)
t_pretty = BenchmarkTools.prettytime(time)
@info "Benchmarking $t_pretty $func with $key"
results[key] = (time, t_pretty, result)
ref = (;time, t_pretty, result)
# Optimized
result = func(args...) # compile first
time = time_func(func, args...)
t_pretty = BenchmarkTools.prettytime(time)
opt = (;time, t_pretty, result)

@info "Benchmark: opt: $(opt.t_pretty) ref: $(ref.t_pretty). Key: $key"
results[key] = (;opt, ref)
end
return results
end

function get_conversion_args(ucs)
map(ucs) do (axt, axtt)
function all_axes()
all_Is() = [(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
collect(Iterators.flatten(map(all_Is()) do I
(
CovariantAxis{I}(),
ContravariantAxis{I}(),
LocalAxis{I}(),
CartesianAxis{I}()
)
end))
end

all_observed_axistensors(::Type{FT}) where {FT} =
vcat(map(x-> rand(last(x)), used_project_arg_types(FT)),
map(x-> rand(last(x)), used_transform_arg_types(FT)))

func_args(FT, ::typeof(Geometry.project)) =
map(used_project_arg_types(FT)) do (axt, axtt)
(axt(), rand(axtt))
end
func_args(FT, ::typeof(Geometry.transform)) =
map(used_transform_arg_types(FT)) do (axt, axtt)
(axt(), rand(axtt))
end

function all_possible_func_args(FT, ::typeof(Geometry.contravariant3))
# TODO: this is not accurate yet, since we don't yet
# vary over all possible LocalGeometry's.
M = @SMatrix [
FT(4) FT(1)
FT(0.5) FT(2)
]
J = LinearAlgebra.det(M)
∂x∂ξ = rand(Geometry.AxisTensor{FT, 2, Tuple{Geometry.LocalAxis{(3,)}, Geometry.CovariantAxis{(3,)}}, SMatrix{1, 1, FT, 1}})
lg = Geometry.LocalGeometry(Geometry.XYPoint(FT(0), FT(0)), J, J, ∂x∂ξ)
# Geometry.LocalGeometry{(3,), Geometry.ZPoint{FT}, FT, SMatrix{1, 1, FT, 1}}
Iterators.flatten(
map(used_project_arg_types(FT)) do (axt, axtt)
map(all_axes()) do ax
(rand(axtt), lg)
end
end
)
end

function test_optimized_transform(::Type{FT}) where {FT}
@info "Testing optimized transform..."
utat = used_transform_arg_types(FT)
args = get_conversion_args(utat)
rt_results = benchmark_axistensor_conversions(args, FT, ref_transform)
ot_results = benchmark_axistensor_conversions(args, FT, ref_transform) # (opt_transform)
for key in keys(rt_results)
@test last(ot_results[key]) == last(rt_results[key]) # test correctness
@test_broken first(ot_results[key])*30 < first(rt_results[key]) # test performance
function func_args(FT, f::typeof(Geometry.contravariant3))
# TODO: fix this..
apfa = all_possible_func_args(FT, f)
args_dict = Dict()
for _args in apfa
hasmethod(f, typeof(_args)) || continue
args_dict[typeof.(_args)] = _args
end
return values(args_dict)
end

function test_optimized_project(::Type{FT}) where {FT}
@info "Testing optimized project..."
upat = used_project_arg_types(FT)
args = get_conversion_args(upat)
rp_results = benchmark_axistensor_conversions(args, FT, ref_project)
op_results = benchmark_axistensor_conversions(args, FT, ref_project) # (opt_project)
for key in keys(rp_results)
@test last(op_results[key]) == last(rp_results[key]) # test correctness
@test_broken first(op_results[key])*30 < first(rp_results[key]) # test performance
reference_func(::typeof(Geometry.contravariant3)) = ref_contravariant3
reference_func(::typeof(Geometry.project)) = ref_project
reference_func(::typeof(Geometry.transform)) = ref_transform

function test_optimized_function(::Type{FT}, func) where {FT}
@info "Testing optimized $func..."
args = func_args(FT, func)
bm = benchmark_conversion(args, FT, func)
for key in keys(bm)
@test bm[key].opt.result == bm[key].ref.result # test correctness
@test_broken bm[key].opt.time*10 < bm[key].ref.time # test performance
end
end

# TODO: figure out how to make error checking in `transform`

# @testset "Test optimized transform" begin
# test_optimized_transform(Float64)
# end
# TODO: figure out how to make error checking in `transform` optional

@testset "Test optimized project" begin
test_optimized_project(Float64)
@testset "Test optimized functions" begin
test_optimized_function(Float64, Geometry.project)
# test_optimized_function(Float64, Geometry.contravariant3)
# test_optimized_function(Float64, Geometry.transform)
end

#! format: on
22 changes: 22 additions & 0 deletions test/Geometry/ref_funcs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

@inline ref_contravariant1(u::AxisVector, local_geometry::LocalGeometry) =
@inbounds ref_project(Contravariant1Axis(), u, local_geometry)[1]
@inline ref_contravariant2(u::AxisVector, local_geometry::LocalGeometry) =
@inbounds ref_project(Contravariant2Axis(), u, local_geometry)[1]
@inline ref_contravariant3(u::AxisVector, local_geometry::LocalGeometry) =
@inbounds ref_project(Contravariant3Axis(), u, local_geometry)[1]

@inline ref_contravariant1(u::Axis2Tensor, local_geometry::LocalGeometry) =
@inbounds ref_project(Contravariant1Axis(), u, local_geometry)[1, :]
@inline ref_contravariant2(u::Axis2Tensor, local_geometry::LocalGeometry) =
@inbounds ref_project(Contravariant2Axis(), u, local_geometry)[1, :]
@inline ref_contravariant3(u::Axis2Tensor, local_geometry::LocalGeometry) =
@inbounds ref_project(Contravariant3Axis(), u, local_geometry)[1, :]

@inline ref_covariant1(u::AxisVector, local_geometry::LocalGeometry) =
CovariantVector(u, local_geometry).u₁
@inline ref_covariant2(u::AxisVector, local_geometry::LocalGeometry) =
CovariantVector(u, local_geometry).u₂
@inline ref_covariant3(u::AxisVector, local_geometry::LocalGeometry) =
CovariantVector(u, local_geometry).u₃

116 changes: 116 additions & 0 deletions test/Geometry/transform_project.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,119 @@ end
SMatrix{$(length(Ito)), $M}($(vals...)),
))
end


for op in (:ref_transform, :ref_project)
@eval begin
# Covariant <-> Cartesian
@inline $op(
ax::CartesianAxis,
v::CovariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂ξ∂x' *
$op(Geometry.dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
@inline $op(
ax::CovariantAxis,
v::CartesianTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂x∂ξ' *
$op(Geometry.dual(axes(local_geometry.∂x∂ξ, 1)), v),
)
@inline $op(
ax::LocalAxis,
v::CovariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂ξ∂x' *
$op(Geometry.dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
@inline $op(
ax::CovariantAxis,
v::LocalTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂x∂ξ' *
$op(Geometry.dual(axes(local_geometry.∂x∂ξ, 1)), v),
)

# Contravariant <-> Cartesian
@inline $op(
ax::ContravariantAxis,
v::CartesianTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂ξ∂x *
$op(Geometry.dual(axes(local_geometry.∂ξ∂x, 2)), v),
)
@inline $op(
ax::CartesianAxis,
v::ContravariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂x∂ξ *
$op(Geometry.dual(axes(local_geometry.∂x∂ξ, 2)), v),
)
@inline $op(
ax::ContravariantAxis,
v::LocalTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂ξ∂x *
$op(Geometry.dual(axes(local_geometry.∂ξ∂x, 2)), v),
)

@inline $op(
ax::LocalAxis,
v::ContravariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂x∂ξ *
$op(Geometry.dual(axes(local_geometry.∂x∂ξ, 2)), v),
)

# Covariant <-> Contravariant
@inline $op(
ax::ContravariantAxis,
v::CovariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂ξ∂x *
local_geometry.∂ξ∂x' *
$op(Geometry.dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
@inline $op(
ax::CovariantAxis,
v::ContravariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂x∂ξ' *
local_geometry.∂x∂ξ *
$op(Geometry.dual(axes(local_geometry.∂x∂ξ, 2)), v),
)

@inline $op(ato::CovariantAxis, v::CovariantTensor, ::LocalGeometry) =
$op(ato, v)
@inline $op(
ato::ContravariantAxis,
v::ContravariantTensor,
::LocalGeometry,
) = $op(ato, v)
@inline $op(ato::CartesianAxis, v::CartesianTensor, ::LocalGeometry) =
$op(ato, v)
@inline $op(ato::LocalAxis, v::LocalTensor, ::LocalGeometry) =
$op(ato, v)
end
end

0 comments on commit b1daf53

Please sign in to comment.