Skip to content

Commit

Permalink
test for LazyBuffer._view when mask out and degrade into const (tinyg…
Browse files Browse the repository at this point in the history
…rad#4465)

changed the condition from all 0 in masked dims to any 0 in masked. it's no-op because shapetracker rewrites whole mask to 0 if any dim has 0 as part of canonicalization
  • Loading branch information
chenyuxyz authored May 7, 2024
1 parent a1d350a commit 46a7931
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
23 changes: 22 additions & 1 deletion test/test_lazybuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.lazy import LazyBuffer, ReduceOps
from tinygrad.lazy import LazyBuffer, ReduceOps, LoadOps
from tinygrad.engine.schedule import create_schedule

class TestLazyBuffer(unittest.TestCase):
Expand Down Expand Up @@ -92,5 +92,26 @@ def test_split_reduce_kernel_dim1(self):
for s in sched:
assert s.ast[0].src[0].op is ReduceOps.SUM

class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST LoadOps
a = Tensor.rand(10, 10)
assert a.lazydata.base.op is not LoadOps.CONST

# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is LoadOps.CONST and b.lazydata.base.arg == 0

# mask out dim = 1 works too
b = a.pad((None, (0, 10)))[:, 10:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is LoadOps.CONST and b.lazydata.base.arg == 0

# partial masked out does not degrade into CONST
b = a.pad(((0, 5), None))[5:]
assert b.shape == (10, 10)
assert b.lazydata.base.op is not LoadOps.CONST

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
# *** movement ops ***

def _view(self, new_st:ShapeTracker) -> LazyBuffer:
if self.st.size == 0 or (new_st.views[-1].mask is not None and all((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return self.const(0, new_st.shape)
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
Expand Down

0 comments on commit 46a7931

Please sign in to comment.