Skip to content

Commit

Permalink
[dynamo] Solve Save/Load OptimizedModule pytorch#101651
Browse files Browse the repository at this point in the history
  • Loading branch information
weiyusheng committed Jun 3, 2024
1 parent 2d1ad0c commit 345f2b1
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 65 deletions.
31 changes: 31 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import collections
import copy
import itertools
import os
import tempfile
import traceback
import types
import unittest
Expand All @@ -15,6 +17,7 @@

import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.debug_utils import same_two_models
import torch.nn.functional as F
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
Expand Down Expand Up @@ -2739,6 +2742,34 @@ def fn(x):
self.assertEqual(test_functions._variable, 1)
self.assertEqual(res, 3 * torch.ones(10))


def test_save_and_load1(self):
mod = MockModule()
opt_mod = torch.compile(mod)
inp = torch.randn(10, 10)
opt_mod(inp)

with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(inp)
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))

def test_save_and_load2(self):
mod = MockModule()
opt_mod = torch.compile(mod, backend='inductor')
inp = torch.randn(10, 10)
opt_mod(inp)

with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(opt_mod, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(inp)
self.assertTrue(same_two_models(loaded_model, mod, [inp]))
self.assertTrue(same_two_models(loaded_model, opt_mod, [inp]))


def test_monkeypatching_forward(self):
class FakeModule(torch.nn.Module):
def forward(self, x):
Expand Down
102 changes: 61 additions & 41 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,26 @@ def profile_wrapper(*args, **kwargs):
return profile_wrapper


def convert_frame_assert(
compiler_fn: CompilerFn,
one_graph: bool = True,
export: bool = False,
export_constraints=None,
):
"""Fully convert a frame into an FX graph"""
reset_graph_break_dup_checker()

def _convert_frame_assert(
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0
class ConvertFrameAssert:
def __init__(
self,
compiler_fn: CompilerFn,
one_graph: bool = True,
export: bool = False,
export_constraints=None,
):
reset_graph_break_dup_checker()
self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
self._one_graph = one_graph
self._export = export
self._export_constraints = export_constraints

@property
def _clone_with_backend(self):
return lambda backend : convert_frame_assert(backend, self._one_graph, self._export, self._export_constraints)

def __call__(
self, frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0
):
increment_frame()

Expand Down Expand Up @@ -458,10 +467,10 @@ def _convert_frame_assert(
frame.f_globals,
frame.f_locals,
frame.f_builtins,
compiler_fn,
one_graph,
export,
export_constraints,
self._torchdynamo_orig_callable,
self._one_graph,
self._export,
self._export_constraints,
hooks,
cache_entry,
cache_size,
Expand All @@ -471,13 +480,14 @@ def _convert_frame_assert(
skip=skip + 1,
)

_convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]

def _clone_with_backend(backend):
return convert_frame_assert(backend, one_graph, export, export_constraints)

_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined]
return _convert_frame_assert
def convert_frame_assert(
compiler_fn: CompilerFn,
one_graph: bool = True,
export: bool = False,
export_constraints=None,
):
"""Fully convert a frame into an FX graph"""
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)


from collections import OrderedDict
Expand Down Expand Up @@ -907,16 +917,22 @@ def format_guard_failures():
torch._dynamo.callback_handler.run_end_callbacks()


def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
class ConvertFrame:
def __init__(self, compiler_fn: CompilerFn, hooks: Hooks):
self._torchdynamo_orig_callable = compiler_fn
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
self._hooks = hooks

@property
def _clone_with_backend(self):
return lambda backend : convert_frame(backend, self._hooks)

def _convert_frame(
frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0
def __call__(
self, frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0
):
counters["frames"]["total"] += 1
try:
result = inner_convert(
result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1
)
counters["frames"]["ok"] += 1
Expand Down Expand Up @@ -980,9 +996,9 @@ def _convert_frame(
log.warning(error_msg, exc_info=True)
return None

_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
_convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined]
return _convert_frame
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
return ConvertFrame(compiler_fn, hooks)


# TODO mlazos: add support for same args, or record them
Expand Down Expand Up @@ -1023,9 +1039,13 @@ def first_real_inst_idx(code):
raise RuntimeError("RESUME instruction not found in code")


def catch_errors_wrapper(callback, hooks: Hooks):
@functools.wraps(callback)
def catch_errors(frame, cache_entry, frame_state):
class CatchErrorsWrapper:
def __init__(self, callback, hooks):
functools.wraps(callback)(self)
self._torchdynamo_orig_callable = callback
self.hooks = hooks

def __call__(self, frame, cache_entry, frame_state):
assert frame_state is not None

is_skipfile = trace_rules.check(frame.f_code)
Expand Down Expand Up @@ -1063,19 +1083,19 @@ def catch_errors(frame, cache_entry, frame_state):

ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_module.bucket_bytes_cap,
backend_compile_fn=callback._torchdynamo_orig_callable,
backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable,
)
assert hasattr(
callback, "_clone_with_backend"
self._torchdynamo_orig_callable, "_clone_with_backend"
), "DDPOptimizer only supports callback fns that know how to clone themselves."
hijacked_callback = callback._clone_with_backend(
hijacked_callback = self._torchdynamo_orig_callable._clone_with_backend(
ddp_optimizer.compile_fn,
)
return hijacked_callback(frame, cache_entry, hooks, frame_state)
return hijacked_callback(frame, cache_entry, self.hooks, frame_state)

with compile_lock, _disable_current_modes():
# skip=1: skip this frame
return callback(frame, cache_entry, hooks, frame_state, skip=1)
return self._torchdynamo_orig_callable(frame, cache_entry, self.hooks, frame_state, skip=1)

catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
return catch_errors
def catch_errors_wrapper(callback, hooks: Hooks):
return CatchErrorsWrapper(callback, hooks)
20 changes: 19 additions & 1 deletion torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def _initialize(self):
self._forward = self.forward
self.forward = self._call_lazy_check

def __reduce__(self):
return (self.__class__, (self._orig_mod, self.dynamo_ctx))

def __getstate__(self):
state = dict(self.__dict__)
state.pop("forward", None)
Expand Down Expand Up @@ -273,9 +276,11 @@ def __init__(
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback
self._backend_ctx_ctor = backend_ctx_ctor
self.prior: Union[Unset, DynamoCallback] = unset
self.first_ctx = first_ctx
self.export = export
self._dynamic = dynamic
self.compiler_config = compiler_config
self.cleanup_fns: List[Callable[[], Any]] = []
self.enter_exit_hooks = []
Expand Down Expand Up @@ -379,7 +384,11 @@ def get_compiler_config():
# call to a builtin without a frame for us to capture
fn = external_utils.wrap_inline(fn)

callback = self.callback
def do_nothing(*arg, **kwargs):
pass
callback = do_nothing
if hasattr(self, 'callback'):
callback = self.callback

is_jit_tracing = torch._C._is_tracing
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing
Expand Down Expand Up @@ -523,6 +532,9 @@ def call_compiled_autograd():
self.enter_exit_hooks.append(call_compiled_autograd)


def __reduce__(self):
return (self.__class__, (self.callback, self._backend_ctx_ctor, self.first_ctx), {'export':self.export, 'dynamic':self._dynamic, 'compiler_config':self.compiler_config})

class RunOnlyContext(_TorchDynamoContext):
def __init__(self):
# cudagraph trees relies on generation increment
Expand All @@ -531,6 +543,9 @@ def on_enter():

super().__init__(callback=False, on_enter=on_enter)

def __reduce__(self):
return (self.__class__, ())


class DisableContext(_TorchDynamoContext):
def __init__(self):
Expand Down Expand Up @@ -583,6 +598,9 @@ def _fn(*args, **kwargs):

return _fn

def __reduce__(self):
return (self.__class__, ())


def _optimize_catch_errors(
compile_fn,
Expand Down
51 changes: 28 additions & 23 deletions torch/_dynamo/repro/after_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn):
)


def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
"""
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
As opposed to wrap_compiler_debug, this wrapper intercepts at the
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
level, e.g., it is useful for minifying issues related to Aot Autograd
tracing. If an error is found, we minify and save the minified repro in
repro.tar.gz.
"""

@functools.wraps(unconfigured_compiler_fn)
def debug_wrapper(gm, example_inputs, **kwargs):
compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
class WrapBackendDebug:
def __init__(self, unconfigured_compiler_fn, compiler_name: str):
functools.wraps(unconfigured_compiler_fn)(self)
self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
self._compiler_name = compiler_name
if hasattr(unconfigured_compiler_fn, "__name__"):
self.__name__ = unconfigured_compiler_fn.__name__
if hasattr(unconfigured_compiler_fn, "compiler_name"):
self.__name__ = unconfigured_compiler_fn.compiler_name
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]

def __call__(self, gm, example_inputs, **kwargs):
compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs)
assert config.repro_after in ("dynamo", "aot", None)

if config.repro_after == "dynamo":
Expand All @@ -82,7 +83,7 @@ def add_paths(exc):
)

if config.repro_level == 3:
dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name)

# Check for either accuracy (level 4) or other type of failures.
if config.repro_level == 4:
Expand All @@ -95,7 +96,7 @@ def add_paths(exc):
dump_to_minify_after_dynamo(
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
example_inputs,
compiler_name,
self._compiler_name,
)
exc = AccuracyError("Bad accuracy detected.")
add_paths(exc)
Expand All @@ -110,7 +111,7 @@ def add_paths(exc):
)
if config.repro_level == 1:
dump_state_fn = functools.partial(
dump_backend_state, compiler_name=compiler_name
dump_backend_state, compiler_name=self._compiler_name
)
dump_state_fn(
fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
Expand All @@ -119,7 +120,7 @@ def add_paths(exc):
dump_to_minify_after_dynamo(
fx.GraphModule(gm, copy.deepcopy(gm.graph)),
example_inputs,
compiler_name,
self._compiler_name,
)
add_paths(exc)
raise
Expand All @@ -128,12 +129,16 @@ def add_paths(exc):

return compiled_gm

debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined]
if hasattr(unconfigured_compiler_fn, "compiler_name"):
debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name
if hasattr(unconfigured_compiler_fn, "get_compiler_config"):
debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined]
return debug_wrapper
def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
"""
A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
As opposed to wrap_compiler_debug, this wrapper intercepts at the
TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
level, e.g., it is useful for minifying issues related to Aot Autograd
tracing. If an error is found, we minify and save the minified repro in
repro.tar.gz.
"""
return WrapBackendDebug(unconfigured_compiler_fn, compiler_name)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
Expand Down

0 comments on commit 345f2b1

Please sign in to comment.