Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] added support for tuples #5220

Merged
merged 73 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
51ab367
progress
ptillet Sep 3, 2024
65da71f
.
ptillet Sep 4, 2024
746a2a3
prototype works
ptillet Sep 4, 2024
630ec6c
added test
ptillet Sep 5, 2024
f2439f9
fixup
ptillet Sep 5, 2024
d9af0ba
cleanup
ptillet Sep 6, 2024
f758ef2
.
ptillet Sep 6, 2024
1b58df2
progress
ptillet Sep 8, 2024
1558d6c
bugfix
ptillet Sep 8, 2024
812af43
progress
ptillet Sep 9, 2024
e226cfd
.
ptillet Oct 9, 2024
98c526e
Merge remote-tracking branch 'origin/main' into phil/tuple-support
ptillet Oct 9, 2024
8cee89d
.
ptillet Oct 11, 2024
627bef2
.
ptillet Oct 12, 2024
756d75a
.
ptillet Oct 12, 2024
2a86fb4
.
ptillet Oct 12, 2024
d614226
.
ptillet Oct 13, 2024
5d29bef
fails again?
ptillet Oct 13, 2024
a790867
more hacks
ptillet Oct 13, 2024
fa23bfc
giant mess; more tests pass
ptillet Oct 13, 2024
d88cca0
very hacky but tests pass; TO REFACTOR
ptillet Oct 13, 2024
e299bf2
.
ptillet Oct 15, 2024
fcae528
.
ptillet Oct 16, 2024
d0168c9
progress
ptillet Nov 16, 2024
0ba41ff
more progress
ptillet Nov 16, 2024
e7289dc
more progress
ptillet Nov 17, 2024
b7d8117
.
ptillet Nov 19, 2024
33505ac
more progress
ptillet Nov 21, 2024
3c08877
more fixes
ptillet Nov 21, 2024
dba9b2d
all tests pass
ptillet Nov 21, 2024
a35e89a
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Nov 21, 2024
ae2ebf6
Merge branch 'main' into phil/tuple-support-2
ptillet Nov 22, 2024
18f24ef
fixed TMA descriptors
ptillet Nov 22, 2024
bba29ae
.
ptillet Nov 22, 2024
67fc1b4
more fixes
ptillet Nov 22, 2024
04d463f
.
ptillet Nov 24, 2024
aa74737
.
ptillet Nov 24, 2024
6161e78
.
ptillet Nov 29, 2024
2ab9b39
more fixes
ptillet Nov 30, 2024
394baf7
more bugfixes
ptillet Dec 2, 2024
8a01c91
fix naming
ptillet Dec 2, 2024
f2cf8d6
.
ptillet Dec 3, 2024
3366e8d
.
ptillet Dec 3, 2024
0a910c2
.
ptillet Dec 3, 2024
6fcdfed
mpre cleaning
ptillet Dec 3, 2024
4fc12d9
cleanup
ptillet Dec 3, 2024
a3fea49
cleanup
ptillet Dec 3, 2024
4f92962
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Dec 6, 2024
4286515
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Dec 6, 2024
8bdadd2
amd
ptillet Dec 6, 2024
d6799df
.
ptillet Dec 6, 2024
74d6277
.
ptillet Dec 6, 2024
5002698
.
ptillet Dec 6, 2024
d14ffe2
.
ptillet Dec 6, 2024
2cb01d8
.
ptillet Dec 6, 2024
91e04a5
.
ptillet Dec 7, 2024
8a528ec
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Dec 7, 2024
26ac4e6
make sure `do_not_specialize` is ignored for None values
ptillet Dec 8, 2024
68b52f5
adding dtype test
ptillet Dec 8, 2024
d64d82a
fix dtype handling
ptillet Dec 8, 2024
10f197f
do not materialize none arguments
ptillet Dec 8, 2024
fdf9948
fix amd
ptillet Dec 8, 2024
7d71fbd
.
ptillet Dec 9, 2024
546dc43
.
ptillet Dec 9, 2024
44f446f
.
ptillet Dec 9, 2024
305dede
.
ptillet Dec 9, 2024
122e523
.
ptillet Dec 9, 2024
243062d
.
ptillet Dec 9, 2024
4abfc4b
Merge commit '89c0b0abdfac' into phil/tuple-support-2
ptillet Dec 9, 2024
5a02aef
.
ptillet Dec 9, 2024
06d2abd
.
ptillet Dec 9, 2024
cb67cbc
.
ptillet Dec 9, 2024
1032df3
.
ptillet Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,7 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
49 changes: 25 additions & 24 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def kernel():
a += 1 # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "is not defined" in str(e.value), "error should mention the undefined variable"
Expand All @@ -32,7 +32,7 @@ def kernel():
0 + "a"

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 0"
Expand All @@ -47,7 +47,7 @@ def kernel():
tl.static_assert(isinstance(0, tl.tensor))

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert isinstance(e.value, CompileTimeAssertionFailure)
Expand All @@ -66,7 +66,7 @@ def kernel():
not (0, 0)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert e.value.__cause__ is None
Expand All @@ -83,7 +83,7 @@ def kernel():
1.0 << 1

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the 1.0"
Expand All @@ -107,7 +107,7 @@ def kernel():
nested_call()

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
inner = e.value.__cause__
Expand All @@ -130,7 +130,7 @@ def kernel():
tl.expand_dims(None, -1)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))

try:
inner = e.value.__cause__
Expand All @@ -157,7 +157,7 @@ def kernel():
a = two_returns()
a + tl.arange(0, 4) # only works if we took the first return

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


def test_not_const_annotate_no_err():
Expand All @@ -166,7 +166,7 @@ def test_not_const_annotate_no_err():
def kernel(N: int = 1):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))


@triton.jit
Expand All @@ -186,14 +186,14 @@ def kernel1(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 4)

triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={}, constants={"N": 0}))
triton.compile(triton.compiler.ASTSource(fn=kernel1, signature={"N": "constexpr"}, constexprs={"N": 0}))

@triton.jit
def kernel2(N: tl.constexpr):
a = returns_branched_on_constexpr(N)
a + tl.arange(0, 8)

triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={}, constants={"N": 1}))
triton.compile(triton.compiler.ASTSource(fn=kernel2, signature={"N": "constexpr"}, constexprs={"N": 1}))


@triton.jit
Expand All @@ -211,7 +211,7 @@ def kernel(N: int):
returns_branched_on_non_constexpr(N)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'N': 'i32'}, constexprs={}))

try:
assert "at 2:4:" in str(e.value), "error should point to the function call"
Expand All @@ -227,7 +227,7 @@ def kernel():
tl.arange(2, 7)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert str(e.value.__cause__) == "arange's range must be a power of 2"


Expand All @@ -238,7 +238,7 @@ def kernel():
tl.full((33, ), 0, dtype=tl.int64)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert str(e.value.__cause__) == "Shape element 0 must be a power of 2"


Expand All @@ -251,7 +251,7 @@ def kernel():
a = CAPTURED # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert "CAPTURED is not defined" in str(e.value)


Expand All @@ -265,7 +265,7 @@ def kernel():
a = GLOBAL # noqa

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
assert "global variable" in str(e.value)


Expand All @@ -279,7 +279,7 @@ def kernel():
a = CONSTEXPR_ANNOTATED_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


CONSTEXPR_GLOBAL = tl.constexpr(42)
Expand All @@ -292,7 +292,7 @@ def kernel():
a = CONSTEXPR_GLOBAL # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


TYPE_ALIAS = tl.pointer_type(tl.int32)
Expand All @@ -305,7 +305,7 @@ def kernel():
a = TYPE_ALIAS # noqa

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


def test_global_access_in_fn_default_arg():
Expand All @@ -315,7 +315,7 @@ def kernel(a=GLOBAL):
pass

# No error.
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': "i32"}, constexprs={}))


def test_defaults_assign_no_err():
Expand All @@ -324,7 +324,7 @@ def test_defaults_assign_no_err():
def kernel(a=1, B: tl.constexpr = ""):
pass

triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32'}, constants={'B': ""}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={'a': 'i32', 'B': 'constexpr'}, constexprs={'B': ""}))


def test_where_warning(fresh_triton_cache):
Expand All @@ -337,7 +337,7 @@ def kernel():
tl.where(a, b, c)

with pytest.warns(UserWarning):
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))


@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
Expand Down Expand Up @@ -369,7 +369,8 @@ def dtype_kernel(dtype: tl.constexpr):
ctx = pytest.raises(CompilationError, match="")

with ctx as e:
triton.compile(triton.compiler.ASTSource(fn=dtype_kernel, signature={}, constants={"dtype": dtype}))
triton.compile(
triton.compiler.ASTSource(fn=dtype_kernel, signature={"dtype": "constexpr"}, constexprs={"dtype": dtype}))

if dtype not in supported_dtypes:
try:
Expand All @@ -388,7 +389,7 @@ def dot_kernel():
tl.dot(a, b, max_num_imprecise_acc=128)

with pytest.raises(CompilationError) as e:
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=dot_kernel, signature={}, constexprs={}))
try:
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
except AssertionError as assertion_err:
Expand Down
23 changes: 19 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4352,15 +4352,17 @@ def kernel(x):
def test_value_specialization(value: int, value_type: str, device) -> None:

def repr(specialization):
spec_type = specialization.signature["VALUE"]
return f"kernel_{spec_type}"
ty = specialization.signature["value1"]
cst = '_'.join([k for k, v in specialization.constants.items() if v == 1])
return f"kernel_{ty}_{cst}"

@triton.jit(repr=repr)
def kernel(VALUE, X):
def kernel(value1, is_one, X):
pass

x = torch.tensor([3.14159], device=device)
h = kernel[(1, )](value, x)
h = kernel[(1, )](value, 1, x)
assert "is_one" in h.name
assert value_type in h.name


Expand Down Expand Up @@ -6130,6 +6132,19 @@ def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, r
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))


def test_dtype(device):

@triton.jit
def kernel(X):
dtype_x: tl.constexpr = X.dtype.element_ty
tl.static_assert(dtype_x == tl.int32)
tl.static_assert(dtype_x == tl.constexpr(tl.int32))
tl.static_assert(dtype_x == tl.int8 or (dtype_x == tl.int16 or dtype_x == tl.int32))

X = torch.zeros(1, dtype=torch.int32, device=device)
kernel[(1, )](X)


def test_side_effectful_scan(device):
if device != "cuda":
pytest.skip()
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def kernel():
pass

try:
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constants={}))
triton.compile(triton.compiler.ASTSource(fn=kernel, signature={}, constexprs={}))
except Exception as e:
pytest.fail(f"triton compile failed with error: {e}")

Expand Down
100 changes: 100 additions & 0 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
import triton
import triton.language as tl
import torch


@triton.jit
def _tuple_increment(values):
for i in tl.static_range(len(values)):
values[i] = values[i] + 1
return values


@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])


@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="cuda"):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])


# ----


@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
# assign from tuple
X0, X1 = XPtrs
x0, x1 = values
tl.store(X0, x0)
tl.store(X1, x1)
# assign to tuple
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])


def test_assign(device="cuda"):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]


# -------
ptillet marked this conversation as resolved.
Show resolved Hide resolved


@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.static_assert(tuple1[1] is None)
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[2][0]))
tl.store(Ptr + 8, tuple1[2][1][0])
tl.store(Ptr + 9, tl.load(tuple1[2][1][2]))


# test serialization/deserialization of tuple arguments in
# the frontend.
@triton.jit
def _tuple_serialize(Ptr, N1, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.static_assert(N1 is None)
tl.static_assert(tuple1[1][1] is None)
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][2]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, None, tuple1))


def test_serialize(device="cuda"):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10, ), dtype=torch.int32, device=device)
# we want to check that JIT specialization propagates to tuples:
_tuple_serialize[(1, )](z, None, (x0, (1, None, x1)), 20, 1, (y0, ))
ref = torch.tensor([8, 1, 12, 21, 10, 15, -1, 8, 1, 12], device=device)
assert torch.equal(z, ref)
15 changes: 6 additions & 9 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,12 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={
kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)
if i not in kernel.constexprs
},
constants={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=backend.get_attrs_descriptor(args, kernel.params),
signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
for i, arg in enumerate(args)},
constexprs={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
if not isinstance(arg, torch.Tensor)},
attrs=backend.get_attrs_descriptor(kernel.params, args),
)

context = triton._C.libtriton.ir.context()
Expand Down
Loading
Loading