Skip to content

Commit

Permalink
[Inductor] Fix the Index Put lowering with same input of self and val…
Browse files Browse the repository at this point in the history
…ues (pytorch#139366)

**Summary**
Fix the issue: pytorch#138908, the root-cause is in pytorch#138908 (comment)

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_index_put
python -u -m pytest -s -v test/inductor/test_cpu_repro.py -k test_index_add
```

Pull Request resolved: pytorch#139366
Approved by: https://github.com/jgong5, https://github.com/eellison
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Dec 16, 2024
1 parent 7ab3177 commit ccf35af
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
66 changes: 66 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,72 @@ def fn(x):
(torch.randn(8),),
)

def test_index_put(self):
# https://github.com/pytorch/pytorch/issues/138908
def fn(x, y):
x = x + 10
y[x] += y[x]

x = torch.randint(-10, -9, (1, 2), dtype=torch.int64)
y = torch.randn((2, 32), dtype=torch.float32)
x_clone = x.clone()
y_clone = y.clone()
with torch.no_grad():
fn(x, y)
torch.compile(fn)(x_clone, y_clone)
self.assertEqual(y, y_clone, atol=1e-3, rtol=1e-3)

def test_index_put2(self):
# https://github.com/pytorch/pytorch/issues/138908
def fn(y, index0, index1):
y[index1] += y[index0]

y = torch.randn((2, 32), dtype=torch.float32)
index0 = torch.tensor([[0, 1]])
index1 = torch.tensor([[1, 0]])
y_clone = y.clone()
index0_clone = index0.clone()
index1_clone = index1.clone()
with torch.no_grad():
fn(y, index0, index1)
torch.compile(fn)(y_clone, index0_clone, index1_clone)
self.assertEqual(y, y_clone, atol=1e-3, rtol=1e-3)

def test_index_add(self):
# https://github.com/pytorch/pytorch/issues/138908
def fn(x, y, scale_y, index):
values = x[index] + y * scale_y
out = x.index_add_(dim=0, source=values, index=index)
return out

inp = (
torch.randn(10, 10),
torch.randn(5, 10),
torch.randn(10),
torch.randperm(10, device="cpu")[:5].to(torch.int32),
)
inp_clones = []
for i in range(3):
inp_clones.append(
[
inp[0].clone(),
inp[1].clone(),
inp[2].clone(),
inp[3].clone()
if i == 0
else torch.zeros(10, device="cpu")[:5].to(torch.int32),
]
)
inp_clone, inp_clone2, inp_clone3 = inp_clones
with torch.no_grad():
cfn = torch.compile(fn)
ref = fn(*inp)
res = cfn(*inp_clone)
self.assertEqual(ref, res, atol=1e-3, rtol=1e-3)
ref = fn(*inp_clone2)
res = cfn(*inp_clone3)
self.assertEqual(ref, res, atol=1e-3, rtol=1e-3)

def test_ModularIndexing_range_issue_103133(self):
def fn(q, k):
einsum = torch.einsum("bcxd,bcyd->bcxy", (q, k))
Expand Down
53 changes: 48 additions & 5 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3416,12 +3416,16 @@ def _unsafe_index(x, indices):
# https://github.com/pytorch/torchdynamo/issues/1863
@register_lowering(aten.index_put)
def index_put(x, indices, values, accumulate=False):
return index_put_(clone(x), indices, values, accumulate)
return index_put_impl_(
clone(x), indices, values, accumulate, check=True, may_realize=False
)


@register_lowering(aten._unsafe_index_put)
def _unsafe_index_put(x, indices, values, accumulate=False):
return index_put_impl_(clone(x), indices, values, accumulate, check=False)
return index_put_impl_(
clone(x), indices, values, accumulate, check=False, may_realize=False
)


def index_put_as_masked_fill(self, indices, value, accumulate):
Expand Down Expand Up @@ -3450,15 +3454,54 @@ def index_put_fallback(self, indices, values, accumulate):

@register_lowering(aten.index_put_, type_promotion_kind=None)
def index_put_(self, indices, values, accumulate=False):
return index_put_impl_(self, indices, values, accumulate, check=True)
return index_put_impl_(
self, indices, values, accumulate, check=True, may_realize=True
)


@register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None)
def _unsafe_index_put_(self, indices, values, accumulate=False):
return index_put_impl_(self, indices, values, accumulate, check=False)
return index_put_impl_(
self, indices, values, accumulate, check=False, may_realize=True
)


def index_put_impl_(self, indices, values, accumulate, check, may_realize=False):
if may_realize:

def try_get_name(x):
if isinstance(x, ir.TensorBox):
x = x.data
if isinstance(x, ir.BaseView):
x = x.unwrap_view()
if isinstance(x, ir.StorageBox):
x = x.data
return x.get_name() if isinstance(x, ir.Buffer) else None

def indice_slice_from_randperm(indice):
# Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660
# For this specific pattern, indices is unique as coming from torch.randperm.
# However, as the content of the indices is unknown, we have to check this specific pattern.
if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView):
indice = indice.data.unwrap_view()
return (
isinstance(indice, ir.StorageBox)
and isinstance(indice.data, ir.ExternKernel)
and getattr(indice.data, "fx_node", None)
and indice.data.fx_node.target == torch.ops.aten.randperm.default
)
return False

if try_get_name(self) in values.get_read_names() and not all(
indice_slice_from_randperm(indice) for indice in indices
):
# Fix issue: https://github.com/pytorch/pytorch/issues/138908
# When self and values have memory overlapping, indices may
# contain duplicate values, potentially causing incorrect results since
# the load of `values` might contain modified value from the store of `self`.
# To address this, store values in a temporary buffer in such cases.
values.realize()

def index_put_impl_(self, indices, values, accumulate, check):
# Dispatch to masked fill for single boolean index with single value
if (
values.get_numel() == 1
Expand Down

0 comments on commit ccf35af

Please sign in to comment.