Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update code generation utilities #183

Merged
merged 4 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 8 additions & 15 deletions docs/src/dev/adding_overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,13 @@ Depending on the type of function you're dealing with, you will have to specify
## [Overloading](@id code-gen)

After implementing the required classification methods for a function, the function has not been overloaded on our tracer types yet.
SCT provides six functions that generate code via meta-programming:
SCT provides three functions that generate code via meta-programming:

* 1-to-1
* `eval(SCT.overload_gradient_1_to_1(module_symbol, f))`
* `eval(SCT.overload_hessian_1_to_1(module_symbol, f))`
* 2-to-1
* `eval(SCT.overload_gradient_1_to_2(module_symbol, f))`
* `eval(SCT.overload_hessian_1_to_2(module_symbol, f))`
* 1-to-2
* `eval(SCT.overload_gradient_2_to_1(module_symbol, f))`
* `eval(SCT.overload_hessian_2_to_1(module_symbol, f))`
* 1-to-1: `eval(SCT.generate_code_1_to_1(module_symbol, f))`
* 2-to-1: `eval(SCT.generate_code_1_to_2(module_symbol, f))`
* 1-to-2: `eval(SCT.generate_code_2_to_1(module_symbol, f))`

You are required to call the two functions that match your type of operator.
You are required to call the function that matches your type of operator.

!!! tip "Code generation"
We will take a look at the code generation mechanism in the example below.
Expand Down Expand Up @@ -170,17 +164,16 @@ The `relu` function has not been overloaded on our tracer types yet.
Let's call the code generation utilities from the [*"Overloading"*](@ref code-gen) section for this purpose:

```@example overload
eval(SCT.overload_gradient_1_to_1(:NNlib, relu))
eval(SCT.overload_hessian_1_to_1(:NNlib, relu))
eval(SCT.generate_code_1_to_1(:NNlib, relu))
```

The `relu` function is now ready to be called with SCT's tracer types.

!!! details "What is the eval call doing?"
Let's call `overload_gradient_1_to_1` without wrapping it `eval`:
Let's call `generate_code_1_to_1` without wrapping it `eval`:

```@example overload
SCT.overload_gradient_1_to_1(:NNlib, relu)
SCT.generate_code_1_to_1(:NNlib, relu)
```

As you can see, this returns a `quote`, a type of expression containing our generated Julia code.
Expand Down
3 changes: 1 addition & 2 deletions ext/SparseConnectivityTracerNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ SCT.is_der1_zero_local(::typeof(softshrink), x) = x > -0.5 && x < 0.5
ops_1_to_1 = union(ops_1_to_1_s, ops_1_to_1_f)

## Overload
eval(SCT.overload_gradient_1_to_1(:NNlib, ops_1_to_1))
eval(SCT.overload_hessian_1_to_1(:NNlib, ops_1_to_1))
eval(SCT.generate_code_1_to_1(:NNlib, ops_1_to_1))

## List operators for later testing
SCT.test_operators_1_to_1(::Val{:NNlib}) = ops_1_to_1
Expand Down
6 changes: 2 additions & 4 deletions ext/SparseConnectivityTracerSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ end
ops_2_to_1 = ops_2_to_1_ssc

## Overloads
eval(SCT.overload_gradient_1_to_1(:SpecialFunctions, ops_1_to_1))
eval(SCT.overload_gradient_2_to_1(:SpecialFunctions, ops_2_to_1))
eval(SCT.overload_hessian_1_to_1(:SpecialFunctions, ops_1_to_1))
eval(SCT.overload_hessian_2_to_1(:SpecialFunctions, ops_2_to_1))
eval(SCT.generate_code_1_to_1(:SpecialFunctions, ops_1_to_1))
eval(SCT.generate_code_2_to_1(:SpecialFunctions, ops_2_to_1))

## List operators for later testing
SCT.test_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1
Expand Down
6 changes: 3 additions & 3 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function gradient_tracer_1_to_1_inner(
end
end

function overload_gradient_1_to_1(M::Symbol, f)
function generate_code_gradient_1_to_1(M::Symbol, f)
fname = nameof(f)
is_der1_zero_g = is_der1_zero_global(f)

Expand Down Expand Up @@ -109,7 +109,7 @@ function gradient_tracer_2_to_1_inner(
end
end

function overload_gradient_2_to_1(M::Symbol, f)
function generate_code_gradient_2_to_1(M::Symbol, f)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der1_arg2_zero_g = is_der1_arg2_zero_global(f)
Expand Down Expand Up @@ -224,7 +224,7 @@ end
end
end

function overload_gradient_1_to_2(M::Symbol, f)
function generate_code_gradient_1_to_2(M::Symbol, f)
fname = nameof(f)
is_der1_out1_zero_g = is_der1_out1_zero_global(f)
is_der1_out2_zero_g = is_der1_out2_zero_global(f)
Expand Down
6 changes: 3 additions & 3 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ function hessian_tracer_1_to_1_inner(
return P(g_out, h_out) # return pattern
end

function overload_hessian_1_to_1(M::Symbol, f)
function generate_code_hessian_1_to_1(M::Symbol, f)
fname = nameof(f)
is_der1_zero_g = is_der1_zero_global(f)
is_der2_zero_g = is_der2_zero_global(f)
Expand Down Expand Up @@ -175,7 +175,7 @@ function hessian_tracer_2_to_1_inner(
return P(g_out, h_out) # return pattern
end

function overload_hessian_2_to_1(M::Symbol, f)
function generate_code_hessian_2_to_1(M::Symbol, f)
fname = nameof(f)
is_der1_arg1_zero_g = is_der1_arg1_zero_global(f)
is_der2_arg1_zero_g = is_der2_arg1_zero_global(f)
Expand Down Expand Up @@ -315,7 +315,7 @@ end
end
end

function overload_hessian_1_to_2(M::Symbol, f)
function generate_code_hessian_1_to_2(M::Symbol, f)
fname = nameof(f)
is_der1_out1_zero_g = is_der1_out1_zero_global(f)
is_der2_out1_zero_g = is_der2_out1_zero_global(f)
Expand Down
46 changes: 29 additions & 17 deletions src/overloads/utils.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
for overload in (
:overload_gradient_1_to_1,
:overload_gradient_2_to_1,
:overload_gradient_1_to_2,
:overload_hessian_1_to_1,
:overload_hessian_2_to_1,
:overload_hessian_1_to_2,
)
@eval function $overload(M::Symbol, ops::Union{AbstractVector,Tuple})
exprs = [$overload(M, op) for op in ops]
return Expr(:block, exprs...)
dims = (Symbol("1_to_1"), Symbol("2_to_1"), Symbol("1_to_2"))

# Generate both Gradient and Hessian code with one call to `generate_code_X_to_Y`
for d in dims
f = Symbol("generate_code_", d)
g = Symbol("generate_code_gradient_", d)
h = Symbol("generate_code_hessian_", d)

@eval function $f(M::Symbol, f)
expr_g = $g(M, f)
expr_h = $h(M, f)
return Expr(:block, expr_g, expr_h)
end
end

# Allow all `generate_code_*` functions to be called on several operators at once
for d in dims
for f in (
Symbol("generate_code_", d),
Symbol("generate_code_gradient_", d),
Symbol("generate_code_hessian_", d),
)
@eval function $f(M::Symbol, ops::Union{AbstractVector,Tuple})
exprs = [$f(M, op) for op in ops]
return Expr(:block, exprs...)
end
end
end

## Overload operators
eval(overload_gradient_1_to_1(:Base, ops_1_to_1))
eval(overload_gradient_2_to_1(:Base, ops_2_to_1))
eval(overload_gradient_1_to_2(:Base, ops_1_to_2))
eval(overload_hessian_1_to_1(:Base, ops_1_to_1))
eval(overload_hessian_2_to_1(:Base, ops_2_to_1))
eval(overload_hessian_1_to_2(:Base, ops_1_to_2))
eval(generate_code_1_to_1(:Base, ops_1_to_1))
eval(generate_code_2_to_1(:Base, ops_2_to_1))
eval(generate_code_1_to_2(:Base, ops_1_to_2))

## List operators for later testing
test_operators_1_to_1(::Val{:Base}) = ops_1_to_1
Expand Down
Loading