Skip to content

Commit

Permalink
[FRONTEND] added support for tuples (triton-lang#5220)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Dec 9, 2024
1 parent e5be006 commit 9743ec0
Show file tree
Hide file tree
Showing 21 changed files with 635 additions and 288 deletions.
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
49 changes: 25 additions & 24 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def kernel():
a += 1 # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "is not defined" in str(e.value), "error should mention the undefined variable"
Expand All @@ -32,7 +32,7 @@ def kernel():
0 + "a"

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 0"
Expand All @@ -47,7 +47,7 @@ def kernel():
tl.static_assert(isinstance(0, tl.tensor))

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert isinstance(e.value, CompileTimeAssertionFailure)
Expand All @@ -66,7 +66,7 @@ def kernel():
not (0, 0)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert e.value.__cause__ is None
Expand All @@ -83,7 +83,7 @@ def kernel():
1.0 << 1

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
Expand All @@ -107,7 +107,7 @@ def kernel():
nested_call()

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
inner = e.value.__cause__
Expand All @@ -130,7 +130,7 @@ def kernel():
tl.expand_dims(None, -1)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
inner = e.value.__cause__
Expand All @@ -157,7 +157,7 @@ def kernel():
a = two_returns()
a + tl.arange(0, 4) # only works if we took the first return

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


def test_not_const_annotate_no_err():
Expand All @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
def kernel(N: int = 1):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))


@triton.jit
Expand All @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 4)

triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))

@triton.jit
def kernel2(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 8)

triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))


@triton.jit
Expand All @@ -211,7 +211,7 @@ def kernel(N: int):
returns_branched_on_non_constexpr(N)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the function call"
Expand All @@ -227,7 +227,7 @@ def kernel():
tl.arange(2, 7)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert str(e.value.__cause__) == "arange's range must be a power of 2"


Expand All @@ -238,7 +238,7 @@ def kernel():
tl.full((33, ), 0, dtype=tl.int64)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"


Expand All @@ -251,7 +251,7 @@ def kernel():
a = CAPTURED # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert "CAPTURED is not defined" in str(e.value)


Expand All @@ -265,7 +265,7 @@ def kernel():
a = GLOBAL # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert "global variable" in str(e.value)


Expand All @@ -279,7 +279,7 @@ def kernel():
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


CONSTEXPR_GLOBAL = tl.constexpr(42)
Expand All @@ -292,7 +292,7 @@ def kernel():
a = CONSTEXPR_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


TYPE_ALIAS = tl.pointer_type(tl.int32)
Expand All @@ -305,7 +305,7 @@ def kernel():
a = TYPE_ALIAS # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


def test_global_access_in_fn_default_arg():
Expand All @@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
pass

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))


def test_defaults_assign_no_err():
Expand All @@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
def kernel(a=1, B: tl.constexpr = ""):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))


def test_where_warning(fresh_triton_cache):
Expand All @@ -337,7 +337,7 @@ def kernel():
tl.where(a, b, c)

with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
Expand Down Expand Up @@ -369,7 +369,8 @@ def dtype_kernel(dtype: tl.constexpr):
ctx = pytest.raises(CompilationError, match="")

with ctx as e:
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
triton.compile(
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))

if dtype not in supported_dtypes:
try:
Expand All @@ -388,7 +389,7 @@ def dot_kernel():
tl.dot(a, b, max_num_imprecise_acc=128)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
try:
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
except AssertionError as assertion_err:
Expand Down
23 changes: 19 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4352,15 +4352,17 @@ def kernel(x):
def test_value_specialization(value: int, value_type: str, device) -> None:

def repr(specialization):
spec_type = specialization.signature["VALUE"]
return f"kernel_{spec_type}"
ty = specialization.signature["value1"]
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
return f"kernel_{ty}_{cst}"

@triton.jit(repr=repr)
def kernel(VALUE, X):
def kernel(value1, is_one, X):
pass

x = torch.tensor([3.14159], device=device)
h = kernel[(1, )](value, x)
h = kernel[(1, )](value, 1, x)
assert "is_one" in h.name
assert value_type in h.name


Expand Down Expand Up @@ -6130,6 +6132,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))


def test_dtype(device):

@triton.jit
def kernel(X):
dtype_x: tl.constexpr = X.dtype.element_ty
tl.static_assert(dtype_x == tl.int32)
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))

X = torch.zeros(1, dtype=torch.int32, device=device)
kernel[(1, )](X)


def test_side_effectful_scan(device):
if device != "cuda":
pytest.skip()
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def kernel():
pass

try:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
except Exception as e:
pytest.fail(f"triton compile failed with error: {e}")

Expand Down
100 changes: 100 additions & 0 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
import triton
import triton.language as tl
import torch


@triton.jit
def _tuple_increment(values):
for i in tl.static_range(len(values)):
values[i] = values[i] + 1
return values


@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])


@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="cuda"):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])


# ----


@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
# assign from tuple
X0, X1 = XPtrs
x0, x1 = values
tl.store(X0, x0)
tl.store(X1, x1)
# assign to tuple
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])


def test_assign(device="cuda"):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]


# -------


@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.static_assert(tuple1[1] is None)
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[2][0]))
tl.store(Ptr + 8, tuple1[2][1][0])
tl.store(Ptr + 9, tl.load(tuple1[2][1][2]))


# test serialization/deserialization of tuple arguments in
# the frontend.
@triton.jit
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.static_assert(N1 is None)
tl.static_assert(tuple1[1][1] is None)
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, None, tuple1))


def test_serialize(device="cuda"):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10, ), dtype=torch.int32, device=device)
# we want to check that JIT specialization propagates to tuples:
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, ))
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
assert torch.equal(z, ref)
15 changes: 6 additions & 9 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,12 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={
kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs
},
constants={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=backend.get_attrs_descriptor(args, kernel.params),
signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)},
constexprs={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=backend.get_attrs_descriptor(kernel.params, args),
)

context = triton._C.libtriton.ir.context()
Expand Down
Loading

0 comments on commit 9743ec0

Please sign in to comment.