Skip to content

Commit

Permalink
search: fix edge cases on screening potential ops (tinygrad#4394)
Browse files Browse the repository at this point in the history
* search: fix edge cases on screening potential ops

won't change correctness, but will save a little python time by
properly deduplicating potential actions

* check for de-duplication instead of exact valid actions

* refactor long line
  • Loading branch information
flammit authored May 2, 2024
1 parent 89030b2 commit 5c5b408
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
1 change: 0 additions & 1 deletion test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def test_tensor_cores_multi_reduce(self):
tc_actions = [k for i, k in get_linearizer_actions(Linearizer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}"


def test_limit_dims_to_max_5d_global(self):
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
Expand Down
23 changes: 22 additions & 1 deletion test/test_search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import unittest

from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.engine.schedule import create_schedule
from tinygrad.features.search import time_linearizer, bufs_from_lin
from tinygrad.features.search import time_linearizer, bufs_from_lin, actions
from tinygrad.device import Device, Buffer
from tinygrad.ops import LoadOps, BufferOps
from tinygrad.tensor import Tensor
Expand Down Expand Up @@ -42,5 +43,25 @@ def add(self, x): self.captured.append(x)
capturing.clear()
assert k_beam_0[-1].prg.prg != k_beam_1[-1].prg.prg

def test_get_linearizer_actions(self):
from test.test_linearizer import helper_realized_ast
a = Tensor.rand(4, 3)
b = Tensor.rand(3)
realized_ast, _ = helper_realized_ast(a @ b)
from tinygrad.features.search import get_linearizer_actions
lins = get_linearizer_actions(Linearizer(realized_ast), False).values()

# ensure amt=0 are not duplicated
if Opt(OptOps.UPCAST, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, amt=4)]) == 0, "did not de-dup UPCAST"
if Opt(OptOps.LOCAL, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, amt=4)]) == 0, "did not de-dup LOCAL"
if Opt(OptOps.UNROLL, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, amt=3)]) == 0, "did not de-dup UNROLL"
if Opt(OptOps.GROUP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, amt=3)]) == 0, "did not de-dup GROUP"
if Opt(OptOps.GROUPTOP, 0, 0) in actions:
assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP"

if __name__ == '__main__':
unittest.main()
9 changes: 6 additions & 3 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class Opt:
axis: Optional[int] = None
amt: Optional[int] = None
def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
def real_axis(self, k:Kernel):
if self.axis is None: return -1
if self.op is OptOps.UNROLL: return k.first_reduce+self.axis
if self.op in {OptOps.GROUP, OptOps.GROUPTOP}: return k.first_reduce+k.group_for_reduces+self.axis
return self.axis

@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
Expand Down Expand Up @@ -431,9 +436,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
self.applied_opts.append(opt)
return

if opt.axis is not None:
axis = opt.axis + (self.first_reduce if opt.op is OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
else: axis = -1
axis = opt.real_axis(self)
check(axis < len(self.full_shape), "invalid axis")

if opt.amt is not None:
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/features/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 256)
for i,a in enumerate(actions):
if a.axis is not None and a.op is not OptOps.TC:
if (a.axis >= lin.shape_len) or (lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions): continue
if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue
lin2 = lin.copy()
try:
lin2.apply_opt(a)
Expand Down

0 comments on commit 5c5b408

Please sign in to comment.