diff --git a/pyro/distributions/transforms/batch_norm.py b/pyro/distributions/transforms/batch_norm.py index ebf82d36e7..a2dae2cc7e 100644 --- a/pyro/distributions/transforms/batch_norm.py +++ b/pyro/distributions/transforms/batch_norm.py @@ -34,7 +34,7 @@ class BatchNormTransform(TransformModule): Example usage: >>> from pyro.nn import AutoRegressiveNN - >>> from pyro.distributions import InverseAutoregressiveFlow + >>> from pyro.distributions.transforms import InverseAutoregressiveFlow >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> iafs = [InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) for _ in range(2)] >>> bn = BatchNormTransform(10) diff --git a/pyro/distributions/transforms/permute.py b/pyro/distributions/transforms/permute.py index f1e914439c..8f02fd15ce 100644 --- a/pyro/distributions/transforms/permute.py +++ b/pyro/distributions/transforms/permute.py @@ -20,7 +20,7 @@ class PermuteTransform(Transform): Example usage: >>> from pyro.nn import AutoRegressiveNN - >>> from pyro.distributions import InverseAutoregressiveFlow, PermuteTransform + >>> from pyro.distributions.transforms import InverseAutoregressiveFlow, PermuteTransform >>> base_dist = dist.Normal(torch.zeros(10), torch.ones(10)) >>> iaf1 = InverseAutoregressiveFlow(AutoRegressiveNN(10, [40])) >>> ff = PermuteTransform(torch.randperm(10, dtype=torch.long))