Skip to content

Commit

Permalink
gradient hook
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Jul 2, 2018
1 parent 5d8b63d commit ce88273
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/tracker/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ include("scalar.jl")
include("array.jl")
include("numeric.jl")

"""
hook(f, x) -> x′
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
the sign of the gradient applied to `x`.
"""
hook(f, x) = istracked(x) ? track(hook, f, x) : x
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ))

param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))

Expand Down

0 comments on commit ce88273

Please sign in to comment.