Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize softmax #77

Closed
CarloLucibello opened this issue Nov 28, 2018 · 4 comments
Closed

generalize softmax #77

CarloLucibello opened this issue Nov 28, 2018 · 4 comments

Comments

@CarloLucibello
Copy link
Member

the softmax functions should be generalized to handle reduction across any dimensions, e.g.:

softmax(x, dims=2)
softmax(x, dims=(1,3))
@bhvieira
Copy link

bhvieira commented Aug 2, 2019

Is this solved? I have some interest in this to add attention support in Flux.

@CarloLucibello
Copy link
Member Author

Looking at this https://github.com/FluxML/NNlib.jl/blob/342928eb4478da9c7b1433ec75c8eb8a9b155747/src/softmax.jl
it seems that the issue with softmax has been solved, but softmax! and logsoftmax still don't support reduction over arbitrary dimensions.

@mcabbott
Copy link
Member

mcabbott commented Sep 2, 2019

The softmax on master (which takes dims=1) is only half as fast as the old one:

mm = rand(20,100);                                          
@btime softmax($mm);  # tagged, with softmax!,  23.076 μs (1 allocation: 15.75 KiB)
@btime softmax1($mm); # master, with dims=1,    47.657 μs (13 allocations: 33.45 KiB) 

Was this discussed somewhere? Are there goals besides being generic? Some variants which are almost as fast:

function softmax2(xs::AbstractArray{T}; dims=1) where {T}
    temp = maximum(xs, dims=dims)
    out = exp.(xs .- temp)
    out ./= sum!(temp, out)
end
function softmax3(xs::AbstractArray{T}; dims=1) where {T}
    max = maximum(xs, dims=dims)
    out = exp.(xs .- max)
    out ./ sum(out, dims=dims)
end

@btime softmax2($mm); # re-using temp,   27.382 μs (11 allocations: 16.83 KiB) 
@btime softmax3($mm); # no mutation,     26.462 μs (13 allocations: 33.45 KiB)  

@CarloLucibello
Copy link
Member Author

softmax and logsoftmax now support dims keyword, so this can be closed.
@mcabbott fill free to open a new Issue/PR if you think there is some performance problem.
The current implementation seems to be exactly the same as your softmax3

function softmax(xs::AbstractArray; dims=1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants