Skip to content

Commit

Permalink
Merge pull request #126 from ornithos/logsoftmaxgrad
Browse files Browse the repository at this point in the history
Improve numerical stability of logsoftmax gradient
  • Loading branch information
MikeInnes authored Jun 12, 2019
2 parents f5fce7a + d219cdb commit a80bdff
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
end
return out
end
∇logsoftmax(Δ, xs) = ∇softmax./ max.(eps(eltype(xs)),softmax(xs)), xs)
∇logsoftmax(Δ, xs) = Δ - sum(Δ, dims=1) .* softmax(xs)
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
2 changes: 1 addition & 1 deletion test/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ end
xs = Float32[1 2 3; 1000 2000 3000]
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.]

@test NNlib.∇logsoftmax(ones(size(xs)), xs) zeros(Float32, size(xs))
@test NNlib.∇logsoftmax(ones(size(xs)), xs) Float32[1 1 1; -1 -1 -1]
@test NNlib.∇softmax(ones(size(xs)), xs) zeros(Float32, size(xs))

# These values precalculated using PyTorch's nn.LogSoftmax
Expand Down

0 comments on commit a80bdff

Please sign in to comment.