-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a dropout_mask method for ComplexF64 array types #1572
Conversation
In dropout_mask, calling similar with a type argument would make dropout work for complex valued networks
Modifying the dropout_mask function itself to always have a float32 type when calling similar: would make dropout work for any input type, currently using dropout for Int inputs would fail with an inexact error, but I didn't want to modify the main method just for this use-case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this! I like the idea.
@@ -46,6 +46,12 @@ function dropout_mask(x, p; dims=:) | |||
return y | |||
end | |||
|
|||
function dropout_mask(x::Array{Complex{Float64}}, p; dims=:) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a bit strict. Maybe we want to generalise the existing method? It should work the same for single precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I just didn't want to introduce changes that would affect any other types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a test with a complex array as an input, I could also just modify dropout_mask to always use a Float type when calling similar:
Modifying the dropout_mask function itself to always have a float32 type when calling similar:
function dropout_mask(x, p; dims=:) y = rand!(similar(x, Float32, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, 1 - p) return y end
would make dropout work for any input type, currently using dropout for Int inputs would fail with an inexact error, but I didn't want to modify the main method just for this use-case.
This might actually save some memory because if a network uses double-precision the rand would be double as well, which is not really needed just for comparison with p.
Added test for a complex valued array
Thanks for the PR! Is this something that could be more generally useful in NNlib, eventually? |
Possibly, I'm not sure whether dropout is useful for NNlib users or not, but if it works for all input data types then it could be used as a general function. |
Coming back to this, I think the same objective could be accomplished without a custom overload. Updating Flux.jl/src/layers/normalise.jl Line 52 in d151080
y = rand!(rng, similar(x, real(eltype(x)), _dropout_shape(x, dims))) Leaves |
Sure I can rebase and push these changes instead. EDIT: Actually on second thought, that would make dropout fail to work for int types. |
Good catch, though it seems y = rand!(rng, similar(x, real(eltype(x)), _dropout_shape(x, dims))) Use fptype = float(real(eltype(x)))
y = rand!(rng, similar(x, fptype, _dropout_shape(x, dims))) Where |
I think this was superseded by #1867, but feel free to reopen if there are still bits you want to land. |
Calling similar with a type argument makes dropout work for complex valued networks as well.
Currently using dropout in a complex valued network will fail with the error:
ERROR: MethodError: no method matching isless(::Float64, ::Complex{Float64})
This change adds a dropout_mask method for complex arrays with a call to similar with Float64 type.
PR Checklist