diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 5d49711ff0f9..5444e600c911 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -283,7 +283,6 @@ def test_tensor_cores_multi_reduce(self): tc_actions = [k for i, k in get_linearizer_actions(Linearizer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC] assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}" - def test_limit_dims_to_max_5d_global(self): t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] diff --git a/test/test_search.py b/test/test_search.py index 17a094a6f243..fc1a943e8487 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -1,8 +1,9 @@ import unittest +from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.linearizer import Linearizer from tinygrad.engine.schedule import create_schedule -from tinygrad.features.search import time_linearizer, bufs_from_lin +from tinygrad.features.search import time_linearizer, bufs_from_lin, actions from tinygrad.device import Device, Buffer from tinygrad.ops import LoadOps, BufferOps from tinygrad.tensor import Tensor @@ -42,5 +43,25 @@ def add(self, x): self.captured.append(x) capturing.clear() assert k_beam_0[-1].prg.prg != k_beam_1[-1].prg.prg + def test_get_linearizer_actions(self): + from test.test_linearizer import helper_realized_ast + a = Tensor.rand(4, 3) + b = Tensor.rand(3) + realized_ast, _ = helper_realized_ast(a @ b) + from tinygrad.features.search import get_linearizer_actions + lins = get_linearizer_actions(Linearizer(realized_ast), False).values() + + # ensure amt=0 are not duplicated + if Opt(OptOps.UPCAST, 0, 0) in actions: + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, amt=4)]) == 0, "did not de-dup UPCAST" + if Opt(OptOps.LOCAL, 0, 0) in actions: + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, amt=4)]) == 0, "did not de-dup LOCAL" + if Opt(OptOps.UNROLL, 0, 0) in actions: + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, amt=3)]) == 0, "did not de-dup UNROLL" + if Opt(OptOps.GROUP, 0, 0) in actions: + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, amt=3)]) == 0, "did not de-dup GROUP" + if Opt(OptOps.GROUPTOP, 0, 0) in actions: + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP" + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index d73ec1fd6a72..1b155eade53f 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -27,6 +27,11 @@ class Opt: axis: Optional[int] = None amt: Optional[int] = None def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" + def real_axis(self, k:Kernel): + if self.axis is None: return -1 + if self.op is OptOps.UNROLL: return k.first_reduce+self.axis + if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis + return self.axis @dataclass(frozen=True) class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N) @@ -431,9 +436,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True): self.applied_opts.append(opt) return - if opt.axis is not None: - axis = opt.axis + (self.first_reduce if opt.op is OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501 - else: axis = -1 + axis = opt.real_axis(self) check(axis < len(self.full_shape), "invalid axis") if opt.amt is not None: diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index ff2434193406..cb0e407a422d 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -87,7 +87,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 256) for i,a in enumerate(actions): if a.axis is not None and a.op is not OptOps.TC: - if (a.axis >= lin.shape_len) or (lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions): continue + if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue lin2 = lin.copy() try: lin2.apply_opt(a)