Skip to content

Commit

Permalink
fix Tensor.var with 0 in reduce dim. (tinygrad#3324)
Browse files Browse the repository at this point in the history
fix when correction is too big. it seems to only work when input size is 0 though.
torch can output -inf in var when correction is too big, which does not make sense.
  • Loading branch information
chenyuxyz authored Feb 6, 2024
1 parent ee25f73 commit d9ef8e2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
12 changes: 11 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,19 +684,25 @@ def test_mean(self):
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)))
def test_mean_zero_axis(self):
helper_test_op([], lambda: torch.ones((1,0,3,0,5)).mean(axis=(1,3)), lambda: Tensor.ones((1,0,3,0,5)).mean(axis=(1,3)), forward_only=True)
helper_test_op([(1,0,3,0,5)], lambda x: x.mean(axis=(1,3)))

def test_var(self):
helper_test_op([(15, 25, 35)], lambda x: x.var())
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5))
# TODO: fix this
# helper_test_op([(10, 2)], lambda x: x.var(correction=50))
def test_var_axis(self):
helper_test_op([(15, 25, 35)], lambda x: x.var(0))
helper_test_op([(15, 25, 35)], lambda x: x.var(2))
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2]))
helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0))
def test_var_zero_axis(self):
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3)))
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0))
helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5))
def test_var_keepdim(self):
helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True))
helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0))
Expand All @@ -712,6 +718,10 @@ def test_std_axis(self):
helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0))
helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0))
def test_std_zero_axis(self):
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3)))
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0))
helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5))
def test_std_keepdim(self):
helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True))
helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0))
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def mean(self, axis=None, keepdim=False):
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction)
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt()

def _softmax(self, axis):
Expand Down

0 comments on commit d9ef8e2

Please sign in to comment.