Skip to content

Commit

Permalink
Ignore fresh unbacked when doing recursive make_fx inside HOPs (pytor…
Browse files Browse the repository at this point in the history
…ch#135053)

Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7705964779531357/

I'm not sure this is the right approach though...

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#135053
Approved by: https://github.com/ydwu4
ghstack dependencies: pytorch#134407
  • Loading branch information
ezyang authored and pytorchmergebot committed Sep 4, 2024
1 parent 46cb2af commit a178a05
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
21 changes: 21 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,27 @@ def false_fn(x):
M()(torch.randn(7))
torch.export.export(M(), (torch.randn(7),))

@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_cond_contains_unbacked_no_escape(self):
class M(torch.nn.Module):
def forward(self, a, b1, b2, c):
def true_fn(x):
return x * b1.item()

def false_fn(x):
return x * b2.item()

r = torch.cond(a, true_fn, false_fn, (c,))
return r * 2

args = (
torch.tensor(True),
torch.tensor([4]),
torch.tensor([4]),
torch.randn(10, requires_grad=True),
)
torch.export.export(M(), args)

def test_state_tensors(self):
class M(torch.nn.Module): # simple with register buffer
def __init__(self) -> None:
Expand Down
10 changes: 7 additions & 3 deletions torch/_functorch/_aot_autograd/collect_metadata_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

import collections
import contextlib
import logging
from functools import wraps
from typing import Callable, DefaultDict, Dict, List, Optional
Expand Down Expand Up @@ -162,15 +163,18 @@ def inner(*flat_args):
# It doesn't matter if we run this under predispatch or not because it is
# only for figuring out metadata
mode = FunctionalTensorMode(_allow_token_discovery=True)
with disable_above, mode:
suppress_pending = contextlib.nullcontext()
fake_mode = detect_fake_mode()
if fake_mode and (shape_env := fake_mode.shape_env):
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
with disable_above, mode, suppress_pending:
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_args = pytree.tree_map(_to_fun, flat_args)
flat_f_outs = f(*flat_f_args)
# We didn't do any tracing, so we don't need to process the
# unbacked symbols, they will just disappear into the ether.
# Also, prevent memoization from applying.
if (fake_mode := detect_fake_mode()) and (shape_env := fake_mode.shape_env):
shape_env.pending_fresh_unbacked_symbols.clear()
if fake_mode:
fake_mode.epoch += 1
fake_mode.reset_nt_tensor_id_counter()

Expand Down

0 comments on commit a178a05

Please sign in to comment.