From ceddf83f9cd19df75e5b0378eda37aae6f5762b9 Mon Sep 17 00:00:00 2001 From: Gabriele Cesa Date: Tue, 25 Aug 2020 14:38:55 +0200 Subject: [PATCH] fixed init tests --- test/nn/test_deltaorth_init.py | 2 +- test/nn/test_he_init.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/nn/test_deltaorth_init.py b/test/nn/test_deltaorth_init.py index dda32570..f9dda753 100644 --- a/test/nn/test_deltaorth_init.py +++ b/test/nn/test_deltaorth_init.py @@ -56,7 +56,7 @@ def check(self, r1: FieldType, r2: FieldType): init.deltaorthonormal_init(cl.weights.data, cl.basisexpansion) # init.generalized_he_init(cl.weights.data, cl.basisexpansion) - filter, _ = cl.expand_weights() + filter, _ = cl.expand_parameters() center = filter[..., c, c] diff --git a/test/nn/test_he_init.py b/test/nn/test_he_init.py index 39a42b15..efbe3d76 100644 --- a/test/nn/test_he_init.py +++ b/test/nn/test_he_init.py @@ -63,7 +63,7 @@ def check(self, r1: FieldType, r2: FieldType): init.generalized_he_init(cl.weights.data, cl.basisexpansion) cl.eval() - x = torch.randn(5, r1.size, s, s) + x = torch.randn(10, r1.size, s, s) xg = GeometricTensor(x, r1) y = cl(xg).tensor @@ -80,8 +80,8 @@ def check(self, r1: FieldType, r2: FieldType): print(mean) print(std) - self.assertTrue(torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=3e-2)) - self.assertTrue(torch.allclose(torch.ones_like(std), std, rtol=2e-2, atol=3e-2)) + self.assertTrue(torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=5e-2)) + self.assertTrue(torch.allclose(torch.ones_like(std), std, rtol=1e-1, atol=6e-2)) if __name__ == '__main__':