You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
In the function _patch_dropout_layers, when changed=True, _patch_dropout_layers(child) of the current child which is not a dropout layer is never executed. However, this function is supposed to recursively replace all dropout layers of a given module with MC dropout layers.
def _patch_dropout_layers(module: torch.nn.Module) -> bool:
changed = False
for name, child in module.named_children():
if isinstance(child, torch.nn.Dropout):
new_module = Dropout(p=child.p, inplace=child.inplace)
elif isinstance(child, torch.nn.Dropout2d):
new_module = Dropout2d(p=child.p, inplace=child.inplace)
else:
new_module = None
if new_module is not None:
changed = True
module.add_module(name, new_module)
# recursively apply to child
changed = changed or _patch_dropout_layers(child)
return changed
To Reproduce
Occur when more than one children of a module are not dropout layers but contain dropout layers. => only dropout layers in the first children are replaced with MC dropout.
Some advices:
From this changed = changed or _patch_dropout_layers(child)
to this: changed |= _patch_dropout_layers(child)
The text was updated successfully, but these errors were encountered:
Thank you very much for the fast update, but be careful with changed = changed or in baal/bayesian/weight_drop.py and baal/bayesian/consistent_dropout.py. Same problem.
Describe the bug
In the function _patch_dropout_layers, when changed=True, _patch_dropout_layers(child) of the current child which is not a dropout layer is never executed. However, this function is supposed to recursively replace all dropout layers of a given module with MC dropout layers.
To Reproduce
Occur when more than one children of a module are not dropout layers but contain dropout layers. => only dropout layers in the first children are replaced with MC dropout.
Some advices:
From this
changed = changed or _patch_dropout_layers(child)
to this:
changed |= _patch_dropout_layers(child)
The text was updated successfully, but these errors were encountered: