From 668324d92bc6a974799896bb58a352b5f11727ff Mon Sep 17 00:00:00 2001 From: Francis Lam Date: Tue, 13 Feb 2024 01:19:35 -0800 Subject: [PATCH] wmma: protect TC locals from modification and use only LOCAL (#3379) also remove unnecesssary upcast_dim from tensor_core and calculate it from the dimensions and thread sizes --- test/test_linearizer.py | 34 +++++++++++++++++++----- test/test_linearizer_failures.py | 8 ------ tinygrad/codegen/kernel.py | 45 ++++++++++++++++---------------- tinygrad/codegen/linearizer.py | 17 ++++++------ 4 files changed, 58 insertions(+), 46 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 1bcf32ce140d..b7907c5cbe3b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -526,9 +526,28 @@ def test_double_reduce(self): Opt(OptOps.UPCAST, 0, 2)], # No globals ]) + def test_invalid_tensor_core_extra_opts(self): + if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: + self.skipTest("device doesn't have tensor cores") + if Device.DEFAULT not in tensor_cores: + self.skipTest("No tensor cores for device") + + N = 128 + Tensor.manual_seed(1552) + a = Tensor.rand(N, N) + b = Tensor.rand(N, N) + realized_ast, _ = helper_realized_ast(a@b) + invalid_opts = [ + [Opt(OptOps.LOCAL, 2, 2)], + [Opt(OptOps.UPCAST, 2, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)], + ] + for x in invalid_opts: + k = Linearizer(realized_ast) + with self.assertRaises(AssertionError): + assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners + def test_tensor_core_opts(self): - if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local: - self.skipTest("Only Compiled uses linearizer with locals") if not Device[Device.DEFAULT].compiler.linearizer_opts.has_tensor_cores: self.skipTest("device doesn't have tensor cores") if Device.DEFAULT not in tensor_cores: @@ -545,17 +564,18 @@ def test_tensor_core_opts(self): [Opt(OptOps.UPCAST, 0, 4)], [Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts - [Opt(OptOps.UNROLL, 0, 2)], # check last unroll - [Opt(OptOps.LASTLOCAL, 0, 4)], # check last local - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of last unroll and last local + [Opt(OptOps.UNROLL, 0, 2)], # check unroll + [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals + [Opt(OptOps.LOCAL, 0, 4)], # check local + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LASTLOCAL, 0, 2)], + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.LASTLOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) ], apply_tc=True, atol=atol, rtol=rtol) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 219c6dc2cc75..5730e2fdab04 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -65,14 +65,6 @@ def test_failure_3(self): ast = helper_add_store(ast) helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"]) - def test_failure_4(self): - ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1, 4, 1, 12, 2, 29), strides=(0, 0, 0, 2, 0, 216, 1, 8), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 1), (0, 11), (0, 2), (0, 27)), contiguous=False), View(shape=(1, 1, 1, 4, 22, 84), strides=(0, 0, 0, 696, 58, 1), offset=0, mask=((0, 1), (0, 1), (0, 1), (0, 4), (0, 12), (0, 58)), contiguous=False), View(shape=(1, 1, 1, 4, 2, 11, 3, 28), strides=(0, 0, 0, 1848, 924, 84, 28, 1), offset=0, mask=None, contiguous=True))))),), arg=(1, 1, 1, 4, 1, 11, 1, 28)) - opts = [Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] - # related to OptOps.NOLOCALS - # IndexError: list index out of range - ast = helper_add_store(ast) - helper_test_lin(Linearizer(ast), opts, failed_platforms=["METAL", "WEBGPU"]) - def test_failure_5(self): ast = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.1464405059814453, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None)), arg=None),), arg=(1, 1, 1, 1, 1, 1, 1, 1)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 82c242186b78..cb111642ca5c 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -12,7 +12,7 @@ from enum import Enum, auto class OptOps(Enum): - UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto(); LASTLOCAL = auto() # noqa: E702 + UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702 GROUP = auto(); GROUPTOP = auto(); NOLOCALS = auto(); PADTO = auto() # noqa: E702 def __lt__(self, x:OptOps): return self.value < x.value @@ -24,26 +24,27 @@ class Opt: def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" @dataclass(frozen=True) -class TensorCore: - dims: List[int] - dtype_in: DType - dtype_out: DType +class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N) + dims: List[int] # N, M, K + dtype_in: DType # dtype for A and B + dtype_out: DType # dtype for C and D threads: List[Tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure - upcast_dim: int # which TC dim to upcast thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501 thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim wmma_func: str # name of wmma function to call def __str__(self): return f"tensor_core<{self.dims}, {self.dtype_in}, {self.dtype_out}>" + def num_threads(self): return len(self.threads) + def num_upcasts(self): return len(self.thread_local_aliases[0]) - self.num_threads() tensor_cores: Dict[str, List[TensorCore]] = { "METAL": [ - TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 - TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 - TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 + TensorCore(dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 + TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 + TensorCore(dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma", threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ]), # noqa: E501 ], "HIP": [ - TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 - TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 + TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 + TensorCore(dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__hip_wmma_f16_f16", threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501 ] } @@ -242,7 +243,7 @@ def simplify_ones(self) -> bool: if self.shape_len == 0: return False all_ones = [s==1 for s in self.full_shape] self.local_dims -= sum(all_ones[self.first_reduce-self.local_dims:self.first_reduce]) - self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) + self.upcasted -= sum(all_ones[self.shape_len-self.upcasted:]) # TODO: no necessary since upcasted axis can't be un-upcasted self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None) return any(all_ones) @@ -365,9 +366,10 @@ def fix(needed, ax): # tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2])) - self.apply_opt(Opt(OptOps.UPCAST, s0 if tc.upcast_dim == 0 else s1, (tc.dims[0]*tc.dims[2])//prod([a[1] for a in tc.threads]))) + for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M + if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, s0 if i == 0 else s1, tc.dims[i]//sz)) for (tc_dim, tc_amt) in tc.threads: - fix(self.apply_opt(Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) + fix(self.apply_opt(Opt(OptOps.LOCAL, s0 if tc_dim == 0 else s1, tc_amt)), s0 if tc_dim == 0 else s1) # assert tensor core and prevent extra_opts from altering the key shape structure if use_tensor_cores == 1: self.tensor_core = tc # TC=2 will do the shape ops without the WMMA @@ -385,7 +387,7 @@ def fix(needed, ax): if self.tensor_core and s0_exists: for upc in [4,2]: if self.full_shape[s0] % upc == 0: - self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc)) + self.apply_opt(Opt(OptOps.LOCAL, s0, upc)) break # alias buffer @@ -396,7 +398,7 @@ def fix(needed, ax): return False def apply_opt(self, opt:Opt): - assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.LASTLOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" # noqa: E501 + assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals" self.applied_opts.append(opt) if opt.axis is not None: axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501 @@ -408,14 +410,10 @@ def apply_opt(self, opt:Opt): if opt.op != OptOps.PADTO: assert self.full_shape[axis] % amt == 0, "no longer valid shift" else: amt = -1 - if opt.op in [OptOps.LOCAL, OptOps.LASTLOCAL]: # cyan + if opt.op == OptOps.LOCAL: # cyan assert self.opts.has_local, "target does not support local" - assert axis < self.first_reduce, "can't local a reduce" - if opt.op == OptOps.LOCAL: - assert not self.tensor_core, "can't local with tensor cores" - self.shift_to(axis, amt, insert_before=self.first_reduce) - else: - self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) + assert axis < self.global_dims, "local is for globals" + self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) self.local_dims += 1 elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green assert self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem" @@ -435,6 +433,7 @@ def apply_opt(self, opt:Opt): self.upcast() elif opt.op == OptOps.UPCAST: # yellow assert axis < self.first_reduce, "upcast is for non-reduce" + assert not(self.tensor_core and axis >= self.first_reduce-len(self.tensor_core.threads)), "can't upcast TC locals" assert amt <= 8, "don't upcast more than 8" self.shift_to(axis, amt, insert_before=None) self.upcast() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 8e4447b7c2d8..acd33d18fa48 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -264,6 +264,7 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]): local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals for n in range(len(replace_acc_idxs)-len(tc.threads)): upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts + if DEBUG >= 3: print("store alias: idxs=", global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) # reduce loop loop_ctx = render_loop(reduce_idxs) @@ -276,23 +277,23 @@ def calc_tc_idxs(local_size: int, aliases: List[List[int]]): for i in self.local_alias: localbuf_idx = self.bufs.index(self.local_alias[i]) buf_idxs = [idx*0 if s == 0 else idx for idx,s in zip(global_idxs+local_idxs+reduce_idxs+full_upcast_idxs,self.sts[i].real_strides())] - if self.tensor_core: + if (tc:=self.tensor_core): min_alias_idx = min(self.local_alias.keys()) - replace_input_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[i-min_alias_idx], self.tensor_core.thread_local_aliases[i-min_alias_idx]) # noqa: E501 - for n in range(len(self.tensor_core.threads)): - buf_idxs[self.first_reduce-len(self.tensor_core.threads)+n] = replace_input_idxs[n] # replace locals - for n in range(len(replace_input_idxs)-len(self.tensor_core.threads)): - buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(self.tensor_core.threads)+n] # replace upcasts + replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx]) + for n in range(tc.num_threads()): + buf_idxs[self.first_reduce-tc.num_threads()+n] = replace_input_idxs[n] # replace locals + for n in range(tc.num_upcasts()): + buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[tc.num_threads()+n] # replace upcasts if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs) ll = self.global_load(i, buf_idxs) locals_to_store.append((localbuf_idx, buf_idxs, ll)) # copy in any global buffers if (tc:=self.tensor_core): - wmma_sz, num_tc_upcast = tc.thread_local_sizes, 2 # 2 is for UNROLL and one UPCAST + wmma_sz = tc.thread_local_sizes def upcast_strides(buf:int): strides, next = [], 1 - for (sz, stride, reduce) in self.upcasted_axis(buf)[num_tc_upcast:]: + for (sz, stride, reduce) in self.upcasted_axis(buf)[tc.num_upcasts():]: strides.append((0 if stride == 0 else next, sz)) next *= 1 if stride == 0 else sz return strides