Skip to content

Commit

Permalink
Variable builder - handle slice (pytorch#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored Apr 19, 2022
1 parent 222696b commit 1e131a6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,3 +1080,31 @@ def fn():
res = fn()

self.assertTrue(same(res, ref_run1))

def test_slice_input(self):
cnts = torchdynamo.testing.CompileCounter()

def getitem(a, idx):
if isinstance(idx, slice):
return (
torch.zeros(1),
a[idx]
+ [
100,
],
)
else:
return (torch.zeros(1), a[idx])

layers = list(range(10))
ref0 = getitem(layers, slice(0, 2, 1))
ref1 = getitem(layers, 2)
ref2 = getitem(layers, slice(3, 8, 2))
with torchdynamo.optimize(cnts, nopython=True):
res0 = getitem(layers, slice(0, 2, 1))
res1 = getitem(layers, 2)
res2 = getitem(layers, slice(3, 8, 2))

self.assertTrue(ref0 == res0)
self.assertTrue(ref1 == res1)
self.assertTrue(ref2 == res2)
1 change: 1 addition & 0 deletions torchdynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def EQUALS_MATCH(self, guard: Guard):
list,
tuple,
set,
slice,
frozenset,
range,
torch.Size,
Expand Down
8 changes: 8 additions & 0 deletions torchdynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .lists import ListVariable
from .lists import NamedTupleVariable
from .lists import RangeVariable
from .lists import SliceVariable
from .lists import TupleVariable
from .misc import AutogradFunctionVariable
from .misc import InspectSignatureVariable
Expand Down Expand Up @@ -265,6 +266,13 @@ def _wrap(self, value):
return DataClassVariable.wrap(self, value).add_guards(
make_guards(GuardBuilder.TYPE_MATCH)
)
elif isinstance(value, slice):
start = ConstantVariable(value.start)
stop = ConstantVariable(value.stop)
step = ConstantVariable(value.step)
return SliceVariable(
[start, stop, step], guards=make_guards(GuardBuilder.CONSTANT_MATCH)
)
else:
result = UserDefinedObjectVariable(
value,
Expand Down

0 comments on commit 1e131a6

Please sign in to comment.