-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add faster tanh
implementation(s)
#345
Conversation
We have been discussing a more structured approach to this lead by @chriselrod and NNlibCPU. |
I could add Or, maybe we should create a dedicated library to very fast but inaccurate special functions? @oscardssmith also expressed interest in such a library, and I think it makes sense as a stand alone repository. Would be good to still have some sort of accuracy guarantees, perhaps tied to a naming convention.
I wont be able to make the ML meeting this week, as I have a vet appointment at that time. But I'll try and merge that PR with basic dense, conv, and batch layers working. |
Also, that Xeon in your benchmark seems to be very slow. Using: using LoopVectorization, BenchmarkTools, LinearAlgebra
Sys.ARCH === :x86_64 && using MKL
BLAS.set_num_threads(1)
using IfElse: ifelse
@inline function tanh_faster(x::Real)
T = float(typeof(x))
x2 = 4x^2
# y = 94.9339451088 * x * ((( x2 + 1.06315869e+03 ) * x2 + 1.88748783e+05 ) * x2 + 5.86237309e+06 ) / ((( ( x2 + 4.03183926e+03 ) * x2 + 1.64253046e+06 ) * x2 + 1.28592857e+08 ) * x2 + 1.11307745e+09 )
num = @fastmath T(94.9339451088) * 2x * ((( x2 + T(1.06315869e+03) ) * x2 + T(1.88748783e+05) ) * x2 + T(5.86237309e+06) )
den = @fastmath ((( ( x2 + T(4.03183926e+03) ) * x2 + T(1.64253046e+06) ) * x2 + T(1.28592857e+08) ) * x2 + T(1.11307745e+09) )
y = num / den
stop = _fast_stop(tanh_faster, x)
ifelse(x > stop, one(y), ifelse(x < -stop, -one(y), y))
end
_fast_stop(::typeof(tanh_faster), x) = oftype(x, 10)
@btime x * x setup=(x=randn(Float32,100,100));
@btime y .= tanh.(x * x) setup=(x=randn(Float32,100,100); y=similar(x));
@btime y .= tanh_faster.(x * x) setup=(x=randn(Float32,100,100); y=similar(x));
@btime vmap!(tanh, y, x * x) setup=(x=randn(Float32,100,100); y=similar(x));
@btime vmap!(tanh_faster, y, x * x) setup=(x=randn(Float32,100,100); y=similar(x)); # with IfElse I am getting: julia> @btime x * x setup=(x=randn(Float32,100,100));
16.220 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
11.967 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
25.125 μs (2 allocations: 39.11 KiB) # M1 native
julia> @btime y .= tanh.(x * x) setup=(x=randn(Float32,100,100); y=similar(x));
75.780 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
94.014 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
83.917 μs (2 allocations: 39.11 KiB) # M1 native
julia> @btime y .= tanh_faster.(x * x) setup=(x=randn(Float32,100,100); y=similar(x));
19.565 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
15.447 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
28.417 μs (2 allocations: 39.11 KiB) # M1 native
julia> @btime vmap!(tanh, y, x * x) setup=(x=randn(Float32,100,100); y=similar(x));
27.560 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
21.269 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
45.625 μs (2 allocations: 39.11 KiB) # M1 native
julia> @btime vmap!(tanh_faster, y, x * x) setup=(x=randn(Float32,100,100); y=similar(x)); # with IfElse
19.097 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
13.739 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
28.708 μs (2 allocations: 39.11 KiB) # M1 native |
This was supposed to be a super-quick fix, but floating point stuff is messy! Agree the nicest possible outcome would be a super-low-precision library. I think that for most ML purposes a smooth function which is <1pixel wrong on a graph is going to be plenty. My guess is that smoothness is more important than worst-case error, i.e. a steps are worse than ramps. NNlib might not be a terrible place for such things to live, though, if it's just NNlibCPU.jl looks more ambitious. From very informal investigation there's lots of room for faster CPU If that's an optional add-on, then the argument for ps. The xeon is indeed older than I realised, 2012, pre-AVX2. Just the easiest non-M1 around. It's also possible that |
One thing to try is using exp2 instead of exp. It may lose a tiny bit of precision, but be notably faster. exp needs some extra work to do the reduction, while exp2 gets it for free. |
Is PyTorch doing something extra besides using |
I never looked inside that PR somehow. The times for julia> @btime y .= tanh.(x) setup=(x = randn(Float32, 10^3); y = similar(x););
2.963 μs (0 allocations: 0 bytes) # as used above
julia> @btime y .= tanh.(x) setup=(x = 0.1f0 .* randn(Float32, 10^3); y = similar(x););
1.129 μs (0 allocations: 0 bytes) # small => more kernel
julia> @btime y .= tanh.(x) setup=(x = 100 .* randn(Float32, 10^3); y = similar(x););
627.491 ns (0 allocations: 0 bytes) # large => more flat
julia> @btime y .= tanh_faster.(x) setup=(x = randn(Float32, 10^3); y = similar(x));
346.005 ns (0 allocations: 0 bytes) # this PR's rational function
julia> @btime y .= Base.Math.tanh_kernel.(x) setup=(x = rand(Float32, 10^3); y = similar(x););
144.404 ns (0 allocations: 0 bytes) # just the polynomial kernel
julia> Base.Math.TANH_SMALL_X(Float32) # polynomial below this
1.3862944f0
julia> Base.Math.TANH_LARGE_X(Float32) # flat above this, but why so large?
18.0f0
julia> tanh(9f0) # already constant at half that.
1.0f0 |
Just wondering, how did you generate the coefficients? Theoretically, Remez.jl can make minimax rational approximations, but I've never had much luck (I believe due to problems with numerical conditioning). |
I copied them from here: https://math.stackexchange.com/questions/107292/rapid-approximation-of-tanhx |
Ok, I delete some details, they are in the gist. But after messing with Remez.jl a bit, I think this function is within 5 eps of true, abs error 2.929f-7. It's a little faster than my initial Maybe there are tricks to do even better, these coefficients come from function tanh_new(x::Float32)
x2 = abs2(x)
n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8))
d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7))
ifelse(x2 < 66f0, x * (n / d), sign(x))
end When I try things on the GPU, I believe Julia's |
There are some papers on how to make the remez algorithm more numerically stable, and on optimal minimax polynomials with floating point coefficients, but neither have been implimented yet (both are on my list of things to do if I have time). The other thing that should be relatively easy to add is ways of specifying fixed values for some coefficients. This is useful for generating even/odd polynomials, as well as ones where the first few terms can be represented without rounding |
src/activations.jl
Outdated
""" | ||
@inline function tanh_fast(x) | ||
@inline function tanh_fast(x::Real) | ||
exp2x = @fastmath exp(x + x) | ||
y = (exp2x - 1) / (exp2x + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use expm1 here. It will be about 30% slower, but you will get correct answers for small x. This implimentation has catestrophic cancellation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably! Agree this looks dodgy, will re-check Float64 times & accuracy at some point.
I also wonder if there's a clever formula for tanh(x)/x
, since it seems a little odd to get accuracy near 0 by expm1 when the desired result is close to linear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't look like it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switching to expm1
seems to cost me the entire speed gain:
julia> xs = collect(0.01:0.001:10); ys = similar(xs);
julia> @btime @. $ys = tanh($xs);
44.291 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = tanh_fast($xs); # with expm1
47.958 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = tanh_fast($xs); # with exp(2x)-1 ... M1 native + 1.8-
11.500 μs (0 allocations: 0 bytes)
Accuracy, on this same range of points, improves 65eps -> 1.7eps relative worst case, but no change in absolute worst case 1.6e-16. So, as you say, the error is all for very small x.
Maybe it can evaluate some polynomial for abs(x) < 0.13
or so, if 4eps is the goal.
Any news on this? |
src/activations.jl
Outdated
# That has bad errors near zero; using expm1 would be slower & more accurate. | ||
# Instead, we switch to an taylor series; seems to add about 50% to time. | ||
x2 = x * x | ||
ypoly = x * evalpoly(x2, (0.9999999999999999, -0.33333333333309806, 0.13333333318143492, -0.053968217983868146 , 0.021865628148606587, -0.008671836868790176)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do these coeffs do? (1.0, 0.33333333333333337, 0.1333333333332623, -0.0539682539194502, 0.021869476267930975, -0.00886184042142138, 0.0035188503873932893)
In general, I've found that forcing the leading coefficients to exact values can help with error since you remove the error in the coeficient. These were generated with the following code.
coefs = [1/big(3) , 2/big(15) ,- 17/big(315) , 62/big(2835) ,- 1382/big(155925) , 21844/big(6081075) ,- 929569/big(638512875) , 6404582/big(10854718875) ,- 443861162/big(1856156927625)]
julia> g(x)=evalpoly(x, coefs)
julia> Tuple(Float64.(ratfn_minimax(g, (1e-15,.017^2),6,0)[1]))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is better (with -0.3333 surely), indeed. In fact this looks to be within 1eps within the range for which I used the polynomial:
julia> maximum(x -> abs(tanh(big(x)) - tanh_fast(x)) / eps(Float64(tanh(x))), 0.000_01:0.000_01:0.13) |> Float64
0.9878056794145185 # with yours
1.8541809732394265 # with mine
Whereas the point which is 5eps is somewhere 0.2<x<0.3
. So perhaps the ideal changeover point is quite a bit larger... I think 0.13 was a very quick guess at where the error got to 4eps or so. I'm not being maximally systematic about this, clearly.
If I move the cutoff to No, within 5 still. By mistake I tested with exp not x^2 < 0.02
then this is within 3eps everywhere, perhaps?@fastmath exp
.
julia> maximum(x -> abs(tanh(big(x)) - tanh_fast(x)) / eps(Float64(tanh(x))), 1e-5:1e-6:2) |> Float64
2.4637177837743502 # with Base.exp, slow
4.923053945522114 # with @fastmath exp
julia> maximum(x -> abs(tanh(big(x)) - tanh_fast(x)) / eps(Float64(tanh(x))), -1e-5:-1e-6:-2) |> Float64
2.680360427525547 # slow
5.459406134894406 # fast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other option is to lower the degree polynomial. The one I gave is a minimax polynomial, so it won't be anywhere near optimal outside the range. Do you have an error threshold in mind? If so, I'll do some searching to find the minimal degree one with the required accuracy of the range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I vaguely had 5eps in mind, just because that was doable for the Float32 rational function --- there I could fit more more terms, but after rounding etc. couldn't do much better. I'm not sure how much NNlib cares about the Float64 case, it might just be for our amusement. But sure, if you are set up to search then a better polynomial + threshold would be welcome.
BTW I tried replacing exp(x+x)
with exp2(α x)
as suggested, but didn't see much change in speed. The accuracy does seem a little worse on one side, so perhaps it wants to be exp(abs(2x))
& restore the sign later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, it looks like to maintain 5 ULPs, we need a polynomial for |x| < 0.063
. Given that, we can lower the degree to 4, which gives coefficients of (1.0, -0.3333333333331805, 0.13333333295910244, -0.05396796907132604, 0.02178397571543739)
which have a maximal error of 1.2 ULPs in the range. Alternatively, bumping up to degree 5 lets the polynomial stay accurate to 2.5 ULPs over the whole range with coefficients (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953)
. I would probably go for this version, since the 1 extra fma won't change performance too much, and will give 2x less accuracy, which imo is worth it.
exp2x = @fastmath exp(x + x) | ||
y = (exp2x - 1) / (exp2x + 1) | ||
ifelse(x > 30, one(y), ifelse(x < -30, -one(y), y)) | ||
exp2x = @fastmath exp(x + x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about
@inline exp_inline(x) = Base.Math.exp_impl(x, Val(:ℯ))
and defining an associated diff rule, for the versions of Julia that define an inline Base.Math.exp_impl
?
This should let the compiler SIMD the code.
Of course, for tanh_fast
itself, you shouldn't have to define an exp_inline
diff rule, because tanh_fast
should have a rule itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw this in the Zulip thread and had a go... but for me this is somewhere between the @fastmath
variant and the plain exp
in speed:
julia> @btime @. $ys = tanh_fast($xs); # with @fastmath exp(x + x), and polynomial etc, M1 + Rosetta, Julia 1.7β
30.125 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = tanh_fast($xs); # with exp_inline
47.208 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = tanh_fast($xs); # with Base.exp
85.125 μs (0 allocations: 0 bytes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about Base.Math.exp_impl_fast(x, Val(:ℯ))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I think that needs 1.7)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using Base.Math.exp_impl_fast
seems to make no difference at all, strangely? I don't see an @inline
.
@inline function sigmoid_fast(x::Real) | ||
t = @fastmath exp(-abs(x)) | ||
y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t)) | ||
ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Float32, this sigmoid_fast
isn't as big an improvement as tanh_fast
. But quite a bit faster. And barely less accurate than the default, maybe 2 eps instead of 1 worst-case, but about the same average.
The rational function symmetric about zero gives you, if you shift it, terrible relative accuracy for negative x
, as the function now goes to zero.
julia> xs = collect(0.01:0.001:10); ys = similar(xs);
julia> xs_32 = Float32.(collect(xs)); ys_32 = similar(xs_32);
# Float32
julia> @btime @. $ys_32 = sigmoid($xs_32);
26.250 μs (0 allocations: 0 bytes)
julia> @btime @. $ys_32 = sigmoid_fast($xs_32);
7.135 μs (0 allocations: 0 bytes)
julia> @btime @. $ys_32 = sigmoid_fast_old($xs_32); # version with polynomial from tanh_fast
3.870 μs (0 allocations: 0 bytes)
# Float64
julia> @btime @. $ys = sigmoid($xs);
34.250 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = sigmoid_fast($xs);
12.000 μs (0 allocations: 0 bytes)
Compare tanh:
julia> @btime @. $ys_32 = tanh($xs_32);
27.125 μs (0 allocations: 0 bytes)
julia> @btime @. $ys_32 = tanh_fast($xs_32);
3.594 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = tanh($xs);
44.958 μs (0 allocations: 0 bytes)
julia> @btime @. $ys = tanh_fast($xs);
17.041 μs (0 allocations: 0 bytes)
Where are we here? |
This reverts commit e6befa1.
Status is that I drifted off. When last I was fiddling, I was trying to use Oscar's numbers here #345 (comment) but IIRC got terrible accuracy & wasn't sure why. It was nearly done, though. |
my numbers were minimax for a specific range, so if you were using a different range, they would be all wrong. |
OK, b53155f changes to use these coefficients. I may well have messed up the ranges last time. Re-reading today, I'm not entirely sure what range these are for, Something I don't think I checked until today, if I time |
Those coefs with that range (you got the range right) should get you 2.5 ULP if I'm doing the math right. |
OK, great. I don't see more than 2, at And the main goal anyway is the Float32 version. |
[skip ci]
Ah, I'd done the testing assuming the regular (not fastmath) versions. |
Oh right, with that I get worst 2 everywhere. But it's no longer obviously faster than
|
no, you definitely want the fastmath version, it just changes the tradeoff a little. I'm looking now to see if I can quickly get a better parameter tune given the tradeoff. |
After a little experimenting, it appears that the coefficients were actually pretty good. Lowering the error further isn't really possible (since the fastmath version needs to go up to |
OK, thanks for checking. Maybe we should call this ready? My last change is to restrict this polynomial story to Float64, so that if anyone tries it with dual numbers or bigfloats or something, they will get ordinary tanh. We can speed up other function in this file by large factors, too, e.g. |
Co-authored-by: Carlo Lucibello <[email protected]>
My guess is people do use gelu since that appears to be what Transformer models are moving towards. |
This proposes to add faster, lower-precision, versions of
tanh
and friends. Motivation:There has been talk of using LoopVectorization to speed this up instead. At least for the computers I tried, this lower-precision function is faster. Of course that may speed up other things, but
tanh
seem the outlier in terms of speed -- I quote times with matrix multiplication above, to illustrate that it dominates.WIP,
tanh_fast
was the initial idea, but at least for Float32, a rational approximation liketanh_faster
seems like a better idea, roughly one less digit & another factor of 3 in speed. Both need some care not to give crazy answers at largex
, the precise implementation could use some adjusting. Times & error measures in this gist.I also have no idea hot this interacts with GPU just yet.