diff --git a/docs/conf.py b/docs/conf.py index aba5f0845..9a524d3b8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,7 +33,15 @@ class Mock(MagicMock): def __getattr__(cls, name): return MagicMock() -MOCK_MODULES = ['mpi4py', 'torch', 'torch.optim', 'torch.nn'] +MOCK_MODULES = ['mpi4py', + 'torch', + 'torch.optim', + 'torch.nn', + 'torch.distributions', + 'torch.distributions.normal', + 'torch.distributions.categorical', + 'torch.nn.functional', + ] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Finish imports