From 6ab21982e43b1ceb7260789c5afe8f39fc48f973 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Wed, 9 Oct 2024 18:03:13 +0200 Subject: [PATCH] Move `@noinline` into code generation utilities (#205) * Remove Compat.jl from tests * Move @noinline into generated code --- src/overloads/gradient_tracer.jl | 30 ++++++++++++++-------------- src/overloads/hessian_tracer.jl | 34 ++++++++++++++++++++------------ test/Project.toml | 1 - test/runtests.jl | 1 - test/test_gradient.jl | 1 - 5 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 2ff4778..e1e9bd5 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -2,9 +2,7 @@ SCT = SparseConnectivityTracer ## 1-to-1 -@noinline function gradient_tracer_1_to_1( - t::T, is_der1_zero::Bool -) where {T<:GradientTracer} +function gradient_tracer_1_to_1(t::T, is_der1_zero::Bool) where {T<:GradientTracer} if is_der1_zero && !isemptytracer(t) return myempty(T) else @@ -36,7 +34,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function) expr_gradienttracer = quote function $M.$fname(t::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g) + return @noinline $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g) end end @@ -55,7 +53,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function) t = $SCT.tracer(d) is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) - t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero) + t_out = @noinline $SCT.gradient_tracer_1_to_1(t, is_der1_zero) return $SCT.Dual(p_out, t_out) end end @@ -65,7 +63,7 @@ end ## 2-to-1 -@noinline function gradient_tracer_2_to_1( +function gradient_tracer_2_to_1( tx::T, ty::T, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool ) where {T<:GradientTracer} # TODO: add tests for isempty @@ -116,7 +114,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function) expr_tracer_tracer = quote function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer} - return $SCT.gradient_tracer_2_to_1( + return @noinline $SCT.gradient_tracer_2_to_1( tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g ) end @@ -141,7 +139,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function) ty = $SCT.tracer(dy) is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_2_to_1( + t_out = @noinline $SCT.gradient_tracer_2_to_1( tx, ty, is_der1_arg1_zero, is_der1_arg2_zero ) return $SCT.Dual(p_out, t_out) @@ -164,12 +162,12 @@ function generate_code_gradient_2_to_1_typed( expr_tracer_type = quote function $M.$fname(tx::$SCT.GradientTracer, ::$Z) - return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g) + return @noinline $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g) end end expr_type_tracer = quote function $M.$fname(::$Z, ty::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g) + return @noinline $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g) end end @@ -188,7 +186,7 @@ function generate_code_gradient_2_to_1_typed( tx = $SCT.tracer(dx) is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) + t_out = @noinline $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero) return $SCT.Dual(p_out, t_out) end end @@ -208,7 +206,7 @@ function generate_code_gradient_2_to_1_typed( ty = $SCT.tracer(dy) is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) + t_out = @noinline $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero) return $SCT.Dual(p_out, t_out) end end @@ -218,7 +216,7 @@ end ## 1-to-2 -@noinline function gradient_tracer_1_to_2( +function gradient_tracer_1_to_2( t::T, is_der1_out1_zero::Bool, is_der1_out2_zero::Bool ) where {T<:GradientTracer} if isemptytracer(t) # TODO: add test @@ -237,7 +235,9 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function) expr_gradienttracer = quote function $M.$fname(t::$SCT.GradientTracer) - return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g) + return @noinline $SCT.gradient_tracer_1_to_2( + t, $is_der1_out1_zero_g, $is_der1_out2_zero_g + ) end end @@ -257,7 +257,7 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function) t = $SCT.tracer(d) is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x) - t_out1, t_out2 = $SCT.gradient_tracer_1_to_2( + t_out1, t_out2 = @noinline $SCT.gradient_tracer_1_to_2( t, is_der1_out1_zero, is_der1_out2_zero ) return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index e19cba5..cc7c90d 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -5,7 +5,7 @@ SCT = SparseConnectivityTracer # ๐Ÿ™[โˆ‡ฮณ] = ๐Ÿ™[โˆ‚ฯ†]โ‹…๐Ÿ™[โˆ‡ฮฑ] # ๐Ÿ™[โˆ‡ยฒฮณ] = ๐Ÿ™[โˆ‚ฯ†]โ‹…๐Ÿ™[โˆ‡ยฒฮฑ] โˆจ ๐Ÿ™[โˆ‚ยฒฯ†]โ‹…(๐Ÿ™[โˆ‡ฮฑ] โˆจ ๐Ÿ™[โˆ‡ฮฑ]แต€) -@noinline function hessian_tracer_1_to_1( +function hessian_tracer_1_to_1( t::T, is_der1_zero::Bool, is_der2_zero::Bool ) where {P<:AbstractHessianPattern,T<:HessianTracer{P}} if isemptytracer(t) # TODO: add test @@ -65,7 +65,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function) expr_hessiantracer = quote ## HessianTracer function $M.$fname(t::$SCT.HessianTracer) - return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g) + return @noinline $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g) end end @@ -85,7 +85,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function) t = $SCT.tracer(d) is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x) is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x) - t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) + t_out = @noinline $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero) return $SCT.Dual(p_out, t_out) end end @@ -96,7 +96,7 @@ end ## 2-to-1 -@noinline function hessian_tracer_2_to_1( +function hessian_tracer_2_to_1( tx::T, ty::T, is_der1_arg1_zero::Bool, @@ -189,7 +189,7 @@ function generate_code_hessian_2_to_1( expr_tracer_tracer = quote function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer} - return $SCT.hessian_tracer_2_to_1( + return @noinline $SCT.hessian_tracer_2_to_1( tx, ty, $is_der1_arg1_zero_g, @@ -232,7 +232,7 @@ function generate_code_hessian_2_to_1( is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_2_to_1( + t_out = @noinline $SCT.hessian_tracer_2_to_1( tx, ty, is_der1_arg1_zero, @@ -263,12 +263,16 @@ function generate_code_hessian_2_to_1_typed( expr_tracer_type = quote function $M.$fname(tx::$SCT.HessianTracer, y::$Z) - return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g) + return @noinline $SCT.hessian_tracer_1_to_1( + tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g + ) end end expr_type_tracer = quote function $M.$fname(x::$Z, ty::$SCT.HessianTracer) - return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g) + return @noinline $SCT.hessian_tracer_1_to_1( + ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g + ) end end @@ -288,7 +292,9 @@ function generate_code_hessian_2_to_1_typed( tx = $SCT.tracer(dx) is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y) is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) + t_out = @noinline $SCT.hessian_tracer_1_to_1( + tx, is_der1_arg1_zero, is_der2_arg1_zero + ) return $SCT.Dual(p_out, t_out) end end @@ -309,7 +315,9 @@ function generate_code_hessian_2_to_1_typed( ty = $SCT.tracer(dy) is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y) is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y) - t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) + t_out = @noinline $SCT.hessian_tracer_1_to_1( + ty, is_der1_arg2_zero, is_der2_arg2_zero + ) return $SCT.Dual(p_out, t_out) end end @@ -319,7 +327,7 @@ end ## 1-to-2 -@noinline function hessian_tracer_1_to_2( +function hessian_tracer_1_to_2( t::T, is_der1_out1_zero::Bool, is_der2_out1_zero::Bool, @@ -344,7 +352,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function) expr_hessiantracer = quote function $M.$fname(t::$SCT.HessianTracer) - return $SCT.hessian_tracer_1_to_2( + return @noinline $SCT.hessian_tracer_1_to_2( t, $is_der1_out1_zero_g, $is_der2_out1_zero_g, @@ -375,7 +383,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function) is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x) is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x) is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x) - t_out1, t_out2 = $SCT.hessian_tracer_1_to_2( + t_out1, t_out2 = @noinline $SCT.hessian_tracer_1_to_2( d, is_der1_out1_zero, is_der2_out1_zero, diff --git a/test/Project.toml b/test/Project.toml index 26182e3..5ae20d2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,7 +3,6 @@ ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/test/runtests.jl b/test/runtests.jl index 453f40c..8e9fd75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ Pkg.develop(; ) using SparseConnectivityTracer -using Compat: pkgversion using Documenter: Documenter, DocMeta using Test diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 02e6989..45fe991 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -2,7 +2,6 @@ using SparseConnectivityTracer using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError using Test -using Compat: Returns using Random: rand, GLOBAL_RNG using LinearAlgebra: det, dot, logdet