Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move @noinline into code generation utilities #205

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Move @noinline into generated code
adrhill committed Oct 9, 2024
commit ccef7f6ce2442a6899d0e093b5eaf427e387f2c2
30 changes: 15 additions & 15 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
@@ -2,9 +2,7 @@ SCT = SparseConnectivityTracer

## 1-to-1

@noinline function gradient_tracer_1_to_1(
t::T, is_der1_zero::Bool
) where {T<:GradientTracer}
function gradient_tracer_1_to_1(t::T, is_der1_zero::Bool) where {T<:GradientTracer}
if is_der1_zero && !isemptytracer(t)
return myempty(T)
else
@@ -36,7 +34,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(t, $is_der1_zero_g)
end
end

@@ -55,7 +53,7 @@ function generate_code_gradient_1_to_1(M::Symbol, f::Function)

t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
t_out = $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(t, is_der1_zero)
return $SCT.Dual(p_out, t_out)
end
end
@@ -65,7 +63,7 @@ end

## 2-to-1

@noinline function gradient_tracer_2_to_1(
function gradient_tracer_2_to_1(
tx::T, ty::T, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool
) where {T<:GradientTracer}
# TODO: add tests for isempty
@@ -116,7 +114,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.GradientTracer}
return $SCT.gradient_tracer_2_to_1(
return @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, $is_der1_arg1_zero_g, $is_der1_arg2_zero_g
)
end
@@ -141,7 +139,7 @@ function generate_code_gradient_2_to_1(M::Symbol, f::Function)
ty = $SCT.tracer(dy)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_2_to_1(
t_out = @noinline $SCT.gradient_tracer_2_to_1(
tx, ty, is_der1_arg1_zero, is_der1_arg2_zero
)
return $SCT.Dual(p_out, t_out)
@@ -164,12 +162,12 @@ function generate_code_gradient_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.GradientTracer, ::$Z)
return $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(tx, $is_der1_arg1_zero_g)
end
end
expr_type_tracer = quote
function $M.$fname(::$Z, ty::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_1(ty, $is_der1_arg2_zero_g)
end
end

@@ -188,7 +186,7 @@ function generate_code_gradient_2_to_1_typed(

tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(tx, is_der1_arg1_zero)
return $SCT.Dual(p_out, t_out)
end
end
@@ -208,7 +206,7 @@ function generate_code_gradient_2_to_1_typed(

ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
t_out = @noinline $SCT.gradient_tracer_1_to_1(ty, is_der1_arg2_zero)
return $SCT.Dual(p_out, t_out)
end
end
@@ -218,7 +216,7 @@ end

## 1-to-2

@noinline function gradient_tracer_1_to_2(
function gradient_tracer_1_to_2(
t::T, is_der1_out1_zero::Bool, is_der1_out2_zero::Bool
) where {T<:GradientTracer}
if isemptytracer(t) # TODO: add test
@@ -237,7 +235,9 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)

expr_gradienttracer = quote
function $M.$fname(t::$SCT.GradientTracer)
return $SCT.gradient_tracer_1_to_2(t, $is_der1_out1_zero_g, $is_der1_out2_zero_g)
return @noinline $SCT.gradient_tracer_1_to_2(
t, $is_der1_out1_zero_g, $is_der1_out2_zero_g
)
end
end

@@ -257,7 +257,7 @@ function generate_code_gradient_1_to_2(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der1_out1_zero = $SCT.is_der1_out1_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.gradient_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.gradient_tracer_1_to_2(
t, is_der1_out1_zero, is_der1_out2_zero
)
return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) # TODO: this was wrong, add test
34 changes: 21 additions & 13 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ SCT = SparseConnectivityTracer
# 𝟙[∇γ] = 𝟙[∂φ]⋅𝟙[∇α]
# 𝟙[∇²γ] = 𝟙[∂φ]⋅𝟙[∇²α] ∨ 𝟙[∂²φ]⋅(𝟙[∇α] ∨ 𝟙[∇α]ᵀ)

@noinline function hessian_tracer_1_to_1(
function hessian_tracer_1_to_1(
t::T, is_der1_zero::Bool, is_der2_zero::Bool
) where {P<:AbstractHessianPattern,T<:HessianTracer{P}}
if isemptytracer(t) # TODO: add test
@@ -65,7 +65,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
expr_hessiantracer = quote
## HessianTracer
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(t, $is_der1_zero_g, $is_der2_zero_g)
end
end

@@ -85,7 +85,7 @@ function generate_code_hessian_1_to_1(M::Symbol, f::Function)
t = $SCT.tracer(d)
is_der1_zero = $SCT.is_der1_zero_local($M.$fname, x)
is_der2_zero = $SCT.is_der2_zero_local($M.$fname, x)
t_out = $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(t, is_der1_zero, is_der2_zero)
return $SCT.Dual(p_out, t_out)
end
end
@@ -96,7 +96,7 @@ end

## 2-to-1

@noinline function hessian_tracer_2_to_1(
function hessian_tracer_2_to_1(
tx::T,
ty::T,
is_der1_arg1_zero::Bool,
@@ -189,7 +189,7 @@ function generate_code_hessian_2_to_1(

expr_tracer_tracer = quote
function $M.$fname(tx::T, ty::T) where {T<:$SCT.HessianTracer}
return $SCT.hessian_tracer_2_to_1(
return @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
$is_der1_arg1_zero_g,
@@ -232,7 +232,7 @@ function generate_code_hessian_2_to_1(
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
is_der_cross_zero = $SCT.is_der_cross_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_2_to_1(
t_out = @noinline $SCT.hessian_tracer_2_to_1(
tx,
ty,
is_der1_arg1_zero,
@@ -263,12 +263,16 @@ function generate_code_hessian_2_to_1_typed(

expr_tracer_type = quote
function $M.$fname(tx::$SCT.HessianTracer, y::$Z)
return $SCT.hessian_tracer_1_to_1(tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
tx, $is_der1_arg1_zero_g, $is_der2_arg1_zero_g
)
end
end
expr_type_tracer = quote
function $M.$fname(x::$Z, ty::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_1(ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g)
return @noinline $SCT.hessian_tracer_1_to_1(
ty, $is_der1_arg2_zero_g, $is_der2_arg2_zero_g
)
end
end

@@ -288,7 +292,9 @@ function generate_code_hessian_2_to_1_typed(
tx = $SCT.tracer(dx)
is_der1_arg1_zero = $SCT.is_der1_arg1_zero_local($M.$fname, x, y)
is_der2_arg1_zero = $SCT.is_der2_arg1_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
tx, is_der1_arg1_zero, is_der2_arg1_zero
)
return $SCT.Dual(p_out, t_out)
end
end
@@ -309,7 +315,9 @@ function generate_code_hessian_2_to_1_typed(
ty = $SCT.tracer(dy)
is_der1_arg2_zero = $SCT.is_der1_arg2_zero_local($M.$fname, x, y)
is_der2_arg2_zero = $SCT.is_der2_arg2_zero_local($M.$fname, x, y)
t_out = $SCT.hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero)
t_out = @noinline $SCT.hessian_tracer_1_to_1(
ty, is_der1_arg2_zero, is_der2_arg2_zero
)
return $SCT.Dual(p_out, t_out)
end
end
@@ -319,7 +327,7 @@ end

## 1-to-2

@noinline function hessian_tracer_1_to_2(
function hessian_tracer_1_to_2(
t::T,
is_der1_out1_zero::Bool,
is_der2_out1_zero::Bool,
@@ -344,7 +352,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)

expr_hessiantracer = quote
function $M.$fname(t::$SCT.HessianTracer)
return $SCT.hessian_tracer_1_to_2(
return @noinline $SCT.hessian_tracer_1_to_2(
t,
$is_der1_out1_zero_g,
$is_der2_out1_zero_g,
@@ -375,7 +383,7 @@ function generate_code_hessian_1_to_2(M::Symbol, f::Function)
is_der2_out1_zero = $SCT.is_der2_out1_zero_local($M.$fname, x)
is_der1_out2_zero = $SCT.is_der1_out2_zero_local($M.$fname, x)
is_der2_out2_zero = $SCT.is_der2_out2_zero_local($M.$fname, x)
t_out1, t_out2 = $SCT.hessian_tracer_1_to_2(
t_out1, t_out2 = @noinline $SCT.hessian_tracer_1_to_2(
d,
is_der1_out1_zero,
is_der2_out1_zero,