Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some basic developer docs for custom rules #105

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pages = [
"tutorials/optimizations.md",
],
"Public API" => "public_api.md",
"Developer documentation" => "devdocs.md",
"Limitations" => "limitations.md",
]

Expand Down
86 changes: 84 additions & 2 deletions docs/src/devdocs.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,85 @@
# Developer documentation
# Developer documentation (WIP)

*Coming soon!*
## Writing a custom rule for stochastic triples

### via `StochasticAD.propagate`

To handle a deterministic discrete construct that `StochasticAD` does not automatically handle (e.g. branching via `if`, boolean comparisons), it is often sufficient to simply add a dispatch rule that calls out to `StochasticAD.propagate`.

```@docs
StochasticAD.propagate
```

### via a custom dispatch

If a function does not meet the conditions of `StochasticAD.propagate` and is not already supported, a custom
dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:

```@example rule
import Random # hide
Random.seed!(1234) # hide
using Distributions
function mygeometric(p)
x = 0
while !(rand(Bernoulli(p)))
x += 1
end
return x
end
```

This is equivalent to `rand(Geometric(p))` which is already supported, but for pedagogical purposes we will
implement our own rule from scratch. Using the stochastic derivative formulas from [Automatic Differentiation of Programs with Discrete Randomness](https://doi.org/10.48550/arXiv.2210.08572), the right stochastic derivative of this program is given by
```math
Y_R = X - 1, w_R = \frac{x}{p(1-p)},
```
and the left stochastic derivative of this program is given by
```math
Y_L = X + 1, w_L = -\frac{x+1}{p}.
```

Using these expressions, we can now write the dispatch rule for stochastic triples:

```@example rule
using StochasticAD
import StochasticAD: StochasticTriple, similar_new, similar_empty, combine
function mygeometric(p_st::StochasticTriple{T}) where {T}
p = p_st.value
x = mygeometric(p)

# Form the new discrete perturbations (combinations of weight w and perturbation Y - X)
Δs1 = if p_st.δ > 0
# right stochastic derivative
w = x / (p * (1 - p))
x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)
elseif p_st.δ < 0
# left stochastic derivative
w = (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
similar_new(p_st.Δs, 1, w)
else
similar_empty(p_st.Δs, Int)
end

# Propagate any existing perturbations to p through the function
function map_func(Δ)
# Sample mygeometric(p + Δ) independently. (A better strategy would be to couple to the original sample.)
mygeometric(p + Δ) - x
end
Δs2 = map(map_func, p_st.Δs)

# Return the output stochastic triple
StochasticTriple{T}(x, zero(x), combine((Δs2, Δs1)))
end
```
In the above, we used some of the interface functions supported by a collection of perturbations `Δs::StochasticAD.AbstractFIs`. These were `similar_empty(Δs, V)`, which created an empty perturbation of type `V`, `similar_new(Δs, Δ, w)`, which created a new perturbation of size `Δ` and weight `w`, `map(map_func, Δs)`,
which propagates a collection of perturbations through a mapping function, and `combine((Δs2, Δs1)))` which combines multiple collections of perturbations together.

We can test out our rule:
```@example rule
@show stochastic_triple(mygeometric, 0.1)

# try feeding an input that already has a pertrubation
f(x) = mygeometric(x + 0.4 * rand(Bernoulli(x)))
@show stochastic_triple(f, 0.1)
nothing # hide
```
57 changes: 46 additions & 11 deletions src/propagate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,56 @@ function strip_Δs(arg)
end

"""
propagate(f, args...; keep_deltas = Val{true})

Propagates `args` through a function `f`, handling stochastic triples appropriately.
This functionality is orthogonal to dispatch: the idea is for this function to be
the "backend" for operator overloading rules.
Currently, we handle deterministic functions `f` with input and output supported by `Functors.jl`.
If `f` has a continuously differentiable component that should be kept,
This function is highly experimental, and is intentionally undocumented.
propagate(f, args...; keep_deltas = Val(false))

Propagates `args` through a function `f`, handling stochastic triples by independently running `f` on the primal
and the alternatives, rather than by inspecting the internals of `f` (which may possibly be unsupported by `StochasticAD`).
Currently handles deterministic functions `f` with any input and output that is `fmap`-able by `Functors.jl`.
If `f` has a continuously differentiable component, provide `keep_deltas = Val(true)`.

This functionality is orthogonal to dispatch: the idea is for this function to be the "backend" for operator
overloading rules based on dispatch. For example:

```jldoctest
using StochasticAD, Distributions
import Random # hide
Random.seed!(4321) # hide

function mybranch(x)
str = repr(x) # string-valued intermediate!
if length(str) < 2
return 3
else
return 7
end
end

function f(x)
return mybranch(9 + rand(Bernoulli(x)))
end

# stochastic_triple(f, 0.5) # this would fail

# Add a dispatch rule for mybranch using StochasticAD.propagate
mybranch(x::StochasticAD.StochasticTriple) = StochasticAD.propagate(mybranch, x)

stochastic_triple(f, 0.5) # now works

# output

StochasticTriple of Int64:
3 + 0ε + (4 with probability 2.0ε)
```

!!! warning
This function is experimental and subject to change.
"""
# TODO: support kwargs to f (or just use kwfunc in macro)
function propagate(f,
args...;
keep_deltas = Val{false},
keep_deltas = Val(false),
provided_st_rep = nothing,
deriv = nothing)
# TODO: support kwargs to f (or just use kwfunc in macro)
#=
TODO: maybe don't iterate through every scalar of array below,
but rather have special array dispatch
Expand All @@ -78,7 +113,7 @@ function propagate(f,
end

primal_args = structural_map(get_value, args)
input_args = keep_deltas == Val{false} ? primal_args : structural_map(strip_Δs, args)
input_args = keep_deltas isa Val{false} ? primal_args : structural_map(strip_Δs, args)
#=
TODO: the below is dangerous is general.
It should be safe so long as f does not close over stochastic triples.
Expand Down
Loading