Skip to content

Commit

Permalink
fix unit tests for init
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabri95 committed Aug 25, 2020
2 parents 2604f67 + ceddf83 commit e3d6093
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion test/nn/test_deltaorth_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check(self, r1: FieldType, r2: FieldType):
init.deltaorthonormal_init(cl.weights.data, cl.basisexpansion)
# init.generalized_he_init(cl.weights.data, cl.basisexpansion)

filter, _ = cl.expand_weights()
filter, _ = cl.expand_parameters()

center = filter[..., c, c]

Expand Down
6 changes: 3 additions & 3 deletions test/nn/test_he_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def check(self, r1: FieldType, r2: FieldType):
init.generalized_he_init(cl.weights.data, cl.basisexpansion)
cl.eval()

x = torch.randn(5, r1.size, s, s)
x = torch.randn(10, r1.size, s, s)

xg = GeometricTensor(x, r1)
y = cl(xg).tensor
Expand All @@ -80,8 +80,8 @@ def check(self, r1: FieldType, r2: FieldType):
print(mean)
print(std)

self.assertTrue(torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=3e-2))
self.assertTrue(torch.allclose(torch.ones_like(std), std, rtol=2e-2, atol=3e-2))
self.assertTrue(torch.allclose(torch.zeros_like(mean), mean, rtol=2e-2, atol=5e-2))
self.assertTrue(torch.allclose(torch.ones_like(std), std, rtol=1e-1, atol=6e-2))


if __name__ == '__main__':
Expand Down

0 comments on commit e3d6093

Please sign in to comment.