From cda26bd210bf4ebc8de7554a3e8d509b8f5edf1e Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 2 Sep 2024 16:24:47 +0200 Subject: [PATCH 1/4] Rename code-gen utils --- docs/src/dev/adding_overloads.md | 20 ++++++++-------- ext/SparseConnectivityTracerNNlibExt.jl | 4 ++-- ...seConnectivityTracerSpecialFunctionsExt.jl | 8 +++---- src/overloads/gradient_tracer.jl | 6 ++--- src/overloads/hessian_tracer.jl | 6 ++--- src/overloads/utils.jl | 24 +++++++++---------- 6 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/src/dev/adding_overloads.md b/docs/src/dev/adding_overloads.md index 99b1c444..2a74ac81 100644 --- a/docs/src/dev/adding_overloads.md +++ b/docs/src/dev/adding_overloads.md @@ -126,14 +126,14 @@ After implementing the required classification methods for a function, the funct SCT provides six functions that generate code via meta-programming: * 1-to-1 - * `eval(SCT.overload_gradient_1_to_1(module_symbol, f))` - * `eval(SCT.overload_hessian_1_to_1(module_symbol, f))` + * `eval(SCT.generate_code_gradient_1_to_1(module_symbol, f))` + * `eval(SCT.generate_code_hessian_1_to_1(module_symbol, f))` * 2-to-1 - * `eval(SCT.overload_gradient_1_to_2(module_symbol, f))` - * `eval(SCT.overload_hessian_1_to_2(module_symbol, f))` + * `eval(SCT.generate_code_gradient_1_to_2(module_symbol, f))` + * `eval(SCT.generate_code_hessian_1_to_2(module_symbol, f))` * 1-to-2 - * `eval(SCT.overload_gradient_2_to_1(module_symbol, f))` - * `eval(SCT.overload_hessian_2_to_1(module_symbol, f))` + * `eval(SCT.generate_code_gradient_2_to_1(module_symbol, f))` + * `eval(SCT.generate_code_hessian_2_to_1(module_symbol, f))` You are required to call the two functions that match your type of operator. @@ -170,17 +170,17 @@ The `relu` function has not been overloaded on our tracer types yet. Let's call the code generation utilities from the [*"Overloading"*](@ref code-gen) section for this purpose: ```@example overload -eval(SCT.overload_gradient_1_to_1(:NNlib, relu)) -eval(SCT.overload_hessian_1_to_1(:NNlib, relu)) +eval(SCT.generate_code_gradient_1_to_1(:NNlib, relu)) +eval(SCT.generate_code_hessian_1_to_1(:NNlib, relu)) ``` The `relu` function is now ready to be called with SCT's tracer types. !!! details "What is the eval call doing?" - Let's call `overload_gradient_1_to_1` without wrapping it `eval`: + Let's call `generate_code_gradient_1_to_1` without wrapping it `eval`: ```@example overload - SCT.overload_gradient_1_to_1(:NNlib, relu) + SCT.generate_code_gradient_1_to_1(:NNlib, relu) ``` As you can see, this returns a `quote`, a type of expression containing our generated Julia code. diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index e5457496..52fe8de2 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -83,8 +83,8 @@ SCT.is_der1_zero_local(::typeof(softshrink), x) = x > -0.5 && x < 0.5 ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f) ## Overload -eval(SCT.overload_gradient_1_to_1(:NNlib, ops_1_to_1)) -eval(SCT.overload_hessian_1_to_1(:NNlib, ops_1_to_1)) +eval(SCT.generate_code_gradient_1_to_1(:NNlib, ops_1_to_1)) +eval(SCT.generate_code_hessian_1_to_1(:NNlib, ops_1_to_1)) ## List operators for later testing SCT.test_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1 diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index f25d0d87..5fd13810 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -111,10 +111,10 @@ end ops_2_to_1 = ops_2_to_1_ssc ## Overloads -eval(SCT.overload_gradient_1_to_1(:SpecialFunctions, ops_1_to_1)) -eval(SCT.overload_gradient_2_to_1(:SpecialFunctions, ops_2_to_1)) -eval(SCT.overload_hessian_1_to_1(:SpecialFunctions, ops_1_to_1)) -eval(SCT.overload_hessian_2_to_1(:SpecialFunctions, ops_2_to_1)) +eval(SCT.generate_code_gradient_1_to_1(:SpecialFunctions, ops_1_to_1)) +eval(SCT.generate_code_gradient_2_to_1(:SpecialFunctions, ops_2_to_1)) +eval(SCT.generate_code_hessian_1_to_1(:SpecialFunctions, ops_1_to_1)) +eval(SCT.generate_code_hessian_2_to_1(:SpecialFunctions, ops_2_to_1)) ## List operators for later testing SCT.test_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1 diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index 1e255161..40ae1d1b 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -30,7 +30,7 @@ function gradient_tracer_1_to_1_inner( end end -function overload_gradient_1_to_1(M::Symbol, f) +function generate_code_gradient_1_to_1(M::Symbol, f) fname = nameof(f) is_der1_zero_g = is_der1_zero_global(f) @@ -109,7 +109,7 @@ function gradient_tracer_2_to_1_inner( end end -function overload_gradient_2_to_1(M::Symbol, f) +function generate_code_gradient_2_to_1(M::Symbol, f) fname = nameof(f) is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) is_der1_arg2_zero_g = is_der1_arg2_zero_global(f) @@ -224,7 +224,7 @@ end end end -function overload_gradient_1_to_2(M::Symbol, f) +function generate_code_gradient_1_to_2(M::Symbol, f) fname = nameof(f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der1_out2_zero_g = is_der1_out2_zero_global(f) diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 59201cb3..79e61968 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -57,7 +57,7 @@ function hessian_tracer_1_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_1_to_1(M::Symbol, f) +function generate_code_hessian_1_to_1(M::Symbol, f) fname = nameof(f) is_der1_zero_g = is_der1_zero_global(f) is_der2_zero_g = is_der2_zero_global(f) @@ -175,7 +175,7 @@ function hessian_tracer_2_to_1_inner( return P(g_out, h_out) # return pattern end -function overload_hessian_2_to_1(M::Symbol, f) +function generate_code_hessian_2_to_1(M::Symbol, f) fname = nameof(f) is_der1_arg1_zero_g = is_der1_arg1_zero_global(f) is_der2_arg1_zero_g = is_der2_arg1_zero_global(f) @@ -315,7 +315,7 @@ end end end -function overload_hessian_1_to_2(M::Symbol, f) +function generate_code_hessian_1_to_2(M::Symbol, f) fname = nameof(f) is_der1_out1_zero_g = is_der1_out1_zero_global(f) is_der2_out1_zero_g = is_der2_out1_zero_global(f) diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl index 98e4dc85..181e0098 100644 --- a/src/overloads/utils.jl +++ b/src/overloads/utils.jl @@ -1,10 +1,10 @@ for overload in ( - :overload_gradient_1_to_1, - :overload_gradient_2_to_1, - :overload_gradient_1_to_2, - :overload_hessian_1_to_1, - :overload_hessian_2_to_1, - :overload_hessian_1_to_2, + :generate_code_gradient_1_to_1, + :generate_code_gradient_2_to_1, + :generate_code_gradient_1_to_2, + :generate_code_hessian_1_to_1, + :generate_code_hessian_2_to_1, + :generate_code_hessian_1_to_2, ) @eval function $overload(M::Symbol, ops::Union{AbstractVector,Tuple}) exprs = [$overload(M, op) for op in ops] @@ -13,12 +13,12 @@ for overload in ( end ## Overload operators -eval(overload_gradient_1_to_1(:Base, ops_1_to_1)) -eval(overload_gradient_2_to_1(:Base, ops_2_to_1)) -eval(overload_gradient_1_to_2(:Base, ops_1_to_2)) -eval(overload_hessian_1_to_1(:Base, ops_1_to_1)) -eval(overload_hessian_2_to_1(:Base, ops_2_to_1)) -eval(overload_hessian_1_to_2(:Base, ops_1_to_2)) +eval(generate_code_gradient_1_to_1(:Base, ops_1_to_1)) +eval(generate_code_gradient_2_to_1(:Base, ops_2_to_1)) +eval(generate_code_gradient_1_to_2(:Base, ops_1_to_2)) +eval(generate_code_hessian_1_to_1(:Base, ops_1_to_1)) +eval(generate_code_hessian_2_to_1(:Base, ops_2_to_1)) +eval(generate_code_hessian_1_to_2(:Base, ops_1_to_2)) ## List operators for later testing test_operators_1_to_1(::Val{:Base}) = ops_1_to_1 From b140297fe645f7903d4ffa57ad7edfbf064e5299 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 2 Sep 2024 16:38:44 +0200 Subject: [PATCH 2/4] Add `generate_code_X_to_Y` utility --- src/overloads/utils.jl | 46 ++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/src/overloads/utils.jl b/src/overloads/utils.jl index 181e0098..e9118e27 100644 --- a/src/overloads/utils.jl +++ b/src/overloads/utils.jl @@ -1,24 +1,36 @@ -for overload in ( - :generate_code_gradient_1_to_1, - :generate_code_gradient_2_to_1, - :generate_code_gradient_1_to_2, - :generate_code_hessian_1_to_1, - :generate_code_hessian_2_to_1, - :generate_code_hessian_1_to_2, -) - @eval function $overload(M::Symbol, ops::Union{AbstractVector,Tuple}) - exprs = [$overload(M, op) for op in ops] - return Expr(:block, exprs...) +dims = (Symbol("1_to_1"), Symbol("2_to_1"), Symbol("1_to_2")) + +# Generate both Gradient and Hessian code with one call to `generate_code_X_to_Y` +for d in dims + f = Symbol("generate_code_", d) + g = Symbol("generate_code_gradient_", d) + h = Symbol("generate_code_hessian_", d) + + @eval function $f(M::Symbol, f) + expr_g = $g(M, f) + expr_h = $h(M, f) + return Expr(:block, expr_g, expr_h) + end +end + +# Allow all `generate_code_*` functions to be called on several operators at once +for d in dims + for f in ( + Symbol("generate_code_", d), + Symbol("generate_code_gradient_", d), + Symbol("generate_code_hessian_", d), + ) + @eval function $f(M::Symbol, ops::Union{AbstractVector,Tuple}) + exprs = [$f(M, op) for op in ops] + return Expr(:block, exprs...) + end end end ## Overload operators -eval(generate_code_gradient_1_to_1(:Base, ops_1_to_1)) -eval(generate_code_gradient_2_to_1(:Base, ops_2_to_1)) -eval(generate_code_gradient_1_to_2(:Base, ops_1_to_2)) -eval(generate_code_hessian_1_to_1(:Base, ops_1_to_1)) -eval(generate_code_hessian_2_to_1(:Base, ops_2_to_1)) -eval(generate_code_hessian_1_to_2(:Base, ops_1_to_2)) +eval(generate_code_1_to_1(:Base, ops_1_to_1)) +eval(generate_code_2_to_1(:Base, ops_2_to_1)) +eval(generate_code_1_to_2(:Base, ops_1_to_2)) ## List operators for later testing test_operators_1_to_1(::Val{:Base}) = ops_1_to_1 From cd4cf862b8495e94f8e23229a56287558567cdf5 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 2 Sep 2024 16:41:57 +0200 Subject: [PATCH 3/4] Update package extensions --- ext/SparseConnectivityTracerNNlibExt.jl | 3 +-- ext/SparseConnectivityTracerSpecialFunctionsExt.jl | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index 52fe8de2..ee74ecb4 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -83,8 +83,7 @@ SCT.is_der1_zero_local(::typeof(softshrink), x) = x > -0.5 && x < 0.5 ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f) ## Overload -eval(SCT.generate_code_gradient_1_to_1(:NNlib, ops_1_to_1)) -eval(SCT.generate_code_hessian_1_to_1(:NNlib, ops_1_to_1)) +eval(SCT.generate_code_1_to_1(:NNlib, ops_1_to_1)) ## List operators for later testing SCT.test_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1 diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index 5fd13810..210a2604 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -111,10 +111,8 @@ end ops_2_to_1 = ops_2_to_1_ssc ## Overloads -eval(SCT.generate_code_gradient_1_to_1(:SpecialFunctions, ops_1_to_1)) -eval(SCT.generate_code_gradient_2_to_1(:SpecialFunctions, ops_2_to_1)) -eval(SCT.generate_code_hessian_1_to_1(:SpecialFunctions, ops_1_to_1)) -eval(SCT.generate_code_hessian_2_to_1(:SpecialFunctions, ops_2_to_1)) +eval(SCT.generate_code_1_to_1(:SpecialFunctions, ops_1_to_1)) +eval(SCT.generate_code_2_to_1(:SpecialFunctions, ops_2_to_1)) ## List operators for later testing SCT.test_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1 From 7746e81eca2f731fca51004a5712079021668fe3 Mon Sep 17 00:00:00 2001 From: adrhill Date: Mon, 2 Sep 2024 16:42:06 +0200 Subject: [PATCH 4/4] Update docs --- docs/src/dev/adding_overloads.md | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/docs/src/dev/adding_overloads.md b/docs/src/dev/adding_overloads.md index 2a74ac81..198a18eb 100644 --- a/docs/src/dev/adding_overloads.md +++ b/docs/src/dev/adding_overloads.md @@ -123,19 +123,13 @@ Depending on the type of function you're dealing with, you will have to specify ## [Overloading](@id code-gen) After implementing the required classification methods for a function, the function has not been overloaded on our tracer types yet. -SCT provides six functions that generate code via meta-programming: +SCT provides three functions that generate code via meta-programming: -* 1-to-1 - * `eval(SCT.generate_code_gradient_1_to_1(module_symbol, f))` - * `eval(SCT.generate_code_hessian_1_to_1(module_symbol, f))` -* 2-to-1 - * `eval(SCT.generate_code_gradient_1_to_2(module_symbol, f))` - * `eval(SCT.generate_code_hessian_1_to_2(module_symbol, f))` -* 1-to-2 - * `eval(SCT.generate_code_gradient_2_to_1(module_symbol, f))` - * `eval(SCT.generate_code_hessian_2_to_1(module_symbol, f))` +* 1-to-1: `eval(SCT.generate_code_1_to_1(module_symbol, f))` +* 2-to-1: `eval(SCT.generate_code_1_to_2(module_symbol, f))` +* 1-to-2: `eval(SCT.generate_code_2_to_1(module_symbol, f))` -You are required to call the two functions that match your type of operator. +You are required to call the function that matches your type of operator. !!! tip "Code generation" We will take a look at the code generation mechanism in the example below. @@ -170,17 +164,16 @@ The `relu` function has not been overloaded on our tracer types yet. Let's call the code generation utilities from the [*"Overloading"*](@ref code-gen) section for this purpose: ```@example overload -eval(SCT.generate_code_gradient_1_to_1(:NNlib, relu)) -eval(SCT.generate_code_hessian_1_to_1(:NNlib, relu)) +eval(SCT.generate_code_1_to_1(:NNlib, relu)) ``` The `relu` function is now ready to be called with SCT's tracer types. !!! details "What is the eval call doing?" - Let's call `generate_code_gradient_1_to_1` without wrapping it `eval`: + Let's call `generate_code_1_to_1` without wrapping it `eval`: ```@example overload - SCT.generate_code_gradient_1_to_1(:NNlib, relu) + SCT.generate_code_1_to_1(:NNlib, relu) ``` As you can see, this returns a `quote`, a type of expression containing our generated Julia code.