Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document limitations #175

Merged
merged 14 commits into from
Aug 20, 2024
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ makedocs(;
),
pages=[
"Getting Started" => "index.md",
"User Documentation" => ["API Reference" => "user/api.md"],
"User Documentation" =>
["Limitations" => "user/limitations.md", "API Reference" => "user/api.md"],
"Developer Documentation" => [
"How SCT works" => "dev/how_it_works.md",
"Internals Reference" => "dev/api.md",
Expand Down
4 changes: 2 additions & 2 deletions docs/src/dev/how_it_works.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ flowchart LR
```
To obtain a sparsity pattern, each scalar input $x_i$ gets seeded with a corresponding singleton index set $\{i\}$ [^1].
Since addition and multiplication have non-zero derivatives with respect to both of their inputs,
the resulting values accumulate and propagate their index sets (annotated on the edges of the graph above).
their outputs accumulate and propagate the index sets of their inputs (annotated on the edges of the graph above).
The sign function has zero derivatives for any input value. It therefore doesn't propagate the index set ${4}$ corresponding to the input $x_4$. Instead, it returns an empty set.

[^1]: since $\frac{\partial x_i}{\partial x_j} \neq 0$ iff $i \neq j$
[^1]: $\frac{\partial x_i}{\partial x_j} \neq 0$ only holds for $i=j$

The resulting **global** gradient sparsity pattern $\left(\nabla f(\mathbf{x})\right)_{i} \neq 1$ for $i$ in $\{1, 2, 3\}$ matches the analytical gradient

Expand Down
6 changes: 1 addition & 5 deletions docs/src/user/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@ CollapsedDocStrings = true
```

# [API Reference](@id api)
```@index
```

## ADTypes Interface

SparseConnectivityTracer uses [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s interface for [sparsity detection](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector).
SparseConnectivityTracer uses [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s [interface for sparsity detection](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector).
In fact, the functions `jacobian_sparsity` and `hessian_sparsity` are re-exported from ADTypes.

To compute **global** sparsity patterns of `f(x)` over the entire input domain `x`, use
Expand Down
126 changes: 126 additions & 0 deletions docs/src/user/limitations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# [Limitations](@id limitations)

## Sparsity patterns are conservative approximations

Sparsity patterns returned by SparseConnectivityTracer (SCT) can in some cases be overly conservative, meaning that they might contain "too many ones".
If you observe an overly conservative pattern, [please open a feature request](https://github.com/adrhill/SparseConnectivityTracer.jl/issues) so we know where to add more method overloads to increase the sparsity.

!!! warning "SCT's no-false-negatives policy"
If you ever observe a sparsity pattern that contains too many zeros, we urge you to [open a bug report](https://github.com/adrhill/SparseConnectivityTracer.jl/issues)!

## Function must be composed of generic Julia functions

SCT can't trace through non-Julia code.
However, if you know the sparsity pattern of an external, non-Julia function,
you might be able to work around it by adding methods on SCT's tracer types.

## Function types must be generic

When computing the sparsity pattern of a function,
it must be written generically enough to accept numbers of type `T<:Real` as (or `AbstractArray{<:Real}`) as inputs.

!!! details "Example: Overly restrictive type annotations"
Let's see this mistake in action:

```@example notgeneric
using SparseConnectivityTracer
method = TracerSparsityDetector()

relu_bad(x::AbstractFloat) = max(zero(x), x)
outer_function_bad(xs) = sum(relu_bad, xs)
nothing # hide
```

Since tracers and dual numbers are `Real` numbers and not `AbstractFloat`s,
`relu_bad` throws a `MethodError`:

```@repl notgeneric
xs = [1.0, -2.0, 3.0];

outer_function_bad(xs)

jacobian_sparsity(outer_function_bad, xs, method)
```

This is easily fixed by loosening type restrictions or adding an additional methods on `Real`:

```@example notgeneric
relu_good(x) = max(zero(x), x)
outer_function_good(xs) = sum(relu_good, xs)
nothing # hide
```

```@repl notgeneric
jacobian_sparsity(outer_function_good, xs, method)
```

## Limited control flow

Only [`TracerLocalSparsityDetector`](@ref) supports comparison operators (`<`, `==`, ...), indicator functions (`iszero`, `iseven`, ...) and control flow.

[`TracerSparsityDetector`](@ref) does not support any boolean functions and control flow (with the exception of `iselse`).
This might seem unintuitive but follows from our policy stated above: SCT guarantees conservative sparsity patterns.
Using an approach based on operator-overloading, this means that global sparsity detection isn't allowed to hit any branching code.
`ifelse` is the only exception, since it allows us to evaluate both branches.


!!! warning "Common control flow errors"
By design, SCT will throw errors instead of returning wrong sparsity patterns. Common error messages include:

```julia
ERROR: TypeError: non-boolean [tracer type] used in boolean context
```

```julia
ERROR: Function [function] requires primal value(s).
A dual-number tracer for local sparsity detection can be used via `TracerLocalSparsityDetector`.
```

!!! details "Why does TracerSparsityDetector not support control flow and comparisons?"
Let us motivate the design decision above by a simple example function:

```@example ctrlflow
function f(x)
if x[1] > x[2]
return x[1]
else
return x[2]
end
end
nothing # hide
```

The desired **global** Jacobian sparsity pattern over the entire input domain $x \in \mathbb{R}^2$ is `[1 1]`.
Two **local** sparsity patterns are possible:
`[1 0]` for $\{x | x_1 > x_2\}$,
`[0 1]` for $\{x | x_1 \le x_2\}$.

The local sparsity patterns of [`TracerLocalSparsityDetector`](@ref) are easy to compute using operator overloading by using [dual numbers](@ref SparseConnectivityTracer.Dual)
which contain primal values on which we can evaluate comparisons like `>`:

```@repl ctrlflow
using SparseConnectivityTracer

jacobian_sparsity(f, [2, 1], TracerLocalSparsityDetector())

jacobian_sparsity(f, [1, 2], TracerLocalSparsityDetector())
```

The global sparsity pattern is **impossible** to compute when code branches with an if-else condition,
since we can only ever hit one branch during run-time.
If we made comparisons like `>` return `true` or `false`, we'd get the local patterns `[1 0]` and `[0 1]` respectively.
But SCT's policy is to guarantee conservative sparsity patterns, which means that "false positives" (ones) are acceptable, but "false negatives" (zeros) are not.
In my our opinion, the right thing to do here is to throw an error:

```@repl ctrlflow
jacobian_sparsity(f, [1, 2], TracerSparsityDetector())
```

In some cases, we can work around this by using `ifelse`.
Since `ifelse` is a method, it can evaluate "both branches" and take a conservative union of both resulting sparsity patterns:

```@repl ctrlflow
f(x) = ifelse(x[1] > x[2], x[1], x[2])

jacobian_sparsity(f, [1, 2], TracerSparsityDetector())
```
35 changes: 29 additions & 6 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ julia> jacobian_sparsity(diff, rand(4), TracerSparsityDetector())
```

```jldoctest
julia> using SparseConnectivityTracer

julia> f(x) = x[1] + x[2]*x[3] + 1/x[4];

julia> hessian_sparsity(f, rand(4), TracerSparsityDetector())
Expand Down Expand Up @@ -67,27 +65,52 @@ For global sparsity patterns, use [`TracerSparsityDetector`](@ref).

# Example

Local sparsity patterns are less convervative than global patterns and need to be recomputed for each input `x`:

```jldoctest
julia> using SparseConnectivityTracer

julia> method = TracerLocalSparsityDetector();

julia> f(x) = x[1] * x[2]; # J_f = [x[2], x[1]]

julia> jacobian_sparsity(f, [1, 0], method)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry:
⋅ 1

julia> jacobian_sparsity(f, [0, 1], method)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 1 stored entry:
1 ⋅

julia> jacobian_sparsity(f, [0, 0], method)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 0 stored entries:
⋅ ⋅

julia> jacobian_sparsity(f, [1, 1], method)
1×2 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries:
1 1
```

`TracerLocalSparsityDetector` can compute sparsity patterns of functions that contain comparisons and `ifelse` statements:


```jldoctest
julia> f(x) = x[1] > x[2] ? x[1:3] : x[2:4];

julia> jacobian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector())
julia> jacobian_sparsity(f, [1, 2, 3, 4], TracerLocalSparsityDetector())
3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
⋅ 1 ⋅ ⋅
⋅ ⋅ 1 ⋅
⋅ ⋅ ⋅ 1

julia> jacobian_sparsity(f, [2.0, 1.0, 3.0, 4.0], TracerLocalSparsityDetector())
julia> jacobian_sparsity(f, [2, 1, 3, 4], TracerLocalSparsityDetector())
3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries:
1 ⋅ ⋅ ⋅
⋅ 1 ⋅ ⋅
⋅ ⋅ 1 ⋅
```

```jldoctest
julia> using SparseConnectivityTracer

julia> f(x) = x[1] + max(x[2], x[3]) * x[3] + 1/x[4];

julia> hessian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector())
Expand Down
Loading