diff --git a/src/softmax.jl b/src/softmax.jl index 4710f373c..4a42768c0 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -17,7 +17,10 @@ independent. 0.244728 0.665241 """ -softmax(xs) = softmax!(similar(xs), xs) +function softmax(xs::AbstractArray{T}; dims=1) where {T} + max = maximum(xs, dims=dims) + out = exp.(xs .- max) ./ sum(exp.(xs .- max), dims=dims) +end function softmax!(out::AbstractVecOrMat{T}, xs::AbstractVecOrMat{T}) where {T} @inbounds for j = 1:size(xs, 2) @@ -50,8 +53,10 @@ function ∇softmax!(out::AbstractVecOrMat, Δ::AbstractVecOrMat, xs::AbstractVe sf = softmax(xs) out .= sf .* (Δ .- sum(Δ .*sf, dims = 1)) end - -∇softmax(Δ, xs) = ∇softmax!(similar(Δ), Δ, xs) +function ∇softmax(Δ, xs; dims=1) + sf = softmax(xs, dims=dims) + out = sf .* (Δ .- sum(Δ .* sf, dims=dims)) +end ∇softmax!(Δ, xs) = ∇softmax!(Δ, Δ, xs)