Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a dropout_mask method for ComplexF64 array types #1572

Closed
wants to merge 3 commits into from
Closed

Add a dropout_mask method for ComplexF64 array types #1572

wants to merge 3 commits into from

Conversation

ShoofLLC
Copy link
Contributor

@ShoofLLC ShoofLLC commented Apr 10, 2021

Calling similar with a type argument makes dropout work for complex valued networks as well.
Currently using dropout in a complex valued network will fail with the error:

ERROR: MethodError: no method matching isless(::Float64, ::Complex{Float64})

This change adds a dropout_mask method for complex arrays with a call to similar with Float64 type.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

In dropout_mask, calling similar with a type argument would make dropout work for complex valued networks
@ShoofLLC ShoofLLC changed the title Add a dropout_method for ComplexF64 array types Add a dropout_mask method for ComplexF64 array types Apr 10, 2021
@ShoofLLC
Copy link
Contributor Author

ShoofLLC commented Apr 10, 2021

Modifying the dropout_mask function itself to always have a float32 type when calling similar:
function dropout_mask(x, p; dims=:) y = rand!(similar(x, Float32, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, 1 - p) return y end

would make dropout work for any input type, currently using dropout for Int inputs would fail with an inexact error, but I didn't want to modify the main method just for this use-case.

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this! I like the idea.

@@ -46,6 +46,12 @@ function dropout_mask(x, p; dims=:)
return y
end

function dropout_mask(x::Array{Complex{Float64}}, p; dims=:)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a bit strict. Maybe we want to generalise the existing method? It should work the same for single precision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I just didn't want to introduce changes that would affect any other types

Copy link
Contributor Author

@ShoofLLC ShoofLLC Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a test with a complex array as an input, I could also just modify dropout_mask to always use a Float type when calling similar:

Modifying the dropout_mask function itself to always have a float32 type when calling similar:
function dropout_mask(x, p; dims=:) y = rand!(similar(x, Float32, _dropout_shape(x, dims))) y .= _dropout_kernel.(y, p, 1 - p) return y end

would make dropout work for any input type, currently using dropout for Int inputs would fail with an inexact error, but I didn't want to modify the main method just for this use-case.

This might actually save some memory because if a network uses double-precision the rand would be double as well, which is not really needed just for comparison with p.

Added test for a complex valued array
@ToucheSir
Copy link
Member

Thanks for the PR! Is this something that could be more generally useful in NNlib, eventually?

@ShoofLLC
Copy link
Contributor Author

Thanks for the PR! Is this something that could be more generally useful in NNlib, eventually?

Possibly, I'm not sure whether dropout is useful for NNlib users or not, but if it works for all input data types then it could be used as a general function.

@ToucheSir
Copy link
Member

ToucheSir commented Feb 4, 2022

Coming back to this, I think the same objective could be accomplished without a custom overload. Updating

y = rand!(rng, similar(x, _dropout_shape(x, dims)))
to:

  y = rand!(rng, similar(x, real(eltype(x)), _dropout_shape(x, dims)))

Leaves IntNN and FloatNN intact, while changing Complex{T} to T.
@ShoofLLC would you be interested in rebasing and updating the PR to this change? If not, I'd be happy to create a new one with your tests.

@ShoofLLC
Copy link
Contributor Author

ShoofLLC commented Feb 4, 2022

Coming back to this, I think the same objective could be accomplished without a custom overload. Updating

y = rand!(rng, similar(x, _dropout_shape(x, dims)))
to:

  y = rand!(rng, similar(x, real(eltype(x)), _dropout_shape(x, dims)))

Leaves IntNN and FloatNN intact, while changing Complex{T} to T.
@ShoofLLC would you be interested in rebasing and updating the PR to this change? If not, I'd be happy to create a new one with your tests.

Sure I can rebase and push these changes instead.

EDIT: Actually on second thought, that would make dropout fail to work for int types.
Correction to the edit: it would fail because of inexact error, I just didn't activate the dropout before using it on Int

@ToucheSir
Copy link
Member

Good catch, though it seems Dropout is already failing for int types on master. I think the fix should be pretty straightforward though, instead of

  y = rand!(rng, similar(x, real(eltype(x)), _dropout_shape(x, dims)))

Use

  fptype = float(real(eltype(x)))
  y = rand!(rng, similar(x, fptype, _dropout_shape(x, dims)))

Where float will be a no-op for Real and Complex floating points, but convert Integers.

@ToucheSir
Copy link
Member

I think this was superseded by #1867, but feel free to reopen if there are still bits you want to land.

@ToucheSir ToucheSir closed this May 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants