Skip to content

Commit

Permalink
Fix constant propagation in builtins and UserClasses (pytorch#131354)
Browse files Browse the repository at this point in the history
  • Loading branch information
rec authored and pytorchmergebot committed Sep 25, 2024
1 parent a0c76ea commit dd4a51b
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 43 deletions.
6 changes: 3 additions & 3 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3810,10 +3810,10 @@ def wrapper_fn(model, params, buffers, inputs):
if torch._dynamo.config.inline_inbuilt_nn_modules:
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_buffers_buffer_: "f32[1]", L_params_l1_bias_: "f32[1]", L_inputs_: "f32[1, 1]"):
l_params_l1_weight_ = L_params_l1_weight_
l_params_l1_bias_ = L_params_l1_bias_
l_buffers_buffer_ = L_buffers_buffer_
l_params_l1_bias_ = L_params_l1_bias_
l_inputs_ = L_inputs_
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
Expand Down Expand Up @@ -6130,7 +6130,7 @@ def wrapper_fn(x, y):
return torch.func.vmap(f)(x, y)

actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
expected = torch.compile(wrapper_fn, backend="aot_eager")(x, y)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
self.assertEqual(some_list, [1, 1])
Expand Down
28 changes: 28 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@
TPFLAGS_MAPPING = 1 << 6


# A class defined in the global scope, used in MiscTests.test_const_getattr
class _B:
def __init__(self):
pass


# Specializes a test to run only if translation validation is set.
def onlyIfTranslationValidation(fn: typing.Callable) -> typing.Callable:
@functools.wraps(fn)
Expand Down Expand Up @@ -1410,6 +1416,28 @@ def fn(x, s):
# One recompile per differing input type
self.assertEqual(cnts.frame_count, 3)

def test_const_getattr(self):
# See https://github.com/pytorch/pytorch/issues/118675
def fn(x):
y = x[f"{_B.__module__}.{_B.__name__}"]
z = x[f"{_B.__class__.__module__}.{_B.__name__}"]
u = x[f"{_B.__class__.__module__}.{_B.__class__.__qualname__}"]
return y + z + u

args = (
{
f"{_B.__module__}._B": torch.randn(10),
"builtins._B": torch.randn(10),
"builtins.type": torch.randn(10),
},
)

cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)

self.assertEqual(fn(*args), opt_fn(*args))
self.assertEqual(cnts.frame_count, 1)

def test_cell_output1(self):
out = None

Expand Down
Empty file.
33 changes: 22 additions & 11 deletions torch/_dynamo/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from .. import variables
from ..current_scope_id import current_scope_id
from ..exc import unimplemented
from ..exc import unimplemented, Unsupported
from ..source import AttrSource, Source
from ..utils import istype
from ..utils import is_function_or_wrapper, istype


if TYPE_CHECKING:
Expand Down Expand Up @@ -236,18 +236,29 @@ def make_guard(self, fn):
raise NotImplementedError

def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
"""getattr(self, name) returning a python constant"""
raise NotImplementedError
v = self.as_python_constant()
try:
return getattr(v, name)
except AttributeError:
raise NotImplementedError from None

def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
"""getattr(self, name) returning a new variable"""
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError
source = None
if self.source:
source = AttrSource(self.source, name)
return variables.ConstantVariable.create(value, source=source)

from .builder import SourcelessBuilder, VariableBuilder
from .misc import GetAttrVariable

source = self.source and AttrSource(self.source, name)
try:
value = self.const_getattr(tx, name)
if not is_function_or_wrapper(value):
if source:
return VariableBuilder(tx, source)(value=value)
else:
return SourcelessBuilder.create(tx=tx, value=value)
except (NotImplementedError, Unsupported):
pass
return GetAttrVariable(self, name, source=source)

def is_proxy(self):
try:
Expand Down
7 changes: 2 additions & 5 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,11 +1706,8 @@ def call_getattr(
return SourcelessBuilder.create(tx, member)
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
return ConstantVariable.create(getattr(obj.fn, name))
else:
try:
return obj.var_getattr(tx, name)
except NotImplementedError:
return GetAttrVariable(obj, name, **options)

return obj.var_getattr(tx, name)

def call_setattr(
self,
Expand Down
5 changes: 4 additions & 1 deletion torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# mypy: ignore-errors
import functools
import inspect
from typing import Dict, List, TYPE_CHECKING
from typing import Any, Dict, List, TYPE_CHECKING

import torch
from torch.fx.experimental._backward_state import BackwardState
Expand Down Expand Up @@ -214,6 +214,9 @@ def is_device_mesh(value):
def as_python_constant(self):
return self.value

def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
raise NotImplementedError

def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
if name == "ndim":
return ConstantVariable.create(self.value.ndim)
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __init__(self, fn, is_constant=False, **kwargs) -> None:
self.is_constant = False

assert isinstance(
fn, (types.FunctionType, torch.jit.ScriptFunction)
fn,
(types.BuiltinFunctionType, types.FunctionType, torch.jit.ScriptFunction),
), f"expected FunctionType found {typestr(fn)} {fn}"
# TODO(anijain2305) - Replace directly calling UserFunctionVariable with
# VariableBuilder, which handles the wrapping of _torchdynamo_inline.
Expand Down
5 changes: 4 additions & 1 deletion torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import sys
import types
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING

import torch._C
import torch._numpy as tnp
Expand Down Expand Up @@ -1629,6 +1629,9 @@ def python_type(self):
def as_python_constant(self):
return self.random

def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any:
raise NotImplementedError

@staticmethod
def is_supported_random_obj(val):
if type(val) is not random.Random:
Expand Down
35 changes: 18 additions & 17 deletions torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
# (1) the tensor is a traceable tensor subclass
# (2) We are getattr'ing an inner tensor from that subclass
if not self.source and is_traceable_wrapper_subclass(fake_val):
fake_val = self.proxy.node.meta["example_value"]
attrs, ctx = fake_val.__tensor_flatten__()
proxy = getattr(self.as_proxy(), name)
example_value = getattr(fake_val, name)
Expand All @@ -243,14 +242,19 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
return SourcelessBuilder.create(tx, example_value)

if not (self.source and self.source.subguards_allowed()):
raise NotImplementedError
return

from ..guards import CLOSURE_VARS, GuardBuilder

# For local source, we associate the real value. We use this real value
# for implementing getattr fallthrough on the variable tracker base class.

# Note - this scope construction is mirrored in guards
# A subsequent PR will introduce a util.
scope = {"L": tx.output.local_scope, "G": tx.output.global_scope}
scope = {
"L": tx.output.local_scope,
"G": tx.output.global_scope,
**CLOSURE_VARS,
}
try:
# We raise in case we get a typerror bug w/ SuperSource.
# SuperSource has bugs in it atm, and can produce code like
Expand All @@ -259,25 +263,25 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
# Which is incorrect, and violates the invariant that all sources should be eval()-able against the scope.
_input_associated_real_value = eval(self.source.name(), scope)
except Exception as exc:
raise NotImplementedError from exc
msg = f"{exc!r} raised in eval('{self.source.name()}')"
raise NotImplementedError(msg) from exc

real_value = getattr(_input_associated_real_value, name)
if _input_associated_real_value is None:
raise NotImplementedError
return

if object_has_getattribute(_input_associated_real_value):
raise NotImplementedError
return

if get_custom_getattr(_input_associated_real_value):
raise NotImplementedError
return

real_value = getattr(_input_associated_real_value, name)
if callable(real_value):
# Callables have more nuanced handling, and we should let the existing system delegate here.
# Raising was past behavior and so should always be sound to fall back.
# Note - at a certain point we may want to handle
raise NotImplementedError
return

from ..guards import GuardBuilder
from .builder import VariableBuilder

attr_source = AttrSource(self.source, name)
Expand Down Expand Up @@ -1169,8 +1173,6 @@ def var_getattr(self, tx: "InstructionTranslator", name):
from ..utils import numpy_attr_wrapper
from .builder import wrap_fx_proxy

result = None

example_value = self.as_proxy().node.meta["example_value"]
example_ndarray = tnp.ndarray(example_value)

Expand All @@ -1189,7 +1191,7 @@ def insert_into_graph():
(self.as_proxy(), name),
{},
)
result = NumpyNdarrayVariable.create(tx, proxy)
return NumpyNdarrayVariable.create(tx, proxy)

# These are awkward to implement. The standard playbook for torch._numpy
# interop is to trace a call into the torch._numpy wrapper which works for
Expand Down Expand Up @@ -1218,9 +1220,8 @@ def insert_into_graph():
unimplemented(f"TODO: add support for ndarray.{name}")
elif name in ["__version__"]:
unimplemented("delegate np.__version__ to NumPy")
if result is None:
raise NotImplementedError
return result
else:
return super().var_getattr(tx, name)

@staticmethod
def patch_args(name, args, kwargs):
Expand Down
12 changes: 8 additions & 4 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
):
return super().var_getattr(tx, name)

try:
obj = inspect.getattr_static(self.value, name)
except AttributeError:
obj = None
obj = inspect.getattr_static(self.value, name, None)

if isinstance(obj, staticmethod):
func = obj.__get__(self.value)
Expand Down Expand Up @@ -213,6 +210,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
):
return VariableBuilder(tx, source)(obj.__get__(self.value))

if inspect.ismemberdescriptor(obj) or inspect.isdatadescriptor(obj):
value = getattr(self.value, name)
if source is not None:
return VariableBuilder(tx, source)(value=value)
else:
return SourcelessBuilder.create(tx=tx, value=value)

if ConstantVariable.is_literal(obj):
return ConstantVariable.create(obj)
elif isinstance(obj, enum.Enum):
Expand Down

0 comments on commit dd4a51b

Please sign in to comment.