Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in _patch_dropout_layers #146

Closed
ngocphuonganhduong opened this issue Sep 10, 2021 · 2 comments · Fixed by #147 or #172
Closed

Bug in _patch_dropout_layers #146

ngocphuonganhduong opened this issue Sep 10, 2021 · 2 comments · Fixed by #147 or #172
Labels
bug Something isn't working

Comments

@ngocphuonganhduong
Copy link

Describe the bug
In the function _patch_dropout_layers, when changed=True, _patch_dropout_layers(child) of the current child which is not a dropout layer is never executed. However, this function is supposed to recursively replace all dropout layers of a given module with MC dropout layers.

def _patch_dropout_layers(module: torch.nn.Module) -> bool:
    changed = False
    for name, child in module.named_children():
        if isinstance(child, torch.nn.Dropout):
            new_module = Dropout(p=child.p, inplace=child.inplace)
        elif isinstance(child, torch.nn.Dropout2d):
            new_module = Dropout2d(p=child.p, inplace=child.inplace)
        else:
            new_module = None
        if new_module is not None:
            changed = True
            module.add_module(name, new_module)
        
        # recursively apply to child
        changed = changed or _patch_dropout_layers(child)
    return changed

To Reproduce
Occur when more than one children of a module are not dropout layers but contain dropout layers. => only dropout layers in the first children are replaced with MC dropout.

Some advices:
From this
changed = changed or _patch_dropout_layers(child)
to this:
changed |= _patch_dropout_layers(child)

@Dref360
Copy link
Member

Dref360 commented Sep 10, 2021

Thanks a lot for submitting this issue! The PR is opened at #147

@ngocphuonganhduong
Copy link
Author

Thank you very much for the fast update, but be careful with changed = changed or in baal/bayesian/weight_drop.py and baal/bayesian/consistent_dropout.py. Same problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants