Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] added support for tuples #5220

Merged
merged 73 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
51ab367
progress
ptillet Sep 3, 2024
65da71f
.
ptillet Sep 4, 2024
746a2a3
prototype works
ptillet Sep 4, 2024
630ec6c
added test
ptillet Sep 5, 2024
f2439f9
fixup
ptillet Sep 5, 2024
d9af0ba
cleanup
ptillet Sep 6, 2024
f758ef2
.
ptillet Sep 6, 2024
1b58df2
progress
ptillet Sep 8, 2024
1558d6c
bugfix
ptillet Sep 8, 2024
812af43
progress
ptillet Sep 9, 2024
e226cfd
.
ptillet Oct 9, 2024
98c526e
Merge remote-tracking branch 'origin/main' into phil/tuple-support
ptillet Oct 9, 2024
8cee89d
.
ptillet Oct 11, 2024
627bef2
.
ptillet Oct 12, 2024
756d75a
.
ptillet Oct 12, 2024
2a86fb4
.
ptillet Oct 12, 2024
d614226
.
ptillet Oct 13, 2024
5d29bef
fails again?
ptillet Oct 13, 2024
a790867
more hacks
ptillet Oct 13, 2024
fa23bfc
giant mess; more tests pass
ptillet Oct 13, 2024
d88cca0
very hacky but tests pass; TO REFACTOR
ptillet Oct 13, 2024
e299bf2
.
ptillet Oct 15, 2024
fcae528
.
ptillet Oct 16, 2024
d0168c9
progress
ptillet Nov 16, 2024
0ba41ff
more progress
ptillet Nov 16, 2024
e7289dc
more progress
ptillet Nov 17, 2024
b7d8117
.
ptillet Nov 19, 2024
33505ac
more progress
ptillet Nov 21, 2024
3c08877
more fixes
ptillet Nov 21, 2024
dba9b2d
all tests pass
ptillet Nov 21, 2024
a35e89a
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Nov 21, 2024
ae2ebf6
Merge branch 'main' into phil/tuple-support-2
ptillet Nov 22, 2024
18f24ef
fixed TMA descriptors
ptillet Nov 22, 2024
bba29ae
.
ptillet Nov 22, 2024
67fc1b4
more fixes
ptillet Nov 22, 2024
04d463f
.
ptillet Nov 24, 2024
aa74737
.
ptillet Nov 24, 2024
6161e78
.
ptillet Nov 29, 2024
2ab9b39
more fixes
ptillet Nov 30, 2024
394baf7
more bugfixes
ptillet Dec 2, 2024
8a01c91
fix naming
ptillet Dec 2, 2024
f2cf8d6
.
ptillet Dec 3, 2024
3366e8d
.
ptillet Dec 3, 2024
0a910c2
.
ptillet Dec 3, 2024
6fcdfed
mpre cleaning
ptillet Dec 3, 2024
4fc12d9
cleanup
ptillet Dec 3, 2024
a3fea49
cleanup
ptillet Dec 3, 2024
4f92962
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Dec 6, 2024
4286515
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Dec 6, 2024
8bdadd2
amd
ptillet Dec 6, 2024
d6799df
.
ptillet Dec 6, 2024
74d6277
.
ptillet Dec 6, 2024
5002698
.
ptillet Dec 6, 2024
d14ffe2
.
ptillet Dec 6, 2024
2cb01d8
.
ptillet Dec 6, 2024
91e04a5
.
ptillet Dec 7, 2024
8a528ec
Merge remote-tracking branch 'origin/main' into phil/tuple-support-2
ptillet Dec 7, 2024
26ac4e6
make sure `do_not_specialize` is ignored for None values
ptillet Dec 8, 2024
68b52f5
adding dtype test
ptillet Dec 8, 2024
d64d82a
fix dtype handling
ptillet Dec 8, 2024
10f197f
do not materialize none arguments
ptillet Dec 8, 2024
fdf9948
fix amd
ptillet Dec 8, 2024
7d71fbd
.
ptillet Dec 9, 2024
546dc43
.
ptillet Dec 9, 2024
44f446f
.
ptillet Dec 9, 2024
305dede
.
ptillet Dec 9, 2024
122e523
.
ptillet Dec 9, 2024
243062d
.
ptillet Dec 9, 2024
4abfc4b
Merge commit '89c0b0abdfac' into phil/tuple-support-2
ptillet Dec 9, 2024
5a02aef
.
ptillet Dec 9, 2024
06d2abd
.
ptillet Dec 9, 2024
cb67cbc
.
ptillet Dec 9, 2024
1032df3
.
ptillet Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ void init_triton_ir(py::module &&m) {
"Function argument index out of range");
return self.getArgument(idx);
})
.def("get_num_args", &FuncOp::getNumArguments)
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
Expand Down
98 changes: 98 additions & 0 deletions python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import pytest
import triton
import triton.language as tl
import torch


@triton.jit
def _tuple_increment(values):
for i in tl.static_range(len(values)):
values[i] = values[i] + 1
return values


@triton.jit
def _tuple_index_func(Ptrs, values):
for i in tl.static_range(len(values)):
tl.store(Ptrs[i], values[i])


@triton.jit
def _tuple_index(_0, Ptrs, _1: tl.constexpr, values, _2, _3: tl.constexpr, _4):
values = _tuple_increment(values)
_tuple_index_func(Ptrs, values)


@pytest.mark.parametrize("size", [0, 1, 2, 3, 4])
def test_index(size, device="cuda"):
vals = tuple([i + 1 for i in range(size)])
rets = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in vals])
_tuple_index[(1, )](0, rets, 0, vals, 0, 0, 0)
assert vals == tuple([x.item() - 1 for x in rets])


# ----


@triton.jit
def _tuple_assign(XPtrs, YPtrs, values):
# assign from tuple
X0, X1 = XPtrs
x0, x1 = values
tl.store(X0, x0)
tl.store(X1, x1)
# assign to tuple
Y0, Y1, Y2 = YPtrs
Y = Y0, Y1, Y2
y = x0, 10, x1
tl.store(Y[0], y[0])
tl.store(Y[1], y[1])
tl.store(Y[2], y[2])


def test_assign(device="cuda"):
vals = (2., 3.)
x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)])
y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)])
_tuple_assign[(1, )](x, y, vals)
assert x[0] == vals[0]
assert x[1] == vals[1]
assert y[0] == vals[0]
assert y[1] == 10
assert y[2] == vals[1]

# -------
ptillet marked this conversation as resolved.
Show resolved Hide resolved

@triton.jit
def _tuple_fn0(Ptr, cst2: tl.constexpr, tuple1):
tl.store(Ptr + 5, cst2)
tl.store(Ptr + 6, tuple1[0])
tl.store(Ptr + 7, tl.load(tuple1[1][0]))
tl.store(Ptr + 8, tuple1[1][1][0])
tl.store(Ptr + 9, tl.load(tuple1[1][1][1]))

# test serialization/deserialization of tuple arguments in
# the frontend.
@triton.jit
def _tuple_serdes(Ptr, tuple1, cst1: tl.constexpr, val1, tuple2):
tl.store(Ptr + 0, tl.load(tuple1[0]))
tl.store(Ptr + 1, tuple1[1][0])
tl.store(Ptr + 2, tl.load(tuple1[1][1]))
tl.store(Ptr + 3, cst1 + val1)
tl.store(Ptr + 4, tl.load(tuple2[0]))
_tuple_fn0(Ptr, 15, (-1, tuple1))

def test_serdes(device="cuda"):
x0 = torch.tensor([8], dtype=torch.int32, device=device)
x1 = torch.tensor([12], dtype=torch.int32, device=device)
y0 = torch.tensor([10], dtype=torch.int32, device=device)
z = torch.empty((10,), dtype=torch.int32, device=device)
# we want to check that JIT specialization propagates to tuples:
_tuple_serdes[(1,)](z, (x0, (1, x1)), 20, 1, (y0,))
print(z)


# function call (tuple argument)
# function call (tuple return value)
# __getitem__ and __setitem__
# assignment (into a tuple, from a tuple)
41 changes: 31 additions & 10 deletions python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,28 @@
import hashlib
import subprocess
import sysconfig

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
from types import ModuleType

def find_paths_if(iterable, pred):
ptillet marked this conversation as resolved.
Show resolved Hide resolved
is_iterable = lambda x: isinstance(x, (list, tuple))
ret = []
def _impl(current, path):
if pred(current):
if len(path) == 1:
ret.append((path[0],))
else:
ret.append(tuple(path))
elif is_iterable(current):
for idx, item in enumerate(current):
_impl(item, path + [idx])
if is_iterable(iterable):
_impl(iterable, [])
else:
ret = [tuple()] if pred(iterable) else []
return ret
# Table that associates strings to AttrsDescriptor (sub)classes.
# In this way we can dynamically select the correct class
# constructor
Expand Down Expand Up @@ -86,17 +102,22 @@ def _add_common_properties(self, params, values):
assert (len(params) == len(values))

# Divisibility property
self.arg_properties["tt.divisibility"] = [
param.num for param, arg in zip(params, values) if AttrsDescriptor.is_divisible_by_16(arg)
and not param.do_not_specialize and not param.do_not_specialize_on_alignment
]
divisibility_16 = []
for param, arg in zip(params, values):
if param.do_not_specialize or param.do_not_specialize_on_alignment:
continue
paths = find_paths_if(arg, AttrsDescriptor.is_divisible_by_16)
divisibility_16 += [(param.num,) + x for x in paths]
self.arg_properties["tt.divisibility"] = divisibility_16

# Equal to 1 property
self.arg_properties["tt.equal_to"] = [
param.num
for param, arg in zip(params, values)
if AttrsDescriptor.is_equal_to_1(arg) and not param.do_not_specialize
]
equal_to_1 = []
for param, arg in zip(params, values):
if param.do_not_specialize:
continue
paths = find_paths_if(arg, AttrsDescriptor.is_equal_to_1)
equal_to_1 += [(param.num,) + x for x in paths]
self.arg_properties["tt.equal_to"] = equal_to_1

def _add_backend_properties(self, params=None, values=None):
""" This method is for different subclasses to implement their own compile-time properties """
Expand Down
Loading
Loading