Skip to content

Commit

Permalink
Move @noinline into code generation utilities (#205)
Browse files Browse the repository at this point in the history
* Remove Compat.jl from tests

* Move @noinline into generated code
  • Loading branch information
adrhill authored Oct 9, 2024
1 parent b172a4c commit 6ab2198
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 31 deletions.
30 changes: 15 additions & 15 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ SCT = SparseConnectivityTracer

## 1-to-1

@noinline function gradient_tracer_1_to_1(
t::T, is_der1_zero::Bool
) where {T<:GradientTracer}
function gradient_tracer_1_to_1(t::T, is_der1_zero::Bool) where {T<:GradientTracer}
if is_der1_zero && !isemptytracer(t)
return myempty(T)
else
Expand Down Expand Up @@ -36,7 +34,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
end
end

Expand All @@ -55,7 +53,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -65,7 +63,7 @@ end

## 2-to-1

@noinline function gradient_tracer_2_to_1(
function gradient_tracer_2_to_1(
tx::T, ty::T, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool
) where {T<:GradientTracer}
# TODO: add tests for isempty
Expand Down Expand Up @@ -116,7 +114,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer}
return $SCT.gradient_tracer_2_to_1(
return @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g
)
end
Expand All @@ -141,7 +139,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)
ty = $SCT.tracer(dy)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_2_to_1(
t_out = @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, is_der1_arg1_zero, is_der1_arg2_zero
)
return $SCT.Dual(p_out, t_out)
Expand All @@ -164,12 +162,12 @@ function generate_code_gradient_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end
end
expr_type_tracer = quote
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end

Expand All @@ -188,7 +186,7 @@ function generate_code_gradient_2_to_1_typed(

tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -208,7 +206,7 @@ function generate_code_gradient_2_to_1_typed(

ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -218,7 +216,7 @@ end

## 1-to-2

@noinline function gradient_tracer_1_to_2(
function gradient_tracer_1_to_2(
t::T, is_der1_out1_zero::Bool, is_der1_out2_zero::Bool
) where {T<:GradientTracer}
if isemptytracer(t) # TODO: add test
Expand All @@ -237,7 +235,9 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_2(
t, $is_der1_out1_zero_g, $is_der1_out2_zero_g
)
end
end

Expand All @@ -257,7 +257,7 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.gradient_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.gradient_tracer_1_to_2(
t, is_der1_out1_zero, is_der1_out2_zero
)
return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test
Expand Down
34 changes: 21 additions & 13 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SCT = SparseConnectivityTracer
# 𝟙[∇γ] = 𝟙[∂φ]⋅𝟙[∇α]
# 𝟙[∇²γ] = 𝟙[∂φ]⋅𝟙[∇²α] ∨ 𝟙[∂²φ]⋅(𝟙[∇α] ∨ 𝟙[∇α]ᵀ)

@noinline function hessian_tracer_1_to_1(
function hessian_tracer_1_to_1(
t::T, is_der1_zero::Bool, is_der2_zero::Bool
) where {P<:AbstractHessianPattern,T<:HessianTracer{P}}
if isemptytracer(t) # TODO: add test
Expand Down Expand Up @@ -65,7 +65,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
expr_hessiantracer = quote
## HessianTracer
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
end
end

Expand All @@ -85,7 +85,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x)
t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -96,7 +96,7 @@ end

## 2-to-1

@noinline function hessian_tracer_2_to_1(
function hessian_tracer_2_to_1(
tx::T,
ty::T,
is_der1_arg1_zero::Bool,
Expand Down Expand Up @@ -189,7 +189,7 @@ function generate_code_hessian_2_to_1(

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer}
return $SCT.hessian_tracer_2_to_1(
return @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
$is_der1_arg1_zero_g,
Expand Down Expand Up @@ -232,7 +232,7 @@ function generate_code_hessian_2_to_1(
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_2_to_1(
t_out = @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
is_der1_arg1_zero,
Expand Down Expand Up @@ -263,12 +263,16 @@ function generate_code_hessian_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g
)
end
end
expr_type_tracer = quote
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g
)
end
end

Expand All @@ -288,7 +292,9 @@ function generate_code_hessian_2_to_1_typed(
tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
tx, is_der1_arg1_zero, is_der2_arg1_zero
)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -309,7 +315,9 @@ function generate_code_hessian_2_to_1_typed(
ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
ty, is_der1_arg2_zero, is_der2_arg2_zero
)
return $SCT.Dual(p_out, t_out)
end
end
Expand All @@ -319,7 +327,7 @@ end

## 1-to-2

@noinline function hessian_tracer_1_to_2(
function hessian_tracer_1_to_2(
t::T,
is_der1_out1_zero::Bool,
is_der2_out1_zero::Bool,
Expand All @@ -344,7 +352,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)

expr_hessiantracer = quote
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_2(
return @noinline $SCT.hessian_tracer_1_to_2(
t,
$is_der1_out1_zero_g,
$is_der2_out1_zero_g,
Expand Down Expand Up @@ -375,7 +383,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)
is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.hessian_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.hessian_tracer_1_to_2(
d,
is_der1_out1_zero,
is_der2_out1_zero,
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ Pkg.develop(;
)

using SparseConnectivityTracer
using Compat: pkgversion
using Documenter: Documenter, DocMeta
using Test

Expand Down
1 change: 0 additions & 1 deletion test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using SparseConnectivityTracer
using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError
using Test

using Compat: Returns
using Random: rand, GLOBAL_RNG
using LinearAlgebra: det, dot, logdet

Expand Down

0 comments on commit 6ab2198

Please sign in to comment.