Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComplexDropout2d Device Error #30

Open
lucacoma opened this issue Jun 20, 2023 · 2 comments
Open

ComplexDropout2d Device Error #30

lucacoma opened this issue Jun 20, 2023 · 2 comments

Comments

@lucacoma
Copy link

lucacoma commented Jun 20, 2023

Hi, thank you for the nice library.

There seems to be a small mistake in the complexPyTorch.complexLayers.ComplexDropout2d layer, which gives a device mismatch error (torch version 2.0.1+cu118):

""" .... line 106, in complex_dropout
return mask*input
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
"""

I managed to solve it by simply moving the mask on the right device in complexPyTorch.complexFunctions.complex_dropout2d as follows

`
def complex_dropout2d(input, p=0.5, training=True):

# need to have the same dropout mask for real and imaginary part,

# this not a clean solution!

device = input.device

mask = torch.ones(*input.shape, dtype = torch.float32, device = device)

mask = torch.nn.functional.dropout2d(mask, p, training)*1/(1-p)

mask.type(input.dtype)

mask = mask.to(device) # Line added

return mask*input`

Best!

@nctamer
Copy link

nctamer commented Aug 20, 2024

the same for all the dropouts. any updates for the official fix?

@wavefrontshaping
Copy link
Owner

Hi,
I can look at it when I have some time but I am not working anymore on this code, which I consider obsolete and do not need anymore due to the implementation of complex tensors in the current versions of PyTorch.
Do not hesitate to fork and why not make a pull request though if you need such changes, I would treat it.
Best,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants