Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faster tanh implementation(s) #345

Merged
merged 15 commits into from
Nov 8, 2021
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 7, 2021

This proposes to add faster, lower-precision, versions of tanh and friends. Motivation:

julia> @btime x * x   setup=(x=randn(Float32,100,100));
  20.250 μs (2 allocations: 39.11 KiB)  # mac M1, julia 1.7 + rosetta
  52.718 μs (2 allocations: 39.11 KiB)  # xeon E5-2603

julia> @btime y .= tanh.(x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
  107.167 μs (2 allocations: 39.11 KiB)
  229.725 μs (2 allocations: 39.11 KiB)

julia> @btime y .= tanh_faster.(x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
  27.125 μs (2 allocations: 39.11 KiB)
  67.623 μs (2 allocations: 39.11 KiB)

There has been talk of using LoopVectorization to speed this up instead. At least for the computers I tried, this lower-precision function is faster. Of course that may speed up other things, but tanh seem the outlier in terms of speed -- I quote times with matrix multiplication above, to illustrate that it dominates.

julia> using LoopVectorization

julia> @btime vmap!(tanh, y, x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
  62.542 μs (2 allocations: 39.11 KiB)
  99.213 μs (2 allocations: 39.11 KiB)

julia> @btime vmap!(tanh_faster, y, x * x)   setup=(x=randn(Float32,100,100); y=similar(x));  # with IfElse
  27.292 μs (2 allocations: 39.11 KiB)
  60.970 μs (2 allocations: 39.11 KiB)

WIP, tanh_fast was the initial idea, but at least for Float32, a rational approximation like tanh_faster seems like a better idea, roughly one less digit & another factor of 3 in speed. Both need some care not to give crazy answers at large x, the precise implementation could use some adjusting. Times & error measures in this gist.

I also have no idea hot this interacts with GPU just yet.

src/activations.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

We have been discussing a more structured approach to this lead by @chriselrod and NNlibCPU.

@chriselrod
Copy link

chriselrod commented Aug 8, 2021

I could add tanh_faster here.

Or, maybe we should create a dedicated library to very fast but inaccurate special functions? @oscardssmith also expressed interest in such a library, and I think it makes sense as a stand alone repository.

Would be good to still have some sort of accuracy guarantees, perhaps tied to a naming convention.
Most _fast functions in SLEEFPirates are accurate to 3 ULP, but I think SLEEFPirates.sigmoid_fast, and SLEEFPirates.pow_fast are all substantially worse than that. tanh_fast is tested to 3 ULP.
Would be great to come up with some conventions here so

  1. users know what they're getting
  2. we can support options in libraries like LoopVectorization where users can specify the level of accuracy on individual functions they require.

I wont be able to make the ML meeting this week, as I have a vet appointment at that time. But I'll try and merge that PR with basic dense, conv, and batch layers working.

@chriselrod
Copy link

Also, that Xeon in your benchmark seems to be very slow. Using:

using LoopVectorization, BenchmarkTools, LinearAlgebra
Sys.ARCH === :x86_64 && using MKL
BLAS.set_num_threads(1)
using IfElse: ifelse
@inline function tanh_faster(x::Real)
  T = float(typeof(x))
  x2 = 4x^2
  # y = 94.9339451088 * x * ((( x2 + 1.06315869e+03 ) * x2 + 1.88748783e+05 ) * x2 + 5.86237309e+06 ) / ((( ( x2 + 4.03183926e+03 ) * x2 + 1.64253046e+06 ) * x2 + 1.28592857e+08 ) * x2 + 1.11307745e+09 )
  num = @fastmath T(94.9339451088) * 2x * ((( x2 + T(1.06315869e+03) ) * x2 + T(1.88748783e+05) ) * x2 + T(5.86237309e+06) )
  den = @fastmath ((( ( x2 + T(4.03183926e+03) ) * x2 + T(1.64253046e+06) ) * x2 + T(1.28592857e+08) ) * x2 + T(1.11307745e+09) )
  y = num / den
  stop = _fast_stop(tanh_faster, x)
  ifelse(x > stop, one(y), ifelse(x < -stop, -one(y), y))
end
_fast_stop(::typeof(tanh_faster), x) = oftype(x, 10)

@btime x * x   setup=(x=randn(Float32,100,100));
@btime y .= tanh.(x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
@btime y .= tanh_faster.(x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
@btime vmap!(tanh, y, x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
@btime vmap!(tanh_faster, y, x * x)   setup=(x=randn(Float32,100,100); y=similar(x));  # with IfElse

I am getting:

julia> @btime x * x   setup=(x=randn(Float32,100,100));
  16.220 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
  11.967 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
  25.125 μs (2 allocations: 39.11 KiB) # M1 native

julia> @btime y .= tanh.(x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
  75.780 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
  94.014 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
  83.917 μs (2 allocations: 39.11 KiB) # M1 native

julia> @btime y .= tanh_faster.(x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
  19.565 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
  15.447 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
  28.417 μs (2 allocations: 39.11 KiB) # M1 native

julia> @btime vmap!(tanh, y, x * x)   setup=(x=randn(Float32,100,100); y=similar(x));
  27.560 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
  21.269 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
  45.625 μs (2 allocations: 39.11 KiB) # M1 native

julia> @btime vmap!(tanh_faster, y, x * x)   setup=(x=randn(Float32,100,100); y=similar(x));  # with IfElse
  19.097 μs (2 allocations: 39.11 KiB) # i7-1165G7 (laptop)
  13.739 μs (2 allocations: 39.11 KiB) # i9-10980XE (HEDT)
  28.708 μs (2 allocations: 39.11 KiB) # M1 native

@mcabbott
Copy link
Member Author

mcabbott commented Aug 8, 2021

This was supposed to be a super-quick fix, but floating point stuff is messy!

Agree the nicest possible outcome would be a super-low-precision library. I think that for most ML purposes a smooth function which is <1pixel wrong on a graph is going to be plenty. My guess is that smoothness is more important than worst-case error, i.e. a steps are worse than ramps. NNlib might not be a terrible place for such things to live, though, if it's just tanh and some friends -- it's supposed to be a collection of NN-ish functions you can use for other purposes. They could move out once someone finds a need & time write other functions.

NNlibCPU.jl looks more ambitious. From very informal investigation there's lots of room for faster CPU conv etc.

If that's an optional add-on, then the argument for tanh_faster not being in there is that it might be useful without that. Motivated by things like this thread -- it seems embarrassing that in the most obvious casual benchmark, 80% of the time is in tanh.

ps. The xeon is indeed older than I realised, 2012, pre-AVX2. Just the easiest non-M1 around. It's also possible that tanh_faster can be written in a more digestible way, e.g. with evalpoly, this is just copy-pasted.

@oscardssmith
Copy link

One thing to try is using exp2 instead of exp. It may lose a tiny bit of precision, but be notably faster. exp needs some extra work to do the reduction, while exp2 gets it for free.

@ToucheSir
Copy link
Member

Is PyTorch doing something extra besides using tanh(f)? If not, why would there still be a discrepancy between libm and JuliaLang/julia#38382?

@mcabbott
Copy link
Member Author

mcabbott commented Aug 8, 2021

I never looked inside that PR somehow. The times for tanh will depend a lot on which branches you hit. And the time for tanh_kernel suggests that my rational tanh_faster has room for improvement:

julia> @btime y .= tanh.(x)  setup=(x = randn(Float32, 10^3); y = similar(x););
  2.963 μs (0 allocations: 0 bytes)  # as used above

julia> @btime y .= tanh.(x)  setup=(x = 0.1f0 .* randn(Float32, 10^3); y = similar(x););
  1.129 μs (0 allocations: 0 bytes)  # small => more kernel

julia> @btime y .= tanh.(x)  setup=(x = 100 .* randn(Float32, 10^3); y = similar(x););
  627.491 ns (0 allocations: 0 bytes) # large => more flat

julia> @btime y .= tanh_faster.(x)  setup=(x = randn(Float32, 10^3); y = similar(x));
  346.005 ns (0 allocations: 0 bytes)  # this PR's rational function

julia> @btime y .= Base.Math.tanh_kernel.(x)  setup=(x = rand(Float32, 10^3); y = similar(x););
  144.404 ns (0 allocations: 0 bytes)  # just the polynomial kernel

julia> Base.Math.TANH_SMALL_X(Float32)  # polynomial below this
1.3862944f0

julia> Base.Math.TANH_LARGE_X(Float32)  # flat above this, but why so large?
18.0f0

julia> tanh(9f0)  # already constant at half that.
1.0f0

@oscardssmith
Copy link

Just wondering, how did you generate the coefficients? Theoretically, Remez.jl can make minimax rational approximations, but I've never had much luck (I believe due to problems with numerical conditioning).

@mcabbott
Copy link
Member Author

mcabbott commented Aug 8, 2021

@mcabbott
Copy link
Member Author

mcabbott commented Aug 9, 2021

Ok, I delete some details, they are in the gist. But after messing with Remez.jl a bit, I think this function is within 5 eps of true, abs error 2.929f-7. It's a little faster than my initial tanh_faster (which is a similar rational function) which was much less accurate, 300eps, abs error 3.629f-5.

Maybe there are tricks to do even better, these coefficients come from Float32.(...) the bigfloat ones, but perhaps the rounding matters? With the bigfloat coefficients, the error is more like 1.5e-9.

function tanh_new(x::Float32)
    x2 = abs2(x)
    n = evalpoly(x2, (1.0f0, 0.1346604f0, 0.0035974074f0, 2.2332108f-5, 1.587199f-8))
    d = evalpoly(x2, (1.0f0, 0.4679937f0, 0.026262015f0, 0.0003453992f0, 8.7767893f-7))
    ifelse(x2 < 66f0, x * (n / d), sign(x))
end

When I try things on the GPU, I believe Julia's tanh is not being used, but the function used has only slightly larger errors. If you use this tanh_new you seem to get almost identical accuracy to using it on the CPU. I can't measure a time for any of these, they all take exactly as long as broadcasting identity.

@oscardssmith
Copy link

There are some papers on how to make the remez algorithm more numerically stable, and on optimal minimax polynomials with floating point coefficients, but neither have been implimented yet (both are on my list of things to do if I have time). The other thing that should be relatively easy to add is ways of specifying fixed values for some coefficients. This is useful for generating even/odd polynomials, as well as ones where the first few terms can be represented without rounding

"""
@inline function tanh_fast(x)
@inline function tanh_fast(x::Real)
exp2x = @fastmath exp(x + x)
y = (exp2x - 1) / (exp2x + 1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use expm1 here. It will be about 30% slower, but you will get correct answers for small x. This implimentation has catestrophic cancellation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably! Agree this looks dodgy, will re-check Float64 times & accuracy at some point.

I also wonder if there's a clever formula for tanh(x)/x, since it seems a little odd to get accuracy near 0 by expm1 when the desired result is close to linear?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't look like it.

Copy link
Member Author

@mcabbott mcabbott Aug 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to expm1 seems to cost me the entire speed gain:

julia> xs = collect(0.01:0.001:10); ys = similar(xs);

julia> @btime @. $ys = tanh($xs);
  44.291 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = tanh_fast($xs);  # with expm1
  47.958 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = tanh_fast($xs);  # with exp(2x)-1   ... M1 native + 1.8-
  11.500 μs (0 allocations: 0 bytes)

Accuracy, on this same range of points, improves 65eps -> 1.7eps relative worst case, but no change in absolute worst case 1.6e-16. So, as you say, the error is all for very small x.

Maybe it can evaluate some polynomial for abs(x) < 0.13 or so, if 4eps is the goal.

@oscardssmith
Copy link

Any news on this?

# That has bad errors near zero; using expm1 would be slower & more accurate.
# Instead, we switch to an taylor series; seems to add about 50% to time.
x2 = x * x
ypoly = x * evalpoly(x2, (0.9999999999999999, -0.33333333333309806, 0.13333333318143492, -0.053968217983868146 , 0.021865628148606587, -0.008671836868790176))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do these coeffs do? (1.0, 0.33333333333333337, 0.1333333333332623, -0.0539682539194502, 0.021869476267930975, -0.00886184042142138, 0.0035188503873932893) In general, I've found that forcing the leading coefficients to exact values can help with error since you remove the error in the coeficient. These were generated with the following code.

coefs = [1/big(3) , 2/big(15) ,- 17/big(315) , 62/big(2835) ,- 1382/big(155925) , 21844/big(6081075) ,- 929569/big(638512875) , 6404582/big(10854718875) ,- 443861162/big(1856156927625)]

julia> g(x)=evalpoly(x, coefs)

julia> Tuple(Float64.(ratfn_minimax(g, (1e-15,.017^2),6,0)[1]))

Copy link
Member Author

@mcabbott mcabbott Aug 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is better (with -0.3333 surely), indeed. In fact this looks to be within 1eps within the range for which I used the polynomial:

julia> maximum(x -> abs(tanh(big(x)) - tanh_fast(x)) / eps(Float64(tanh(x))), 0.000_01:0.000_01:0.13) |> Float64
  0.9878056794145185  # with yours
1.8541809732394265  # with mine

Whereas the point which is 5eps is somewhere 0.2<x<0.3. So perhaps the ideal changeover point is quite a bit larger... I think 0.13 was a very quick guess at where the error got to 4eps or so. I'm not being maximally systematic about this, clearly.

If I move the cutoff to x^2 < 0.02 then this is within 3eps everywhere, perhaps? No, within 5 still. By mistake I tested with exp not @fastmath exp.

julia> maximum(x -> abs(tanh(big(x)) - tanh_fast(x)) / eps(Float64(tanh(x))), 1e-5:1e-6:2) |> Float64
2.4637177837743502  #  with Base.exp, slow
4.923053945522114  #  with @fastmath exp

julia> maximum(x -> abs(tanh(big(x)) - tanh_fast(x)) / eps(Float64(tanh(x))), -1e-5:-1e-6:-2) |> Float64
2.680360427525547  # slow
5.459406134894406  # fast

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option is to lower the degree polynomial. The one I gave is a minimax polynomial, so it won't be anywhere near optimal outside the range. Do you have an error threshold in mind? If so, I'll do some searching to find the minimal degree one with the required accuracy of the range.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely had 5eps in mind, just because that was doable for the Float32 rational function --- there I could fit more more terms, but after rounding etc. couldn't do much better. I'm not sure how much NNlib cares about the Float64 case, it might just be for our amusement. But sure, if you are set up to search then a better polynomial + threshold would be welcome.

BTW I tried replacing exp(x+x) with exp2(α x) as suggested, but didn't see much change in speed. The accuracy does seem a little worse on one side, so perhaps it wants to be exp(abs(2x)) & restore the sign later.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, it looks like to maintain 5 ULPs, we need a polynomial for |x| < 0.063. Given that, we can lower the degree to 4, which gives coefficients of (1.0, -0.3333333333331805, 0.13333333295910244, -0.05396796907132604, 0.02178397571543739) which have a maximal error of 1.2 ULPs in the range. Alternatively, bumping up to degree 5 lets the polynomial stay accurate to 2.5 ULPs over the whole range with coefficients (1.0, -0.33333333333324583, 0.13333333325511604, -0.05396823125794372, 0.02186660872609521, -0.008697141630499953). I would probably go for this version, since the 1 extra fma won't change performance too much, and will give 2x less accuracy, which imo is worth it.

exp2x = @fastmath exp(x + x)
y = (exp2x - 1) / (exp2x + 1)
ifelse(x > 30, one(y), ifelse(x < -30, -one(y), y))
exp2x = @fastmath exp(x + x)
Copy link

@chriselrod chriselrod Aug 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about

 @inline exp_inline(x) = Base.Math.exp_impl(x, Val(:ℯ))

and defining an associated diff rule, for the versions of Julia that define an inline Base.Math.exp_impl?

This should let the compiler SIMD the code.

Of course, for tanh_fast itself, you shouldn't have to define an exp_inline diff rule, because tanh_fast should have a rule itself.

Copy link
Member Author

@mcabbott mcabbott Aug 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this in the Zulip thread and had a go... but for me this is somewhere between the @fastmath variant and the plain exp in speed:

julia> @btime @. $ys = tanh_fast($xs);  # with @fastmath exp(x + x), and polynomial etc, M1 + Rosetta, Julia 1.7β
  30.125 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = tanh_fast($xs);  # with exp_inline
  47.208 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = tanh_fast($xs);  # with Base.exp
  85.125 μs (0 allocations: 0 bytes)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Base.Math.exp_impl_fast(x, Val(:ℯ))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think that needs 1.7)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using Base.Math.exp_impl_fast seems to make no difference at all, strangely? I don't see an @inline.

Comment on lines +276 to +702
@inline function sigmoid_fast(x::Real)
t = @fastmath exp(-abs(x))
y = ifelse(x ≥ 0, inv(1 + t), t / (1 + t))
ifelse(x > 40, one(y), ifelse(x < -80, zero(y), y))
end
Copy link
Member Author

@mcabbott mcabbott Aug 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Float32, this sigmoid_fast isn't as big an improvement as tanh_fast. But quite a bit faster. And barely less accurate than the default, maybe 2 eps instead of 1 worst-case, but about the same average.

The rational function symmetric about zero gives you, if you shift it, terrible relative accuracy for negative x, as the function now goes to zero.

julia> xs = collect(0.01:0.001:10); ys = similar(xs);

julia> xs_32 = Float32.(collect(xs)); ys_32 = similar(xs_32);

# Float32

julia> @btime @. $ys_32 = sigmoid($xs_32);
  26.250 μs (0 allocations: 0 bytes)

julia> @btime @. $ys_32 = sigmoid_fast($xs_32);
  7.135 μs (0 allocations: 0 bytes)

julia> @btime @. $ys_32 = sigmoid_fast_old($xs_32);  # version with polynomial from tanh_fast
  3.870 μs (0 allocations: 0 bytes)

# Float64

julia> @btime @. $ys = sigmoid($xs);
  34.250 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = sigmoid_fast($xs);
  12.000 μs (0 allocations: 0 bytes)

Compare tanh:

julia> @btime @. $ys_32 = tanh($xs_32);
  27.125 μs (0 allocations: 0 bytes)

julia> @btime @. $ys_32 = tanh_fast($xs_32);
  3.594 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = tanh($xs);
  44.958 μs (0 allocations: 0 bytes)

julia> @btime @. $ys = tanh_fast($xs);
  17.041 μs (0 allocations: 0 bytes)

@mcabbott mcabbott marked this pull request as ready for review August 29, 2021 04:11
@CarloLucibello
Copy link
Member

Where are we here?

@mcabbott
Copy link
Member Author

mcabbott commented Nov 6, 2021

Status is that I drifted off. When last I was fiddling, I was trying to use Oscar's numbers here #345 (comment) but IIRC got terrible accuracy & wasn't sure why. It was nearly done, though.

@oscardssmith
Copy link

my numbers were minimax for a specific range, so if you were using a different range, they would be all wrong.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2021

OK, b53155f changes to use these coefficients. I may well have messed up the ranges last time. Re-reading today, I'm not entirely sure what range these are for, x^2 < 0.017 comes from trying it out, and seems to keep within 5 eps.

Something I don't think I checked until today, if I time @btime ForwardDiff.derivative(tanh, x[]) setup=(x=Ref(rand(Float32))); they are a bit slower. I guess ideally they would have their own rules somewhere. (For ChainRules there are rules here, including broadcasting.)

@oscardssmith
Copy link

Those coefs with that range (you got the range right) should get you 2.5 ULP if I'm doing the math right.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2021

OK, great. I don't see more than 2, at 0.129701, inside the range. But outside the polynomial's range, 5 at 0.140901, -0.206101. This seems fine really, it does what we wanted.

And the main goal anyway is the Float32 version.

[skip ci]
@oscardssmith
Copy link

Ah, I'd done the testing assuming the regular (not fastmath) versions.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2021

Oh right, with that I get worst 2 everywhere. But it's no longer obviously faster than tanh:

julia> xs = collect(0.01:0.001:10); ys = similar(xs);

julia> @btime @. $ys = tanh($xs);
  98.333 μs (0 allocations: 0 bytes)             # M1 + rosetta + 1.7
  min 44.916 μs, mean 45.177 μs (0 allocations)  # M1 native + 1.8

julia> @btime @. $ys = tanh_fast($xs);  # without @fastmath
  90.416 μs (0 allocations: 0 bytes)
  min 48.083 μs, mean 48.747 μs (0 allocations)

julia> @btime @. $ys = tanh_fast($xs);  # with @fastmath
  30.375 μs (0 allocations: 0 bytes)
  min 33.167 μs, mean 33.580 μs (0 allocations)

@oscardssmith
Copy link

no, you definitely want the fastmath version, it just changes the tradeoff a little. I'm looking now to see if I can quickly get a better parameter tune given the tradeoff.

@oscardssmith
Copy link

After a little experimenting, it appears that the coefficients were actually pretty good. Lowering the error further isn't really possible (since the fastmath version needs to go up to x=2.55 for the error to go down noticeably (which would take 2 extra coefficients to match the accuracy), and you can't remove a coefficient without significantly raising the error.

@mcabbott
Copy link
Member Author

mcabbott commented Nov 8, 2021

OK, thanks for checking. Maybe we should call this ready?

My last change is to restrict this polynomial story to Float64, so that if anyone tries it with dual numbers or bigfloats or something, they will get ordinary tanh.

We can speed up other function in this file by large factors, too, e.g. elu_fast, gelu_fast, swish_fast. But maybe they don't need to have slow versions at all. And maybe nobody uses them anyway? That can wait for another PR.

src/activations.jl Outdated Show resolved Hide resolved
Co-authored-by: Carlo Lucibello <[email protected]>
@mcabbott mcabbott merged commit 37093c7 into FluxML:master Nov 8, 2021
@mcabbott mcabbott deleted the tanh_faster branch November 8, 2021 11:09
@oscardssmith
Copy link

My guess is people do use gelu since that appears to be what Transformer models are moving towards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants