From 345f2b1694e949a0a355026fc0617bd37d16a3cf Mon Sep 17 00:00:00 2001 From: weiyusheng Date: Tue, 28 May 2024 11:56:15 +0800 Subject: [PATCH] [dynamo] Solve Save/Load OptimizedModule https://github.com/pytorch/pytorch/pull/101651 --- test/dynamo/test_modules.py | 31 +++++++++ torch/_dynamo/convert_frame.py | 102 +++++++++++++++++----------- torch/_dynamo/eval_frame.py | 20 +++++- torch/_dynamo/repro/after_dynamo.py | 51 +++++++------- 4 files changed, 139 insertions(+), 65 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index b22f02ee2fcc40..80863ac435c05d 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -3,6 +3,8 @@ import collections import copy import itertools +import os +import tempfile import traceback import types import unittest @@ -15,6 +17,7 @@ import torch._dynamo.test_case import torch._dynamo.testing +from torch._dynamo.debug_utils import same_two_models import torch.nn.functional as F from torch._dynamo.eval_frame import unsupported from torch._dynamo.mutation_guard import GenerationTracker @@ -2739,6 +2742,34 @@ def fn(x): self.assertEqual(test_functions._variable, 1) self.assertEqual(res, 3 * torch.ones(10)) + + def test_save_and_load1(self): + mod = MockModule() + opt_mod = torch.compile(mod) + inp = torch.randn(10, 10) + opt_mod(inp) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) + loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + loaded_model(inp) + self.assertTrue(same_two_models(loaded_model, mod, [inp])) + self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) + + def test_save_and_load2(self): + mod = MockModule() + opt_mod = torch.compile(mod, backend='inductor') + inp = torch.randn(10, 10) + opt_mod(inp) + + with tempfile.TemporaryDirectory() as tmpdirname: + torch.save(opt_mod, os.path.join(tmpdirname, "model.pt")) + loaded_model = torch.load(os.path.join(tmpdirname, "model.pt")) + loaded_model(inp) + self.assertTrue(same_two_models(loaded_model, mod, [inp])) + self.assertTrue(same_two_models(loaded_model, opt_mod, [inp])) + + def test_monkeypatching_forward(self): class FakeModule(torch.nn.Module): def forward(self, x): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index e779ccef9e3897..4bb4c4f52e9358 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -361,17 +361,26 @@ def profile_wrapper(*args, **kwargs): return profile_wrapper -def convert_frame_assert( - compiler_fn: CompilerFn, - one_graph: bool = True, - export: bool = False, - export_constraints=None, -): - """Fully convert a frame into an FX graph""" - reset_graph_break_dup_checker() - - def _convert_frame_assert( - frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 +class ConvertFrameAssert: + def __init__( + self, + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, + ): + reset_graph_break_dup_checker() + self._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + self._one_graph = one_graph + self._export = export + self._export_constraints = export_constraints + + @property + def _clone_with_backend(self): + return lambda backend : convert_frame_assert(backend, self._one_graph, self._export, self._export_constraints) + + def __call__( + self, frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, *, skip: int = 0 ): increment_frame() @@ -458,10 +467,10 @@ def _convert_frame_assert( frame.f_globals, frame.f_locals, frame.f_builtins, - compiler_fn, - one_graph, - export, - export_constraints, + self._torchdynamo_orig_callable, + self._one_graph, + self._export, + self._export_constraints, hooks, cache_entry, cache_size, @@ -471,13 +480,14 @@ def _convert_frame_assert( skip=skip + 1, ) - _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] - - def _clone_with_backend(backend): - return convert_frame_assert(backend, one_graph, export, export_constraints) - - _convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined] - return _convert_frame_assert +def convert_frame_assert( + compiler_fn: CompilerFn, + one_graph: bool = True, + export: bool = False, + export_constraints=None, +): + """Fully convert a frame into an FX graph""" + return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) from collections import OrderedDict @@ -907,16 +917,22 @@ def format_guard_failures(): torch._dynamo.callback_handler.run_end_callbacks() -def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): - """Try to convert a frame into an FX graph, if error leave frame unmodified""" - inner_convert = convert_frame_assert(compiler_fn, one_graph=False) +class ConvertFrame: + def __init__(self, compiler_fn: CompilerFn, hooks: Hooks): + self._torchdynamo_orig_callable = compiler_fn + self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._hooks = hooks + + @property + def _clone_with_backend(self): + return lambda backend : convert_frame(backend, self._hooks) - def _convert_frame( - frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 + def __call__( + self, frame: types.FrameType, cache_entry, hooks: Hooks, frame_state, skip: int = 0 ): counters["frames"]["total"] += 1 try: - result = inner_convert( + result = self._inner_convert( frame, cache_entry, hooks, frame_state, skip=skip + 1 ) counters["frames"]["ok"] += 1 @@ -980,9 +996,9 @@ def _convert_frame( log.warning(error_msg, exc_info=True) return None - _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] - _convert_frame._clone_with_backend = lambda backend: convert_frame(backend, hooks) # type: ignore[attr-defined] - return _convert_frame +def convert_frame(compiler_fn: CompilerFn, hooks: Hooks): + """Try to convert a frame into an FX graph, if error leave frame unmodified""" + return ConvertFrame(compiler_fn, hooks) # TODO mlazos: add support for same args, or record them @@ -1023,9 +1039,13 @@ def first_real_inst_idx(code): raise RuntimeError("RESUME instruction not found in code") -def catch_errors_wrapper(callback, hooks: Hooks): - @functools.wraps(callback) - def catch_errors(frame, cache_entry, frame_state): +class CatchErrorsWrapper: + def __init__(self, callback, hooks): + functools.wraps(callback)(self) + self._torchdynamo_orig_callable = callback + self.hooks = hooks + + def __call__(self, frame, cache_entry, frame_state): assert frame_state is not None is_skipfile = trace_rules.check(frame.f_code) @@ -1063,19 +1083,19 @@ def catch_errors(frame, cache_entry, frame_state): ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_module.bucket_bytes_cap, - backend_compile_fn=callback._torchdynamo_orig_callable, + backend_compile_fn=self._torchdynamo_orig_callable._torchdynamo_orig_callable, ) assert hasattr( - callback, "_clone_with_backend" + self._torchdynamo_orig_callable, "_clone_with_backend" ), "DDPOptimizer only supports callback fns that know how to clone themselves." - hijacked_callback = callback._clone_with_backend( + hijacked_callback = self._torchdynamo_orig_callable._clone_with_backend( ddp_optimizer.compile_fn, ) - return hijacked_callback(frame, cache_entry, hooks, frame_state) + return hijacked_callback(frame, cache_entry, self.hooks, frame_state) with compile_lock, _disable_current_modes(): # skip=1: skip this frame - return callback(frame, cache_entry, hooks, frame_state, skip=1) + return self._torchdynamo_orig_callable(frame, cache_entry, self.hooks, frame_state, skip=1) - catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined] - return catch_errors +def catch_errors_wrapper(callback, hooks: Hooks): + return CatchErrorsWrapper(callback, hooks) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 318fdd2650857c..0a8b0f4f264ff1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -168,6 +168,9 @@ def _initialize(self): self._forward = self.forward self.forward = self._call_lazy_check + def __reduce__(self): + return (self.__class__, (self._orig_mod, self.dynamo_ctx)) + def __getstate__(self): state = dict(self.__dict__) state.pop("forward", None) @@ -273,9 +276,11 @@ def __init__( super().__init__() assert callable(callback) or callback is False or callback is None self.callback: DynamoCallback = callback + self._backend_ctx_ctor = backend_ctx_ctor self.prior: Union[Unset, DynamoCallback] = unset self.first_ctx = first_ctx self.export = export + self._dynamic = dynamic self.compiler_config = compiler_config self.cleanup_fns: List[Callable[[], Any]] = [] self.enter_exit_hooks = [] @@ -379,7 +384,11 @@ def get_compiler_config(): # call to a builtin without a frame for us to capture fn = external_utils.wrap_inline(fn) - callback = self.callback + def do_nothing(*arg, **kwargs): + pass + callback = do_nothing + if hasattr(self, 'callback'): + callback = self.callback is_jit_tracing = torch._C._is_tracing is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing @@ -523,6 +532,9 @@ def call_compiled_autograd(): self.enter_exit_hooks.append(call_compiled_autograd) + def __reduce__(self): + return (self.__class__, (self.callback, self._backend_ctx_ctor, self.first_ctx), {'export':self.export, 'dynamic':self._dynamic, 'compiler_config':self.compiler_config}) + class RunOnlyContext(_TorchDynamoContext): def __init__(self): # cudagraph trees relies on generation increment @@ -531,6 +543,9 @@ def on_enter(): super().__init__(callback=False, on_enter=on_enter) + def __reduce__(self): + return (self.__class__, ()) + class DisableContext(_TorchDynamoContext): def __init__(self): @@ -583,6 +598,9 @@ def _fn(*args, **kwargs): return _fn + def __reduce__(self): + return (self.__class__, ()) + def _optimize_catch_errors( compile_fn, diff --git a/torch/_dynamo/repro/after_dynamo.py b/torch/_dynamo/repro/after_dynamo.py index 76b9128e699520..73060f8806d235 100644 --- a/torch/_dynamo/repro/after_dynamo.py +++ b/torch/_dynamo/repro/after_dynamo.py @@ -56,19 +56,20 @@ def _accuracy_fails(gm, example_inputs, compiler_fn): ) -def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): - """ - A minifier decorator that wraps the TorchDynamo produced Fx graph modules. - As opposed to wrap_compiler_debug, this wrapper intercepts at the - TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some - level, e.g., it is useful for minifying issues related to Aot Autograd - tracing. If an error is found, we minify and save the minified repro in - repro.tar.gz. - """ - - @functools.wraps(unconfigured_compiler_fn) - def debug_wrapper(gm, example_inputs, **kwargs): - compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) +class WrapBackendDebug: + def __init__(self, unconfigured_compiler_fn, compiler_name: str): + functools.wraps(unconfigured_compiler_fn)(self) + self._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] + self._compiler_name = compiler_name + if hasattr(unconfigured_compiler_fn, "__name__"): + self.__name__ = unconfigured_compiler_fn.__name__ + if hasattr(unconfigured_compiler_fn, "compiler_name"): + self.__name__ = unconfigured_compiler_fn.compiler_name + if hasattr(unconfigured_compiler_fn, "get_compiler_config"): + self.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] + + def __call__(self, gm, example_inputs, **kwargs): + compiler_fn = functools.partial(self._torchdynamo_orig_callable, **kwargs) assert config.repro_after in ("dynamo", "aot", None) if config.repro_after == "dynamo": @@ -82,7 +83,7 @@ def add_paths(exc): ) if config.repro_level == 3: - dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) + dump_to_minify_after_dynamo(gm, example_inputs, self._compiler_name) # Check for either accuracy (level 4) or other type of failures. if config.repro_level == 4: @@ -95,7 +96,7 @@ def add_paths(exc): dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, - compiler_name, + self._compiler_name, ) exc = AccuracyError("Bad accuracy detected.") add_paths(exc) @@ -110,7 +111,7 @@ def add_paths(exc): ) if config.repro_level == 1: dump_state_fn = functools.partial( - dump_backend_state, compiler_name=compiler_name + dump_backend_state, compiler_name=self._compiler_name ) dump_state_fn( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs @@ -119,7 +120,7 @@ def add_paths(exc): dump_to_minify_after_dynamo( fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs, - compiler_name, + self._compiler_name, ) add_paths(exc) raise @@ -128,12 +129,16 @@ def add_paths(exc): return compiled_gm - debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn # type: ignore[attr-defined] - if hasattr(unconfigured_compiler_fn, "compiler_name"): - debug_wrapper.__name__ = unconfigured_compiler_fn.compiler_name - if hasattr(unconfigured_compiler_fn, "get_compiler_config"): - debug_wrapper.get_compiler_config = unconfigured_compiler_fn.get_compiler_config # type: ignore[attr-defined] - return debug_wrapper +def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): + """ + A minifier decorator that wraps the TorchDynamo produced Fx graph modules. + As opposed to wrap_compiler_debug, this wrapper intercepts at the + TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some + level, e.g., it is useful for minifying issues related to Aot Autograd + tracing. If an error is found, we minify and save the minified repro in + repro.tar.gz. + """ + return WrapBackendDebug(unconfigured_compiler_fn, compiler_name) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #