From 9743ec0dca5bbd9dbce20adc3ee273af6b095f94 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 9 Dec 2024 00:58:30 -0800 Subject: [PATCH] [FRONTEND] added support for tuples (#5220) --- python/src/ir.cc | 1 + .../test/unit/language/test_compile_errors.py | 49 ++-- python/test/unit/language/test_core.py | 23 +- python/test/unit/language/test_decorator.py | 2 +- python/test/unit/language/test_tuple.py | 100 +++++++ python/test/unit/runtime/test_bindings.py | 15 +- python/test/unit/runtime/test_cache.py | 4 +- python/test/unit/runtime/test_subproc.py | 8 +- python/test/unit/test_perf_warning.py | 5 +- python/triton/_utils.py | 49 ++++ python/triton/backends/compiler.py | 41 ++- python/triton/compiler/code_generator.py | 258 +++++++++++------- python/triton/compiler/compiler.py | 29 +- python/triton/language/__init__.py | 19 +- python/triton/language/core.py | 141 ++++++++-- python/triton/language/semantic.py | 4 +- python/triton/runtime/jit.py | 44 ++- python/triton/tools/compile.py | 16 +- third_party/amd/backend/compiler.py | 13 +- third_party/amd/backend/driver.py | 59 ++-- third_party/nvidia/backend/driver.py | 43 ++- 21 files changed, 635 insertions(+), 288 deletions(-) create mode 100644 python/test/unit/language/test_tuple.py diff --git a/python/src/ir.cc b/python/src/ir.cc index 23bb86e5eb2c..53ba39ae1026 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) { "Function argument index out of range"); return self.getArgument(idx); }) + .def("get_num_args", &FuncOp::getNumArguments) .def( "add_entry_block", [](FuncOp &self) -> Block * { return self.addEntryBlock(); }, diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 1bafa551e379..2760d26bc5ef 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -17,7 +17,7 @@ def kernel(): a += 1 # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "is not defined" in str(e.value), "error should mention the undefined variable" @@ -32,7 +32,7 @@ def kernel(): 0 + "a" with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the 0" @@ -47,7 +47,7 @@ def kernel(): tl.static_assert(isinstance(0, tl.tensor)) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert isinstance(e.value, CompileTimeAssertionFailure) @@ -66,7 +66,7 @@ def kernel(): not (0, 0) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert e.value.__cause__ is None @@ -83,7 +83,7 @@ def kernel(): 1.0 << 1 with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the 1.0" @@ -107,7 +107,7 @@ def kernel(): nested_call() with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: inner = e.value.__cause__ @@ -130,7 +130,7 @@ def kernel(): tl.expand_dims(None, -1) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) try: inner = e.value.__cause__ @@ -157,7 +157,7 @@ def kernel(): a = two_returns() a + tl.arange(0, 4) # only works if we took the first return - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) def test_not_const_annotate_no_err(): @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err(): def kernel(N: int = 1): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) @triton.jit @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 4) - triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0})) + triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0})) @triton.jit def kernel2(N: tl.constexpr): a = returns_branched_on_constexpr(N) a + tl.arange(0, 8) - triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1})) + triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1})) @triton.jit @@ -211,7 +211,7 @@ def kernel(N: int): returns_branched_on_non_constexpr(N) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={})) try: assert "at 2:4:" in str(e.value), "error should point to the function call" @@ -227,7 +227,7 @@ def kernel(): tl.arange(2, 7) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert str(e.value.__cause__) == "arange's range must be a power of 2" @@ -238,7 +238,7 @@ def kernel(): tl.full((33, ), 0, dtype=tl.int64) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert str(e.value.__cause__) == "Shape element 0 must be a power of 2" @@ -251,7 +251,7 @@ def kernel(): a = CAPTURED # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert "CAPTURED is not defined" in str(e.value) @@ -265,7 +265,7 @@ def kernel(): a = GLOBAL # noqa with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) assert "global variable" in str(e.value) @@ -279,7 +279,7 @@ def kernel(): a = CONSTEXPR_ANNOTATED_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) CONSTEXPR_GLOBAL = tl.constexpr(42) @@ -292,7 +292,7 @@ def kernel(): a = CONSTEXPR_GLOBAL # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) TYPE_ALIAS = tl.pointer_type(tl.int32) @@ -305,7 +305,7 @@ def kernel(): a = TYPE_ALIAS # noqa # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) def test_global_access_in_fn_default_arg(): @@ -315,7 +315,7 @@ def kernel(a=GLOBAL): pass # No error. - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={})) def test_defaults_assign_no_err(): @@ -324,7 +324,7 @@ def test_defaults_assign_no_err(): def kernel(a=1, B: tl.constexpr = ""): pass - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""})) def test_where_warning(fresh_triton_cache): @@ -337,7 +337,7 @@ def kernel(): tl.where(a, b, c) with pytest.warns(UserWarning): - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) @pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) @@ -369,7 +369,8 @@ def dtype_kernel(dtype: tl.constexpr): ctx = pytest.raises(CompilationError, match="") with ctx as e: - triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype})) + triton.compile( + triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype})) if dtype not in supported_dtypes: try: @@ -388,7 +389,7 @@ def dot_kernel(): tl.dot(a, b, max_num_imprecise_acc=128) with pytest.raises(CompilationError) as e: - triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={})) try: assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)") except AssertionError as assertion_err: diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2daa8aaf07d6..ac22cdee4335 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4352,15 +4352,17 @@ def kernel(x): def test_value_specialization(value: int, value_type: str, device) -> None: def repr(specialization): - spec_type = specialization.signature["VALUE"] - return f"kernel_{spec_type}" + ty = specialization.signature["value1"] + cst = '_'.join([k for k, v in specialization.constants.items() if v == 1]) + return f"kernel_{ty}_{cst}" @triton.jit(repr=repr) - def kernel(VALUE, X): + def kernel(value1, is_one, X): pass x = torch.tensor([3.14159], device=device) - h = kernel[(1, )](value, x) + h = kernel[(1, )](value, 1, x) + assert "is_one" in h.name assert value_type in h.name @@ -6130,6 +6132,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32)) +def test_dtype(device): + + @triton.jit + def kernel(X): + dtype_x: tl.constexpr = X.dtype.element_ty + tl.static_assert(dtype_x == tl.int32) + tl.static_assert(dtype_x == tl.constexpr(tl.int32)) + tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32)) + + X = torch.zeros(1, dtype=torch.int32, device=device) + kernel[(1, )](X) + + def test_side_effectful_scan(device): if device != "cuda": pytest.skip() diff --git a/python/test/unit/language/test_decorator.py b/python/test/unit/language/test_decorator.py index fbbfb7144680..42207cc1fab0 100644 --- a/python/test/unit/language/test_decorator.py +++ b/python/test/unit/language/test_decorator.py @@ -23,7 +23,7 @@ def kernel(): pass try: - triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={})) + triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={})) except Exception as e: pytest.fail(f"triton compile failed with error: {e}") diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py new file mode 100644 index 000000000000..863034579a7f --- /dev/null +++ b/python/test/unit/language/test_tuple.py @@ -0,0 +1,100 @@ +import pytest +import triton +import triton.language as tl +import torch + + +@triton.jit +def _tuple_increment(values): + for i in tl.static_range(len(values)): + values[i] = values[i] + 1 + return values + + +@triton.jit +def _tuple_index_func(Ptrs, values): + for i in tl.static_range(len(values)): + tl.store(Ptrs[i], values[i]) + + +@triton.jit +def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4): + values = _tuple_increment(values) + _tuple_index_func(Ptrs, values) + + +@pytest.mark.parametrize("size", [0, 1, 2, 3, 4]) +def test_index(size, device="cuda"): + vals = tuple([i + 1 for i in range(size)]) + rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals]) + _tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0) + assert vals == tuple([x.item() - 1 for x in rets]) + + +# ---- + + +@triton.jit +def _tuple_assign(XPtrs, YPtrs, values): + # assign from tuple + X0, X1 = XPtrs + x0, x1 = values + tl.store(X0, x0) + tl.store(X1, x1) + # assign to tuple + Y0, Y1, Y2 = YPtrs + Y = Y0, Y1, Y2 + y = x0, 10, x1 + tl.store(Y[0], y[0]) + tl.store(Y[1], y[1]) + tl.store(Y[2], y[2]) + + +def test_assign(device="cuda"): + vals = (2., 3.) + x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) + y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) + _tuple_assign[(1, )](x, y, vals) + assert x[0] == vals[0] + assert x[1] == vals[1] + assert y[0] == vals[0] + assert y[1] == 10 + assert y[2] == vals[1] + + +# ------- + + +@triton.jit +def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1): + tl.static_assert(tuple1[1] is None) + tl.store(Ptr + 5, cst2) + tl.store(Ptr + 6, tuple1[0]) + tl.store(Ptr + 7, tl.load(tuple1[2][0])) + tl.store(Ptr + 8, tuple1[2][1][0]) + tl.store(Ptr + 9, tl.load(tuple1[2][1][2])) + + +# test serialization/deserialization of tuple arguments in +# the frontend. +@triton.jit +def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2): + tl.static_assert(N1 is None) + tl.static_assert(tuple1[1][1] is None) + tl.store(Ptr + 0, tl.load(tuple1[0])) + tl.store(Ptr + 1, tuple1[1][0]) + tl.store(Ptr + 2, tl.load(tuple1[1][2])) + tl.store(Ptr + 3, cst1 + val1) + tl.store(Ptr + 4, tl.load(tuple2[0])) + _tuple_fn0(Ptr, 15, (-1, None, tuple1)) + + +def test_serialize(device="cuda"): + x0 = torch.tensor([8], dtype=torch.int32, device=device) + x1 = torch.tensor([12], dtype=torch.int32, device=device) + y0 = torch.tensor([10], dtype=torch.int32, device=device) + z = torch.empty((10, ), dtype=torch.int32, device=device) + # we want to check that JIT specialization propagates to tuples: + _tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, )) + ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device) + assert torch.equal(z, ref) diff --git a/python/test/unit/runtime/test_bindings.py b/python/test/unit/runtime/test_bindings.py index 206d1323017e..e621eefc0110 100644 --- a/python/test/unit/runtime/test_bindings.py +++ b/python/test/unit/runtime/test_bindings.py @@ -63,15 +63,12 @@ def walk_fn(op): backend = triton.compiler.compiler.make_backend(target) src = triton.compiler.compiler.ASTSource( fn=kernel, - signature={ - kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) - for i, arg in enumerate(args) - if i not in kernel.constexprs - }, - constants={kernel.arg_names[i]: arg - for i, arg in enumerate(args) - if not isinstance(arg, torch.Tensor)}, - attrs=backend.get_attrs_descriptor(args, kernel.params), + signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg)) + for i, arg in enumerate(args)}, + constexprs={kernel.arg_names[i]: arg + for i, arg in enumerate(args) + if not isinstance(arg, torch.Tensor)}, + attrs=backend.get_attrs_descriptor(kernel.params, args), ) context = triton._C.libtriton.ir.context() diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index a5f381dc9f56..23c943aeb1d4 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -592,10 +592,10 @@ def cache_hook(*args, **kwargs): JITFunction.cache_hook = cache_hook # In warmup we assume that the pointer range is 32 bits kernel_add.warmup(torch.float32, grid=(1, )) - assert pointer_range_32 == [0] + assert pointer_range_32 == [(0, )] # Torch tensor > 2GB kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device)) assert len(pointer_range_32) == 0 # Torch tensor <= 2GB kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device)) - assert pointer_range_32 == [0] + assert pointer_range_32 == [(0, )] diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 334d5d635f67..ecd7227a30c4 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -19,8 +19,8 @@ def kernel_sub(a, b, o, N: tl.constexpr): src = ASTSource( fn=kernel_sub, - constants={'N': 32}, - signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32"}, + constexprs={'N': 32}, + signature={'a': "*fp32", 'b': "*fp32", 'o': "*fp32", 'N': 'constexpr'}, attrs=attrs, ) triton.compile(src=src, target=target) @@ -44,7 +44,7 @@ def kernel_dot(Z): z = tl.dot(z, z) tl.store(Z + offs, z) - src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constants={}) + src = ASTSource(fn=kernel_dot, signature={'Z': "*fp32"}, attrs=attrs, constexprs={}) triton.compile(src=src, target=target) @@ -65,7 +65,7 @@ def empty_kernel(): import gc gc.collect() - src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constants={}) + src = ASTSource(fn=empty_kernel, signature={}, attrs=attrs, constexprs={}) triton.compile(src=src, target=target) diff --git a/python/test/unit/test_perf_warning.py b/python/test/unit/test_perf_warning.py index 6646d94f50a8..461dcb46b43c 100644 --- a/python/test/unit/test_perf_warning.py +++ b/python/test/unit/test_perf_warning.py @@ -92,7 +92,7 @@ def matmul_kernel( "stride_cm": "i32", "stride_cn": "i32", }, - constants={}, + constexprs={}, )) captured = capfd.readouterr() @@ -136,8 +136,9 @@ def ldst_vec(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, XBLOCK: tl.constexpr) "in_ptr2": "*fp16", "in_ptr3": "*fp32", "out_ptr0": "*fp16", + "XBLOCK": "constexpr", }, - constants={"XBLOCK": XBLOCK}, + constexprs={"XBLOCK": XBLOCK}, ), options={"num_warps": 1}, ) diff --git a/python/triton/_utils.py b/python/triton/_utils.py index ca60c8c3cbca..0ce1a53a701b 100644 --- a/python/triton/_utils.py +++ b/python/triton/_utils.py @@ -20,3 +20,52 @@ def list_list_unflatten(spec: List[int], flat: List[Any]) -> List[List[Any]]: idx += size assert idx == len(flat) return ret + + +def find_paths_if(iterable, pred): + from .language import core + is_iterable = lambda x: isinstance(x, (list, tuple, core.tuple, core.tuple_type)) + ret = dict() + + def _impl(current, path): + path = (path[0], ) if len(path) == 1 else tuple(path) + if is_iterable(current): + for idx, item in enumerate(current): + _impl(item, path + (idx, )) + elif pred(path, current): + if len(path) == 1: + ret[(path[0], )] = current + else: + ret[tuple(path)] = current + + if is_iterable(iterable): + _impl(iterable, []) + elif pred(list(), iterable): + ret = {tuple(): iterable} + else: + ret = dict() + return ret + + +def parse_list_string(s): + s = s.strip() + if s.startswith('[') and s.endswith(']'): + s = s[1:-1] + result = [] + current = '' + depth = 0 + for c in s: + if c == '[': + depth += 1 + current += c + elif c == ']': + depth -= 1 + current += c + elif c == ',' and depth == 0: + result.append(current.strip()) + current = '' + else: + current += c + if current.strip(): + result.append(current.strip()) + return result diff --git a/python/triton/backends/compiler.py b/python/triton/backends/compiler.py index 6d33dbd6fa9b..4c5ac74cf23b 100644 --- a/python/triton/backends/compiler.py +++ b/python/triton/backends/compiler.py @@ -3,11 +3,11 @@ import hashlib import subprocess import sysconfig - from abc import ABCMeta, abstractmethod from dataclasses import dataclass from typing import Dict, List, Tuple, Union from types import ModuleType +from .._utils import find_paths_if # Table that associates strings to AttrsDescriptor (sub)classes. # In this way we can dynamically select the correct class @@ -52,7 +52,8 @@ class AttrsDescriptor: `constant_properties`: a set containing the properties that can be used to determine if a parameter is constant """ - __slots__ = ('divisibility_16', 'equal_to_1', 'arg_properties', 'property_values', 'constant_properties') + __slots__ = ('divisibility_16', 'equal_to_1', 'equal_to_none', 'arg_properties', 'property_values', + 'constant_properties') def __init__(self, params=None, values=None): """ @@ -67,6 +68,7 @@ def __init__(self, params=None, values=None): # Default initialization self.arg_properties = {} self.property_values = {} + self.equal_to_none = {} self.constant_properties = set() self._add_common_properties(params, values) @@ -86,17 +88,30 @@ def _add_common_properties(self, params, values): assert (len(params) == len(values)) # Divisibility property - self.arg_properties["tt.divisibility"] = [ - param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + divisibility_16 = [] + for param, arg in zip(params, values): + if param.do_not_specialize or \ + param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_divisible_by_16(val)) + divisibility_16 += [(param.num, ) + x for x in paths] + self.arg_properties["tt.divisibility"] = divisibility_16 # Equal to 1 property - self.arg_properties["tt.equal_to"] = [ - param.num - for param, arg in zip(params, values) - if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize - ] + equal_to_1 = [] + for param, arg in zip(params, values): + if param.do_not_specialize: + continue + paths = find_paths_if(arg, lambda path, val: AttrsDescriptor.is_equal_to_1(val)) + equal_to_1 += [(param.num, ) + x for x in paths] + self.arg_properties["tt.equal_to"] = equal_to_1 + + # Equal to None property + equal_to_none = [] + for param, arg in zip(params, values): + paths = find_paths_if(arg, lambda path, val: val is None) + equal_to_none += [(param.num, ) + x for x in paths] + self.equal_to_none = equal_to_none def _add_backend_properties(self, params=None, values=None): """ This method is for different subclasses to implement their own compile-time properties """ @@ -130,6 +145,8 @@ def get_constants(self) -> Dict: for prop_name in self.constant_properties: for p in self.arg_properties.get(prop_name, []): constants[p] = self.property_values[prop_name] + for v in self.equal_to_none: + constants[v] = None return constants def filter_out_constants(self): @@ -166,7 +183,7 @@ def from_dict(data): """ attrs_descriptor = _descriptor_table[data["cls"]]() for prop_name, param_ids in data["arg_properties"].items(): - attrs_descriptor.arg_properties[prop_name] = param_ids + attrs_descriptor.arg_properties[prop_name] = list(map(tuple, param_ids)) attrs_descriptor._init_slots() return attrs_descriptor diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 1c39d778ec0f..050e8ad0d728 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -15,9 +15,13 @@ from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) from types import ModuleType from triton._utils import list_list_flatten, list_list_unflatten +from functools import reduce +from .._utils import find_paths_if def mangle_ty(ty): + if ty.is_tuple(): + return 'T' + '_'.join(map(mangle_ty, ty.types)) + 'T' if ty.is_ptr(): return 'P' + mangle_ty(ty.element_ty) if ty.is_int(): @@ -56,7 +60,7 @@ def _is_triton_tensor(o: Any) -> bool: def _is_constexpr(o: Any) -> bool: - return isinstance(o, constexpr) + return o is None or isinstance(o, (constexpr, language.core.dtype)) def _is_triton_scalar(o: Any) -> bool: @@ -189,11 +193,66 @@ def visit_Call(self, node: ast.Call) -> bool: return self.visit(node.func) +class ASTFunction: + + def get_path(self, x, path): + return reduce(lambda a, idx: a[idx], path, x) + + def set_path(self, x, path, val): + prev = x if len(path) == 1 else self.get_path(x, path[:-1]) + prev[path[-1]] = val + + def __init__(self, ret_types, arg_types, constexprs, constants, attrs): + self.ret_types = ret_types + self.arg_types = arg_types + self.constexprs = constexprs + self.constants = constants + self.attrs = attrs + + def serialize(self, builder: ir.builder): + # fill up IR values in template + # > build function + is_val = lambda path, _: path not in self.constexprs and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + arg_types = [self.get_path(self.arg_types, path).to_ir(builder) for path in val_paths] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(arg_types, ret_types) + + def deserialize(self, fn): + # create "template" + def make_template(val): + if isinstance(val, (list, tuple, language.tuple_type)): + return language.tuple([make_template(x) for x in val]) + return language.constexpr(None) + + vals = make_template(self.arg_types) + is_val = lambda path, _: path not in self.constexprs and _ is not None + val_paths = list(find_paths_if(self.arg_types, is_val).keys()) + # > set attributes + for attr_path, attr_specs in self.attrs.items(): + for attr_name, attr_val in attr_specs: + if attr_path in val_paths: + fn.set_arg_attr(val_paths.index(attr_path), attr_name, attr_val) + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + if isinstance(ty, nv_tma_desc_type): + fn.set_arg_attr(i, "tt.nv_tma_desc", 1) + # > add IR values to the template + for i, path in enumerate(val_paths): + ty = self.get_path(self.arg_types, path) + self.set_path(vals, path, language.tensor(fn.args(i), ty)) + # > add constexpr values to the template + constants = self.constants | self.constexprs + for path, val in constants.items(): + self.set_path(vals, path, language.constexpr(val)) + return vals + + class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, - codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, - noinline=False, file_name: Optional[str] = None, begin_line=0): + def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunction, options, codegen_fns, module_map, + module=None, is_kernel=False, function_types: Optional[Dict] = None, noinline=False, + file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -223,8 +282,6 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n self.gscope[k] = v self.lscope = {} - self.attributes = attributes - self.constants = constants self.jit_fn = jit_fn self.function_name = function_name self.is_kernel = is_kernel @@ -342,7 +399,6 @@ def visit_compound_statement(self, stmts): stmts = [stmts] for stmt in stmts: self.visit(stmt) - # Stop parsing as soon as we hit a `return` statement; everything # after this is dead code. if isinstance(stmt, ast.Return): @@ -354,7 +410,7 @@ def visit_Module(self, node): def visit_List(self, node): ctx = self.visit(node.ctx) assert ctx is None - elts = [self.visit(elt) for elt in node.elts] + elts = language.tuple([self.visit(elt) for elt in node.elts]) return elts # By design, only non-kernel functions can return @@ -363,16 +419,15 @@ def visit_Return(self, node): if ret_value is None: self.builder.ret([]) ret_ty = language.void - elif isinstance(ret_value, tuple): - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + elif isinstance(ret_value, language.tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value.values] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) - ret_ty = tuple(ret_types) + ret_ty = language.tuple_type(ret_types) else: ret = language.semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type - if self.ret_type is None: self.ret_type = ret_ty elif self.ret_type != ret_ty: @@ -397,7 +452,6 @@ def visit_FunctionDef(self, node): init_node = ast.Assign(targets=[st_target], value=default_value) else: init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) - try: assert not self.visiting_arg_default_value self.visiting_arg_default_value = True @@ -407,34 +461,15 @@ def visit_FunctionDef(self, node): # initialize function visibility = "public" if self.is_kernel else "private" - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, - self.prototype.to_ir(self.builder), visibility, self.noinline) + fn_ty = self.prototype.serialize(self.builder) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, fn_ty, visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() - arg_values = [] - idx = 0 - for i in range(len(arg_names)): - if i in self.constants: - cst = self.constants[i] - if not _is_constexpr(cst): - cst = constexpr(self.constants[i]) - arg_values.append(cst) - continue - else: - if i in self.attributes: - for name, value in self.attributes[i]: - self.fn.set_arg_attr(idx, name, value) - - # Mark this argument as a pass-by-value TMA descriptor (nvidia) - if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): - self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) - - arg_values.append(tensor(self.fn.args(idx), self.prototype.param_types[idx])) - idx += 1 - - insert_pt = self.builder.get_insertion_block() + arg_values = self.prototype.deserialize(self.fn) + # bind arguments to symbols for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) + insert_pt = self.builder.get_insertion_block() self.builder.set_insertion_point_to_start(entry) # visit function body self.visit_compound_statement(node.body) @@ -445,8 +480,11 @@ def visit_FunctionDef(self, node): self.ret_type = language.void self.builder.ret([]) else: - self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] - self.fn.reset_type(self.prototype.to_ir(self.builder)) + if isinstance(self.ret_type, language.tuple_type): + self.prototype.ret_types = self.ret_type.types + else: + self.prototype.ret_types = [self.ret_type] + self.fn.reset_type(self.prototype.serialize(self.builder)) self.builder.ret([ self.builder.create_poison(ty.to_ir(self.builder)) for ty in self.prototype.ret_types @@ -478,37 +516,41 @@ def visit_AnnAssign(self, node): if target in self.lscope: raise ValueError(f'{target} is already defined.' f' constexpr cannot be reassigned.') - if not _is_constexpr(value): - value = constexpr(value) + value = constexpr(value) self.lscope[target] = value return self.lscope[target] # default: call visit_Assign return self.visit_Assign(node) + def assignTarget(self, target, value): + if isinstance(target, ast.Subscript): + assert target.ctx.__class__.__name__ == "Store" + return self.visit_Subscript_Store(target, value) + if isinstance(target, ast.Tuple): + assert target.ctx.__class__.__name__ == "Store" + for i, name in enumerate(target.elts): + self.set_value(self.visit(name), value.values[i]) + return + assert isinstance(target, ast.Name) + self.set_value(self.visit(target), value) + def visit_Assign(self, node): - _names = [] - if isinstance(node, ast.AnnAssign): - _names += [self.visit(node.target)] - else: - for target in node.targets: - _names += [self.visit(target)] - if len(_names) > 1: - raise self._unsupported(node, "simultaneous multiple assignment is not supported.") - names = _names[0] - values = self.visit(node.value) - if not _is_list_like(names): - names = [names] - if not _is_list_like(values): - values = [values] - native_nontensor_types = (language.dtype, ) - for name, value in zip(names, values): - # by default, constexpr are assigned into python variable + # construct values to assign + def _sanitize_value(value): + if isinstance(value, language.tuple): + return language.tuple([_sanitize_value(v) for v in value.values]) + native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) if value is not None and \ - not _is_triton_value(value) and \ - not isinstance(value, native_nontensor_types): + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) - self.set_value(name, value) + return value + + values = _sanitize_value(self.visit(node.value)) + targets = [node.target] if isinstance(node, ast.AnnAssign) else node.targets + assert len(targets) == 1 + self.assignTarget(targets[0], values) def visit_AugAssign(self, node): name = node.target.id @@ -531,7 +573,7 @@ def visit_Load(self, node): def visit_Tuple(self, node): args = [self.visit(x) for x in node.elts] - return tuple(args) + return language.tuple(args) def _apply_binary_method(self, method_name, lhs, rhs): # TODO: raise something meaningful if getattr fails below, esp for reverse method @@ -903,7 +945,7 @@ def visit_While(self, node): assert False, "Not implemented" ast.NodeVisitor.generic_visit(self, stmt) - def visit_Subscript(self, node): + def visit_Subscript_Load(self, node): assert node.ctx.__class__.__name__ == "Load" lhs = self.visit(node.value) slices = self.visit(node.slice) @@ -911,6 +953,16 @@ def visit_Subscript(self, node): return lhs.__getitem__(slices, _builder=self.builder) return lhs[slices] + def visit_Subscript_Store(self, node, value): + assert node.ctx.__class__.__name__ == "Store" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + assert isinstance(lhs, language.tuple) + lhs.__setitem__(slices, value) + + def visit_Subscript(self, node): + return self.visit_Subscript_Load(node) + def visit_ExtSlice(self, node): return [self.visit(dim) for dim in node.dims] @@ -1067,7 +1119,7 @@ def visit_Slice(self, node): lower = self.visit(node.lower) upper = self.visit(node.upper) step = self.visit(node.step) - return slice(lower, upper, step) + return language.slice(lower, upper, step) def visit_Index(self, node): return self.visit(node.value) @@ -1083,24 +1135,26 @@ def visit_Assert(self, node) -> Any: def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] - # generate function def - attributes = {} - constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] - constants = {i: args[i] for i in constexprs} - # generate call - args = [None if i in constexprs else arg for i, arg in enumerate(args)] - arg_vals = [arg.handle for arg in args if arg is not None] - arg_types = [arg.type for arg in args if arg is not None] - fn_name = mangle_fn(fn.__name__, arg_types, constants) + for i, arg in enumerate(args): + if isinstance(arg, (language.dtype, float, int, bool)): + args[i] = language.core.constexpr(arg) + args_cst = find_paths_if(args, lambda _, x: _is_constexpr(x)) + args_val = find_paths_if(args, lambda _, x: not _is_constexpr(x)).values() + # mangle + fn_name = mangle_fn(fn.__name__, [arg.type for arg in args_val], args_cst) # generate function def if necessary if not self.module.has_function(fn_name): - prototype = language.function_type([], arg_types) gscope = fn.__globals__ # If the callee is not set, we use the same debug setting as the caller file_name, begin_line = get_jit_fn_file_line(fn) - generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + arg_types = [ + language.core.constexpr if arg is None or isinstance(arg, + (bool, int, language.core.dtype)) else arg.type + for arg in args + ] + prototype = ASTFunction([], arg_types, args_cst, dict(), dict()) + generator = CodeGenerator(self.context, prototype, gscope, module=self.module, jit_fn=fn, + function_name=fn_name, function_types=self.function_ret_types, noinline=fn.noinline, file_name=file_name, begin_line=begin_line, options=self.builder.options, codegen_fns=self.builder.codegen_fns, module_map=self.builder.module_map) @@ -1115,8 +1169,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): else: callee_ret_type = self.function_ret_types[fn_name] symbol = self.module.get_function(fn_name) - call_op = self.builder.call(symbol, arg_vals) - if call_op.get_num_results() == 0 or callee_ret_type is None: + args_val = [arg.handle for arg in args_val] + call_op = self.builder.call(symbol, args_val) + if callee_ret_type is None: return None elif call_op.get_num_results() == 1: return tensor(call_op.get_result(0), callee_ret_type) @@ -1124,8 +1179,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): # should return a tuple of tl.tensor results = [] for i in range(call_op.get_num_results()): - results.append(tensor(call_op.get_result(i), callee_ret_type[i])) - return tuple(results) + results.append(tensor(call_op.get_result(i), callee_ret_type.types[i])) + return language.tuple(results) def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func)) @@ -1144,7 +1199,11 @@ def visit_Call(self, node): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: - return fn(*args, **extra_kwargs, **kws) + ret = fn(*args, **extra_kwargs, **kws) + # builtin functions return plain tuples for readability + if isinstance(ret, tuple): + ret = language.tuple(ret) + return ret except Exception as e: # Normally when we raise a CompilationError, we raise it as # `from None`, because the original fileline from the exception @@ -1285,38 +1344,29 @@ def kernel_suffix(signature, specialization): suffix = '' for i, _ in enumerate(signature): suffix += str(i) - if i in specialization.equal_to_1: + if (i, ) in specialization.equal_to_1: suffix += 'c' - if i in specialization.divisibility_16: + if (i, ) in specialization.divisibility_16: suffix += 'd' return suffix def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map): + constexprs = specialization.constexprs + arg_idx = lambda x: (fn.arg_names.index(x), ) if isinstance(x, str) else x + constants = specialization.attrs.get_constants() + constexprs = {arg_idx(k): v for k, v in constexprs.items()} + arg_types = [str_to_ty(ty) for ty in specialization.signature.values()] + # find index of constants in serialized order attrs = specialization.attrs - # create kernel prototype - cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in specialization.constants.items()} - # visit kernel AST - gscope = fn.__globals__.copy() - function_name = fn.repr(specialization) - tys = list(specialization.signature.values()) - new_constants = attrs.get_constants() - for k in new_constants: - if k in tys and tys[k] == "i1" and new_constants[k] == 1: - new_constants[k] = True - new_attrs = attrs.filter_out_constants() fn_attrs = new_attrs.get_fn_attrs() - all_constants = constants.copy() - all_constants.update(new_constants) - arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + fn_attrs = {k: v for k, v in fn_attrs.items() if k not in constants} file_name, begin_line = get_jit_fn_file_line(fn) - - prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, - jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, - begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map) + prototype = ASTFunction([], arg_types, constexprs, constants, fn_attrs) + generator = CodeGenerator(context, prototype, gscope=fn.__globals__.copy(), function_name=fn.repr(specialization), + jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options, + codegen_fns=codegen_fns, module_map=module_map) generator.visit(fn.parse()) ret = generator.module diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index f70c46a9d406..52b8afea14a9 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -51,12 +51,12 @@ def convert_type_repr(x): class ASTSource: - def __init__(self, fn, signature, constants=None, attrs=None) -> None: + def __init__(self, fn, signature, constexprs=None, attrs=None) -> None: self.fn = fn self.ext = "ttir" self.name = fn.__name__ self.signature = signature - self.constants = constants + self.constexprs = constexprs self.attrs = attrs if isinstance(self.signature, str): self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} @@ -64,20 +64,19 @@ def __init__(self, fn, signature, constants=None, attrs=None) -> None: for k in self.signature.keys(): if not isinstance(k, str): raise TypeError("Signature keys must be string") - if self.constants is None: - self.constants = {} - else: - for k in self.constants.keys(): - if not isinstance(k, str): - raise TypeError("Constants keys must be string") + if self.constexprs is None: + self.constexprs = {} if self.attrs is None: self.attrs = AttrsDescriptor() + # this is the constexprs plus the specialized constants + spec_constants = {self.fn.arg_names[k[0]]: v for k, v in self.attrs.get_constants().items() if len(k) == 1} + self.constants = self.constexprs | spec_constants def hash(self): sorted_sig = [v for k, v in sorted(self.signature.items())] # Note - we stringify the keys here to allow sorting to work for cases # where constants have mixed int/str keys. - sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + sorted_constants = sorted((str(k), v) for k, v in self.constexprs.items()) key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" return hashlib.sha256(key.encode("utf-8")).hexdigest() @@ -276,11 +275,11 @@ def compile(src, target=None, options=None): codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() - try: - module = src.make_ir(options, codegen_fns, module_map, context) - except Exception as e: - filter_traceback(e) - raise + # try: + module = src.make_ir(options, codegen_fns, module_map, context) + # except Exception as e: + # filter_traceback(e) + # raise use_ir_loc = os.environ.get("USE_IR_LOC", None) for ext, compile_ir in list(stages.items())[first_stage:]: next_module = compile_ir(module, metadata) @@ -412,7 +411,7 @@ def launch_metadata(self, grid, stream, *args): arg_idx = 0 for i, arg_name in enumerate(self.src.fn.arg_names): if i in self.src.fn.constexprs: - arg_dict[arg_name] = self.src.constants[arg_name] + arg_dict[arg_name] = self.src.constexprs[arg_name] else: arg_dict[arg_name] = args[arg_idx] arg_idx += 1 diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0c8965fc520a..5f5d464d6379 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,6 +1,7 @@ """isort:skip_file""" # Import order is significant here. +from .._utils import parse_list_string from . import math from . import extra from .standard import ( @@ -69,7 +70,6 @@ float8e5, float8e5b16, full, - function_type, gather, histogram, inline_asm_elementwise, @@ -95,6 +95,7 @@ range, reduce, reshape, + slice, split, static_assert, static_print, @@ -102,6 +103,8 @@ store, tensor, trans, + tuple, + tuple_type, uint16, uint32, uint64, @@ -188,7 +191,6 @@ "floor", "fma", "full", - "function_type", "gather", "histogram", "inline_asm_elementwise", @@ -232,6 +234,7 @@ "reduce", "reshape", "rsqrt", + "slice", "sigmoid", "sin", "softmax", @@ -248,6 +251,7 @@ "tensor", "trans", "triton", + "tuple", "uint16", "uint32", "uint64", @@ -264,6 +268,9 @@ def str_to_ty(name): + if name == "none": + return None + if name[0] == "*": name = name[1:] const = False @@ -273,9 +280,17 @@ def str_to_ty(name): ty = str_to_ty(name) return pointer_type(element_ty=ty, const=const) + if name[0] == "[": + names = parse_list_string(name) + tys = [str_to_ty(x) for x in names] + return tuple_type(types=tys) + if name == "nvTmaDesc": return nv_tma_desc_type() + if name == "constexpr": + return constexpr + tys = { "fp8e4nv": float8e4nv, "fp8e4b8": float8e4b8, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 85d5f6beba5f..31b19754c63c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -140,6 +140,7 @@ def __init__(self, value): self.value = value.value else: self.value = value + self.type = constexpr def __repr__(self) -> str: return f"constexpr[{self.value}]" @@ -473,6 +474,10 @@ def is_ptr(): def is_const(): return False + @staticmethod + def is_tuple(): + return False + def __eq__(self, other: dtype): if not isinstance(other, dtype): return False @@ -608,11 +613,10 @@ def __init__(self, element_ty: dtype, shape: List): # Note that block_type's shape is a list of int # while tensor's shape is a list of constexpr. - - assert (isinstance(shape, list)) + assert (isinstance(shape, (list, tuple))) # shape can be empty ([]) when an input is a 0D tensor. - self.shape = _unwrap_shape(shape) + self.shape = tuple(_unwrap_shape(shape)) if not self.shape: raise TypeError('0d block_type is forbidden') @@ -647,19 +651,32 @@ def scalar(self): return self.element_ty -class function_type(dtype): +class tuple_type(dtype): - def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: - self.ret_types = ret_types - self.param_types = param_types + def __init__(self, types): + self.types = types + self.name = f"[{','.join(map(str, self.types))}]" def __str__(self): - return f'fn ({self.param_types}) -> {self.ret_types}' + return self.name + + def __iter__(self): + return iter(self.types) def to_ir(self, builder: ir.builder): - ir_param_types = [ty.to_ir(builder) for ty in self.param_types] - ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] - return builder.get_function_ty(ir_param_types, ret_types) + return [ty.to_ir(builder) for ty in self.types] + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def is_tuple(self): + return True + + +class slice_type(dtype): + + def __init__(self): + self.name = 'slice_type' # scalar types @@ -761,7 +778,7 @@ def __init__(self, handle, type: dtype): self.type = type # Tensor type (can be block_type) # Following the practice in pytorch, dtype is scalar type self.dtype = type.scalar - self.shape = [constexpr(s) for s in self.shape] + self.shape = tuple([constexpr(s) for s in self.shape]) def _flatten_ir(self): return [self.handle] @@ -982,13 +999,16 @@ def __not__(self, _builder=None): @builtin def __getitem__(self, slices, _builder=None): - if isinstance(slices, (slice, constexpr)) or slices is None: + import builtins + if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None: slices = [slices] + if isinstance(slices, tuple): + slices = slices.values ret = self for dim, sl in enumerate(slices): if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) - elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None: pass else: raise ValueError(f"unsupported tensor index: {sl}") @@ -1147,6 +1167,77 @@ def flip(self, dim=None) -> tensor: ... +class tuple: + + def __init__(self, args: list): + self.values = [i for i in args] + + @property + def type(self): + + def get_type(x): + if isinstance(x, dtype): + return dtype + return x.type + + return tuple_type([get_type(x) for x in self.values]) + + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + import builtins + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) + + # TODO: remove + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value + + def __add__(self, other): + if isinstance(other, list): + other = tuple(other) + return tuple(self.values + other.values) + # return tuple(a + b for a, b in zip(self.values, other.values)) + + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) + + def __eq__(self, other): + import builtins + if isinstance(other, (list, builtins.tuple)): + other = tuple(other) + return constexpr(self.values == other.values) + + def __hash__(self): + import builtins + return hash(builtins.tuple(self.values)) + + def __str__(self): + return str([str(x) for x in self.values]) + + def __iter__(self): + return iter(self.values) + + def __len__(self): + return len(self.values) + + +class slice: + + def __init__(self, start, stop, step): + self.start = start + self.stop = stop + self.step = step + self.type = slice_type() + + class _experimental_tensor_descriptor_base(_value): """" A tensor descriptor with unknown shape and strides @@ -1562,7 +1653,7 @@ def expand_dims(input, axis, _builder=None): """ input = semantic.to_tensor(input, _builder) axis = _constexpr_to_value(axis) - axes = list(axis) if isinstance(axis, Sequence) else [axis] + axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis] new_ndim = len(input.shape) + len(axes) axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] @@ -2215,14 +2306,12 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = reduce_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] @@ -2316,14 +2405,12 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): - in_scalar_tys = [t.type.scalar for t in input] - prototype = function_type(in_scalar_tys, in_scalar_tys * 2) - + param_types = [t.type.scalar for t in input] * 2 region = scan_op.get_region(0) with _insertion_guard(_builder): - param_types = [ty.to_ir(_builder) for ty in prototype.param_types] - block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + to_ir = lambda T: T.to_ir(_builder) + block = _builder.create_block_with_parent(region, list(map(to_ir, param_types))) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 60890ac596eb..2f7dba929be0 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -759,14 +759,14 @@ def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> # Add new axes to lhs for _ in range(len(lhs_shape), len(rhs_shape)): lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), - tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + tl.block_type(lhs_ty.scalar, [1] + lhs_shape.values)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for _ in range(len(rhs_shape), len(lhs_shape)): rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), - tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d04f516e8152..4ae7a918a192 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -308,6 +308,8 @@ def mangle_type(arg, is_const=False): return "fp32" elif hasattr(arg, "tma_desc_cpu_ptr"): return "nvTmaDesc" + elif isinstance(arg, tuple): + return "[" + ",".join(map(mangle_type, arg)) + "]" else: # dtypes are hashable so we can memoize this mapping: dsk = (arg.dtype, is_const) @@ -335,8 +337,8 @@ def serialize_specialization_data(name, signature, constants, attrs, options, ke constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} import json obj = { - 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': - options.__dict__, 'key': key + 'name': name, 'signature': signature, 'constant_keys': list(constants.keys()), 'constant_vals': + list(constants.values()), 'attrs': attrs.to_dict(), 'options': options.__dict__, 'key': key } serialized_obj = json.dumps(obj) return serialized_obj @@ -368,6 +370,7 @@ def create_function_from_signature(sig, kparams, backend): func_args.append(f"{name}=default_{name}") dict_entries.append(f"'{name}': {name}") if kp.is_constexpr: + signature_types.append('"constexpr"') constexpr_vals.append(name) else: non_constexpr_vals.append(name) @@ -601,32 +604,23 @@ def run(self, *args, grid, warmup, **kwargs): # done here rather than when we build the signature as otherwise # the kernel cache key could not distinguish between byte pointers # and None arguments, resulting in a downstream mismatch: - sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigkeys = [param.name for param in self.params] sigvals = sig_and_spec[:len(sigkeys)] - signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} - - configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) - constant_params = configs[0].get_constants() - constants = { - p.name: v - for (v, p) in zip(bound_vals, self.params) - if p.is_constexpr or (p.num in constant_params) or v is None - } - for i, arg in constants.items(): + signature = {k: v for (k, v) in zip(sigkeys, sigvals)} + + attrs = backend.get_attrs_descriptor(self.params, bound_vals) + constexprs = {p.name: v for (v, p) in zip(bound_vals, self.params) if p.is_constexpr} + for i, arg in constexprs.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + if self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=True): return None # compile the kernel - src = self.ASTSource(self, signature, constants, configs[0]) - kernel = self.compile( - src, - target=target, - options=options.__dict__, - ) + src = self.ASTSource(self, signature, constexprs, attrs) + kernel = self.compile(src, target=target, options=options.__dict__) self.cache[device][key] = kernel - self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + self._call_hook(key, signature, device, constexprs, options, [attrs], warmup, before=False) # Check that used global values have not changed. not_present = object() @@ -639,15 +633,11 @@ def run(self, *args, grid, warmup, **kwargs): # canonicalize grid assert grid is not None if callable(grid): - # Arguments are passed as a dict to `grid`, by contract. - # TODO(jlebar): In the new launch API, pass the compiler flags as a - # second parameter to `grid`. grid = grid(bound_args) grid_size = len(grid) grid_0 = grid[0] grid_1 = grid[1] if grid_size > 1 else 1 grid_2 = grid[2] if grid_size > 2 else 1 - # launch kernel launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, @@ -738,9 +728,11 @@ def preload(self, specialization_data): if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constant_keys = deserialized_obj['constant_keys'] + constant_vals = deserialized_obj['constant_vals'] constants = { key: tl.dtype(value) if tl.dtype.is_dtype(value) else value - for key, value in deserialized_obj['constants'].items() + for key, value in zip(constant_keys, constant_vals) } signature = dict(deserialized_obj['signature'].items()) src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 6adf7794cc44..50483b236241 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -91,15 +91,13 @@ def constexpr(s): pass return None - hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {(i, ): constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} hints = {k: v for k, v in hints.items() if v is not None} constants = {kernel.arg_names[i]: constexpr(s) for i, s in enumerate(signature)} constants = {k: v for k, v in constants.items() if v is not None} - signature = { - kernel.arg_names[i]: s.split(":")[0] - for i, s in enumerate(signature) - if kernel.arg_names[i] not in constants - } + signature = {kernel.arg_names[i]: s.split(":")[0] for i, s in enumerate(signature)} + for key in constants: + signature[key] = 'constexpr' const_sig = 'x'.join([str(v) for v in constants.values()]) doc_string = [f"{k}={v}" for k, v in constants.items()] doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] @@ -109,8 +107,8 @@ def constexpr(s): assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" attrs = triton.backends.compiler.AttrsDescriptor.from_hints(hints) for p, v in attrs.get_constants().items(): - constants.update({kernel.arg_names[p]: v}) - src = triton.compiler.ASTSource(fn=kernel, constants=constants, signature=signature, attrs=attrs) + constants.update({kernel.arg_names[p[0]]: v}) + src = triton.compiler.ASTSource(fn=kernel, constexprs=constants, signature=signature, attrs=attrs) opts = {"num_warps": args.num_warps, "num_stages": args.num_stages} ccinfo = triton.compile(src, options=opts) if ccinfo.metadata.global_scratch_size > 0: @@ -126,7 +124,7 @@ def constexpr(s): arg_types.append(signature[arg_name]) arg_names_not_1.append(arg_name) arg_types_not_1.append(signature[arg_name]) - elif i in attrs.equal_to_1: + elif (i, ) in attrs.equal_to_1: arg_names.append(arg_name) arg_types.append(signature[arg_name]) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 81b07f2e7d86..a8d806a8b104 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -1,5 +1,6 @@ from triton.backends.compiler import BaseBackend, GPUTarget, AttrsDescriptor, register_descriptor from triton._C.libtriton import ir, passes, llvm, amd +from triton._utils import find_paths_if from dataclasses import dataclass from typing import Any, Dict, Tuple from types import ModuleType @@ -100,10 +101,14 @@ def _add_backend_properties(self, params=None, values=None): if params is None or values is None: return - self.arg_properties["tt.pointer_range"] = [ - param.num for param, arg in zip(params, values) if HIPAttrsDescriptor.is_within2gb(arg) - and not param.do_not_specialize and not param.do_not_specialize_on_alignment - ] + pointer_range = [] + for param, arg in zip(params, values): + if param.do_not_specialize or \ + param.do_not_specialize_on_alignment: + continue + paths = find_paths_if(arg, lambda path, val: HIPAttrsDescriptor.is_within2gb(val)) + pointer_range += [(param.num, ) + x for x in paths] + self.arg_properties["tt.pointer_range"] = pointer_range @staticmethod def is_within2gb(arg): diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index 99e5509eca8d..965341b96e28 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -8,6 +8,7 @@ from triton.runtime.cache import get_cache_manager from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -164,7 +165,7 @@ def __init__(self): # -------------------- Launcher ---------------------------- def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "hipDeviceptr_t" return { "i1": "int32_t", @@ -186,32 +187,27 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids, warp_size): - start_desc = len(signature) - #signature = generate_cu_signature(constants, signature, ids) - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" - return { - 'i1': 'int32_t', - 'i8': 'int8_t', - 'i16': 'int16_t', - 'i32': 'int32_t', - 'i64': 'int64_t', - 'u1': 'uint32_t', - 'u8': 'uint8_t', - 'u16': 'uint16_t', - 'u32': 'uint32_t', - 'u64': 'uint64_t', - 'fp16': 'float', - 'bf16': 'float', - 'fp32': 'float', - 'f32': 'float', - 'fp64': 'double', - }[ty] + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" + return ty_to_cpp(ty) def format_of(ty): + if ty == "hipDeviceptr_t": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -227,14 +223,22 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) libhip_path = _get_path_to_hip_runtime_dylib() # generate glue code - params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params = list(range(len(signature))) + params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] params.append("&global_scratch") src = f""" #define __HIP_PLATFORM_AMD__ @@ -416,8 +420,8 @@ def format_of(ty): // raise exception asap - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); @@ -468,9 +472,8 @@ class HIPLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids, metadata.warp_size) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 196f189caa4a..468a2e9deac2 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -10,6 +10,7 @@ from triton.runtime import _allocation from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver +from triton._utils import parse_list_string dirname = os.path.dirname(os.path.realpath(__file__)) include_dir = [os.path.join(dirname, "include")] @@ -95,7 +96,7 @@ def __init__(self): def ty_to_cpp(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "CUdeviceptr" return { "i1": "int32_t", @@ -118,19 +119,29 @@ def ty_to_cpp(ty): def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): - if ty[0] == '*': + if ty[0] == '*' or ty == "none": return "PyObject*" if ty == "nvTmaDesc": return "PyObject*" - + if ty[0] == '[': + if ty == "[]": + return "[]" + tys = parse_list_string(ty) + val = ','.join(map(_extracted_type, tys)) + return f"[{val}]" return ty_to_cpp(ty) def format_of(ty): + if ty == "CUdeviceptr": + return "O" + if ty[0] == "[": + if ty == "[]": + return "()" + tys = parse_list_string(ty) + val = ''.join(map(format_of, tys)) + return f"({val})" return { "PyObject*": "O", "float": "f", @@ -146,22 +157,29 @@ def format_of(ty): "uint64_t": "K", }[ty] + signature = {k: v for k, v in signature.items() if v != 'constexpr'} args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiKKOOOOO" + args_format + signature = ','.join(signature.values()).replace('[', '').replace(']', '') + signature = list(filter(bool, signature.split(','))) + signature = {i: s for i, s in enumerate(signature)} args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) internal_args_list = [] for i, ty in signature.items(): - if ty[0] == "*": + if ty[0] == "*" or ty == "none": internal_args_list.append(f"ptr_info{i}.dev_ptr") elif ty == "nvTmaDesc": # Note: we have to dereference the pointer internal_args_list.append(f"*tma_ptr{i}") else: internal_args_list.append(f"_arg{i}") + params = range(len(signature)) # generate glue code - params = [f"&arg{i}" for i in signature.keys() if i not in constants] + params = [f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none"] params.append("&global_scratch") src = f""" #include \"cuda.h\" @@ -395,7 +413,7 @@ def format_of(ty): }} // raise exception asap - {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + {"".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])}; {"".join([f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" if ty == "nvTmaDesc" else "" for i, ty in signature.items()])}; Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''}); @@ -446,9 +464,8 @@ class CudaLauncher(object): def __init__(self, src, metadata): ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} + constants = {idx: value for idx, value in constants.items()} + signature = {idx: value for idx, value in src.signature.items()} src = make_launcher(constants, signature, ids) mod = compile_module_from_src(src, "__triton_launcher") self.launch = mod.launch