-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
848: Generalize axis tensor tests r=charleskawczynski a=charleskawczynski This PR generalizes some of the axis tensor tests. There's still some issues with being able to test a complete range of inputs, but these changes should make testing more methods (beyond transform/project) easier. Co-authored-by: Charles Kawczynski <[email protected]>
- Loading branch information
Showing
3 changed files
with
251 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,114 +1,139 @@ | ||
using Test | ||
using StaticArrays | ||
import BenchmarkTools | ||
import StatsBase | ||
import OrderedCollections | ||
import Random | ||
using Test, StaticArrays | ||
import Random, BenchmarkTools, StatsBase, OrderedCollections, LinearAlgebra | ||
#! format: off | ||
using ClimaCore.Geometry: | ||
AbstractAxis, | ||
CovariantAxis, | ||
AxisVector, | ||
ContravariantAxis, | ||
LocalAxis, | ||
CartesianAxis, | ||
AxisTensor, | ||
Covariant1Vector, | ||
Covariant13Vector, | ||
UVVector, | ||
UWVector, | ||
UVector, | ||
WVector, | ||
Covariant12Vector, | ||
UVWVector, | ||
Covariant123Vector, | ||
Covariant3Vector, | ||
Contravariant12Vector, | ||
Contravariant3Vector, | ||
Contravariant123Vector, | ||
Contravariant13Vector, | ||
Contravariant2Vector, | ||
Axis2Tensor | ||
using ClimaCore.Geometry:Geometry, AbstractAxis, CovariantAxis, | ||
AxisVector, ContravariantAxis, LocalAxis, CartesianAxis, AxisTensor, | ||
Covariant1Vector, Covariant13Vector, UVVector, UWVector, UVector, | ||
WVector, Covariant12Vector, UVWVector, Covariant123Vector, Covariant3Vector, | ||
Contravariant12Vector, Contravariant3Vector, Contravariant123Vector, | ||
Contravariant13Vector, Contravariant2Vector, Axis2Tensor, Contravariant3Axis, | ||
LocalGeometry, CovariantTensor, CartesianTensor, LocalTensor, ContravariantTensor | ||
|
||
include("transform_project.jl") # compact, generic but unoptimized reference | ||
include("used_transform_args.jl") | ||
include("ref_funcs.jl") | ||
include("used_project_args.jl") | ||
|
||
function benchmark_axistensor_conversions(args, ::Type{FT}, func::F) where {FT, F} | ||
results = OrderedCollections.OrderedDict() | ||
for (aTo, x) in args | ||
func(aTo, x) # compile first | ||
|
||
# Setting up a benchmark is _really_ slow | ||
# for many many microbenchmarks. So, let's | ||
# just hardcode something simple that runs | ||
# fast.. | ||
|
||
#### Using BenchmarkTools | ||
|
||
# b = BenchmarkTools.@benchmarkable $func($aTo, $x) | ||
# trial = BenchmarkTools.run(b, samples = 3) | ||
# time = StatsBase.mean(trial.times) | ||
# time = BenchmarkTools.@btime begin; result = func($aTo, $x); end | ||
|
||
#### Hard-code average of 25 calls | ||
time = @elapsed begin | ||
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x) | ||
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x) | ||
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x) | ||
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x) | ||
func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x); func(aTo, x) | ||
end | ||
ns = 1e9 | ||
time = time/25*ns # average and convert to ns | ||
# time_func(func::F, x, y) where {F} = time_func_accurate(func, x, y) | ||
time_func(func::F, x, y) where {F} = time_func_fast(func, x, y) | ||
|
||
function time_func_accurate(func::F, x, y) where {F} | ||
b = BenchmarkTools.@benchmarkable $func($x, $y) | ||
trial = BenchmarkTools.run(b; samples = 3) | ||
# show(stdout, MIME("text/plain"), trial) | ||
time = StatsBase.mean(trial.times) | ||
return time | ||
end | ||
|
||
function time_func_fast(func::F, x, y) where {F} | ||
time = 0 | ||
nsamples = 100 | ||
for i in 1:nsamples | ||
time += @elapsed func(x, y) | ||
end | ||
ns = 1e9 | ||
time = time/nsamples*ns # average and convert to ns | ||
return time | ||
end | ||
|
||
key = typeof.((aTo, x)) | ||
result = func(aTo, x) | ||
function benchmark_conversion(arg_set, ::Type{FT}, func::F) where {FT, F} | ||
results = OrderedCollections.OrderedDict() | ||
ref_func = reference_func(func) | ||
for args in arg_set | ||
key = typeof.(args) | ||
# Reference | ||
result = ref_func(args...) # compile first | ||
time = time_func(ref_func, args...) | ||
t_pretty = BenchmarkTools.prettytime(time) | ||
@info "Benchmarking $t_pretty $func with $key" | ||
results[key] = (time, t_pretty, result) | ||
ref = (;time, t_pretty, result) | ||
# Optimized | ||
result = func(args...) # compile first | ||
time = time_func(func, args...) | ||
t_pretty = BenchmarkTools.prettytime(time) | ||
opt = (;time, t_pretty, result) | ||
|
||
@info "Benchmark: opt: $(opt.t_pretty) ref: $(ref.t_pretty). Key: $key" | ||
results[key] = (;opt, ref) | ||
end | ||
return results | ||
end | ||
|
||
function get_conversion_args(ucs) | ||
map(ucs) do (axt, axtt) | ||
function all_axes() | ||
all_Is() = [(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] | ||
collect(Iterators.flatten(map(all_Is()) do I | ||
( | ||
CovariantAxis{I}(), | ||
ContravariantAxis{I}(), | ||
LocalAxis{I}(), | ||
CartesianAxis{I}() | ||
) | ||
end)) | ||
end | ||
|
||
all_observed_axistensors(::Type{FT}) where {FT} = | ||
vcat(map(x-> rand(last(x)), used_project_arg_types(FT)), | ||
map(x-> rand(last(x)), used_transform_arg_types(FT))) | ||
|
||
func_args(FT, ::typeof(Geometry.project)) = | ||
map(used_project_arg_types(FT)) do (axt, axtt) | ||
(axt(), rand(axtt)) | ||
end | ||
func_args(FT, ::typeof(Geometry.transform)) = | ||
map(used_transform_arg_types(FT)) do (axt, axtt) | ||
(axt(), rand(axtt)) | ||
end | ||
|
||
function all_possible_func_args(FT, ::typeof(Geometry.contravariant3)) | ||
# TODO: this is not accurate yet, since we don't yet | ||
# vary over all possible LocalGeometry's. | ||
M = @SMatrix [ | ||
FT(4) FT(1) | ||
FT(0.5) FT(2) | ||
] | ||
J = LinearAlgebra.det(M) | ||
∂x∂ξ = rand(Geometry.AxisTensor{FT, 2, Tuple{Geometry.LocalAxis{(3,)}, Geometry.CovariantAxis{(3,)}}, SMatrix{1, 1, FT, 1}}) | ||
lg = Geometry.LocalGeometry(Geometry.XYPoint(FT(0), FT(0)), J, J, ∂x∂ξ) | ||
# Geometry.LocalGeometry{(3,), Geometry.ZPoint{FT}, FT, SMatrix{1, 1, FT, 1}} | ||
Iterators.flatten( | ||
map(used_project_arg_types(FT)) do (axt, axtt) | ||
map(all_axes()) do ax | ||
(rand(axtt), lg) | ||
end | ||
end | ||
) | ||
end | ||
|
||
function test_optimized_transform(::Type{FT}) where {FT} | ||
@info "Testing optimized transform..." | ||
utat = used_transform_arg_types(FT) | ||
args = get_conversion_args(utat) | ||
rt_results = benchmark_axistensor_conversions(args, FT, ref_transform) | ||
ot_results = benchmark_axistensor_conversions(args, FT, ref_transform) # (opt_transform) | ||
for key in keys(rt_results) | ||
@test last(ot_results[key]) == last(rt_results[key]) # test correctness | ||
@test_broken first(ot_results[key])*30 < first(rt_results[key]) # test performance | ||
function func_args(FT, f::typeof(Geometry.contravariant3)) | ||
# TODO: fix this.. | ||
apfa = all_possible_func_args(FT, f) | ||
args_dict = Dict() | ||
for _args in apfa | ||
hasmethod(f, typeof(_args)) || continue | ||
args_dict[typeof.(_args)] = _args | ||
end | ||
return values(args_dict) | ||
end | ||
|
||
function test_optimized_project(::Type{FT}) where {FT} | ||
@info "Testing optimized project..." | ||
upat = used_project_arg_types(FT) | ||
args = get_conversion_args(upat) | ||
rp_results = benchmark_axistensor_conversions(args, FT, ref_project) | ||
op_results = benchmark_axistensor_conversions(args, FT, ref_project) # (opt_project) | ||
for key in keys(rp_results) | ||
@test last(op_results[key]) == last(rp_results[key]) # test correctness | ||
@test_broken first(op_results[key])*30 < first(rp_results[key]) # test performance | ||
reference_func(::typeof(Geometry.contravariant3)) = ref_contravariant3 | ||
reference_func(::typeof(Geometry.project)) = ref_project | ||
reference_func(::typeof(Geometry.transform)) = ref_transform | ||
|
||
function test_optimized_function(::Type{FT}, func) where {FT} | ||
@info "Testing optimized $func..." | ||
args = func_args(FT, func) | ||
bm = benchmark_conversion(args, FT, func) | ||
for key in keys(bm) | ||
@test bm[key].opt.result == bm[key].ref.result # test correctness | ||
@test_broken bm[key].opt.time*10 < bm[key].ref.time # test performance | ||
end | ||
end | ||
|
||
# TODO: figure out how to make error checking in `transform` | ||
|
||
# @testset "Test optimized transform" begin | ||
# test_optimized_transform(Float64) | ||
# end | ||
# TODO: figure out how to make error checking in `transform` optional | ||
|
||
@testset "Test optimized project" begin | ||
test_optimized_project(Float64) | ||
@testset "Test optimized functions" begin | ||
test_optimized_function(Float64, Geometry.project) | ||
# test_optimized_function(Float64, Geometry.contravariant3) | ||
# test_optimized_function(Float64, Geometry.transform) | ||
end | ||
|
||
#! format: on |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
|
||
@inline ref_contravariant1(u::AxisVector, local_geometry::LocalGeometry) = | ||
@inbounds ref_project(Contravariant1Axis(), u, local_geometry)[1] | ||
@inline ref_contravariant2(u::AxisVector, local_geometry::LocalGeometry) = | ||
@inbounds ref_project(Contravariant2Axis(), u, local_geometry)[1] | ||
@inline ref_contravariant3(u::AxisVector, local_geometry::LocalGeometry) = | ||
@inbounds ref_project(Contravariant3Axis(), u, local_geometry)[1] | ||
|
||
@inline ref_contravariant1(u::Axis2Tensor, local_geometry::LocalGeometry) = | ||
@inbounds ref_project(Contravariant1Axis(), u, local_geometry)[1, :] | ||
@inline ref_contravariant2(u::Axis2Tensor, local_geometry::LocalGeometry) = | ||
@inbounds ref_project(Contravariant2Axis(), u, local_geometry)[1, :] | ||
@inline ref_contravariant3(u::Axis2Tensor, local_geometry::LocalGeometry) = | ||
@inbounds ref_project(Contravariant3Axis(), u, local_geometry)[1, :] | ||
|
||
@inline ref_covariant1(u::AxisVector, local_geometry::LocalGeometry) = | ||
CovariantVector(u, local_geometry).u₁ | ||
@inline ref_covariant2(u::AxisVector, local_geometry::LocalGeometry) = | ||
CovariantVector(u, local_geometry).u₂ | ||
@inline ref_covariant3(u::AxisVector, local_geometry::LocalGeometry) = | ||
CovariantVector(u, local_geometry).u₃ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters