-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
generic_matmul! hit in back!
because type-promotion in activation function
#613
Comments
Isn't this expected though? Why aren't you using the eltype of the input to determine the type of the float constants? |
Depends what you mean by "expected". As a rule when looking for code that might be causing a slow down one doesn't immediately go looking for constants not type matching. It is of course obvious in retrospect, We could certainly consider giving a surpressable warning if the return type of the activation does not match it's inputs. Or we could do other things. |
If nothing else we can have a "performance tips" page, saying to be careful of the types of literals in your activation functions. I also think we can probably make this faster than it is. If nothing else we can promote the matrix and use BLAS rather than the generic matmul. |
I think likely what we should do is trigger a warning on our fallback operations; by default I don’t think any user ever wants to use the generic matmul, and so while I like that we support using it, we should have a (silencable) warning that spits out the first time it is invoked, along with a backtrace to figure out where it’s coming from. |
At the risk of flagrant type piracy, we could just override the behaviour of It would be worth some notes in the activation functions section of the docs though. NNlib's ones are all set up to preserve input type and there's testing infrastructure for this as well; it's really a matter of following standard Julia style. |
Given that #615 added this to the docs, do we still want to address this with a warning somehow? |
Yeah, I think we should, it can wreck your performance. |
I believe changing the types to hit BLAS makes things troublesome for mixed precision. I'm not an expert on the topic, but I've heard that mentioned quite a few times on orthogonal issues/PRs. A warning would be good though. Only concern is the type-piracy. Either way I added this to the triage project so it gets discussed during the next ML community call. |
Please no type piracy. |
You don't need type piracy. |
Doing these things generically in a manner that doesn't touch ad and runtime performance in forward or backwards pass can be tough. |
You can have a disable-able safety rails mode that compiles away. # Safety rails default on
has_safety_rails() = true
# Function to let advanced user turn them off.
#Triggers recompilation
safety_rails!(enable) = @eval has_safety_rails() = $enable
macro safety_rail(cond, msg::String, logargs...)
disable = " To disable this warning, run `Flux.safety_rails!(false)`."
#TODO this doesn't quite display logargs right.
warning = :(@warn($msg*$disable, $(esc.(logargs)...)))
return quote
has_safety_rails() && Zygote.ignore() do
$(esc(cond)) && $warning
end
end
end
function *(a::T, b::S) where {T, S}
@safety_rail(
T!==S,
"Mixed type multiplication encountered. This probably means you ...",
T, S
)
return Base.:(*)(a, b)
end
1.0 * 2 I think I stole this trick from TimerOutputs.jl, to have a debug mode that compiles away when not in use. |
Summarizing what was discussed during triage today: The appropriate place for a What was suggested instead is to package the promotion check into a utility function. Something like |
It is wanted in Base all the time though so I think you will have a hard time putting such a warning there. |
Would a warning via a utility function be acceptable then @oxinabox? |
I honestly don't care how it is done. We should understand the context. The thing that matters here is that it is very easy to get Float64's in your network. But in NN code you often intentionally want One day we might be able to do |
Times with FluxML/Zygote.jl#1044 : now hardly any slowdown, but a few more allocations than the all-Float32 version:
Compared to tagged version without that PR, just the slow case -- 10x slower than Float32, as above:
This version without the PR has some ProjectTo stuff in place, e.g. in the rule for |
Sometimes
generic_matmul!
is hit inback!
For examopole adding a
leak
too unit can be doneby writing an activation function like
And this is well and good, of
x
is aFloat64
.But if
x
is aFloat32
this will trigger a type-promotion.Which is bad, because the user almost certainly did not intend the type promotion.
But worse,
it means rather than hitting fast BLAS, we fall back to slow
generic_matmul!
.Here is a MWE:
Time if it has to promote:
@time demo_flux()
0.143774 seconds (607 allocations: 19.635 MiB)
Time normally:
@time demo_flux()
0.016475 seconds (568 allocations: 13.218 MiB, 47.67% gc time)
That is a 10x time diifference, and it scales up as your matrix sizes scale up.
The text was updated successfully, but these errors were encountered: