Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logsoftmax with dims #135

Merged
merged 4 commits into from
Sep 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,21 @@ export softmax, softmax!, ∇softmax, ∇softmax!,
log-probabilities (any real vector) and returns a probability distribution that
sums to 1.

If given a matrix it will treat it as a batch of vectors, with each column
independent.
If given a matrix it will by default (`dims=1`) treat it as a batch of vectors,
with each column independent. Keyword `dims=2` will instead treat rows independently, etc.

julia> softmax([1,2,3.])
3-element Array{Float64,1}:
0.0900306
0.244728
0.665241
```
julia> softmax([1,2,3.])
3-element Array{Float64,1}:
0.0900306
0.244728
0.665241
```
"""
function softmax(xs::AbstractArray{T}; dims=1) where {T}
max = maximum(xs, dims=dims)
out = exp.(xs .- max) ./ sum(exp.(xs .- max), dims=dims)
function softmax(xs::AbstractArray; dims=1)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
exp_ ./ sum(exp_, dims=dims)
end

function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T}
Expand Down Expand Up @@ -51,23 +54,29 @@ end

function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVecOrMat)
sf = softmax(xs)
out .= sf .* (Δ .- sum(Δ .*sf, dims = 1))
out .= sf .* (Δ .- sum(Δ .* sf, dims = 1))
end
function ∇softmax(Δ, xs; dims=1)
function ∇softmax(Δ, xs; dims=1)
sf = softmax(xs, dims=dims)
out = sf .* (Δ .- sum(Δ .* sf, dims=dims))
sf .* (Δ .- sum(Δ .* sf, dims=dims))
end
∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)


"""
logsoftmax(xs) = log.(exp.(xs) ./ sum(exp.(xs)))

`logsoftmax(xs)` computes the log of `softmax(xs)`, but in a more numerically stable
way than directly taking the log of the softmax function, which is commonly used in
Computes the log of softmax in a more numerically stable
way than directly taking `log.(softmax(xs))`. Commonly used in
computing cross entropy loss.
"""
logsoftmax(xs) = logsoftmax!(similar(xs), xs)
function logsoftmax(xs::AbstractArray; dims=1)
max_ = maximum(xs, dims=dims)
exp_ = exp.(xs .- max_)
log_ = log.(sum(exp_, dims=dims))
(xs .- max_) .- log_
end

function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
for j = 1:size(xs, 2)
@inbounds begin
Expand All @@ -86,5 +95,6 @@ function logsoftmax!(out::AbstractVecOrMat, xs::AbstractVecOrMat)
end
return out
end
∇logsoftmax(Δ, xs) = Δ - sum(Δ, dims=1) .* softmax(xs)

∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs)
∇logsoftmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)
4 changes: 3 additions & 1 deletion test/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ end
@testset "softmax" begin
xs = rand(5,5)
@test all(sum(softmax(xs), dims = 1) .≈ 1)
@test all(sum(softmax(xs; dims=2), dims = 2) .≈ 1)
@test sum(softmax(vec(xs))) ≈ 1
@test log.(softmax(xs; dims=2)) ≈ logsoftmax(xs; dims=2)

xs = [-100_000, -100_000.]
@test softmax(xs) ≈ [0.5, 0.5]
Expand All @@ -100,7 +102,7 @@ end
xs = Float32[1 2 3; 1000 2000 3000]
@test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.]

@test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1]
@test NNlib.∇logsoftmax(ones(size(xs)), xs) ≈ Float32[1 1 1; -1 -1 -1]
@test NNlib.∇softmax(ones(size(xs)), xs) ≈ zeros(Float32, size(xs))

# These values precalculated using PyTorch's nn.LogSoftmax
Expand Down