From 05b15e63f09e8efedc06f5e12403e4a05632a3d8 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Wed, 13 Oct 2021 18:04:06 +0530 Subject: [PATCH] Add `strategy` argument to Trainer (#8597) Co-authored-by: Rohit Gupta Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + .../connectors/accelerator_connector.py | 107 ++++++++++++++++-- pytorch_lightning/trainer/trainer.py | 5 + .../test_accelerator_connector.py | 75 +++++++++++- tests/accelerators/test_ipu.py | 19 +++- tests/accelerators/test_tpu_backend.py | 13 +++ tests/deprecated_api/test_remove_1-7.py | 10 ++ tests/models/test_tpu.py | 12 +- tests/trainer/test_trainer.py | 95 ++++++++++++++++ 9 files changed, 323 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b58098b4b1b9..ba1be2433463f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699)) +- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597)) + + ### Changed - Module imports are now catching `ModuleNotFoundError` instead of `ImportError` ([#9867](https://github.com/PyTorchLightning/pytorch-lightning/pull/9867)) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 467fa3b898ada..7e9b25869039e 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -94,6 +94,7 @@ def __init__( ipus, distributed_backend, accelerator, + strategy: Optional[Union[str, TrainingTypePlugin]], gpus, gpu_ids, num_nodes, @@ -111,12 +112,9 @@ def __init__( self._distrib_type = None self._accelerator_type = None - if distributed_backend is not None: - rank_zero_deprecation( - f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5." - f" Use `Trainer(accelerator={distributed_backend})` instead." - ) - distributed_backend = distributed_backend or accelerator + self.strategy = strategy.lower() if isinstance(strategy, str) else strategy + self.distributed_backend = distributed_backend or accelerator + self._init_deterministic(deterministic) self.num_processes = num_processes @@ -126,7 +124,6 @@ def __init__( self.parallel_device_ids = gpu_ids self.tpu_cores = tpu_cores self.ipus = ipus - self.distributed_backend = distributed_backend self.num_nodes = num_nodes self.sync_batchnorm = sync_batchnorm self.benchmark = benchmark @@ -151,16 +148,23 @@ def __init__( self.plugins = plugins + self._handle_accelerator_and_distributed_backend(distributed_backend, accelerator) + self._validate_accelerator_and_devices() self._warn_if_devices_flag_ignored() self.select_accelerator_type() - self.set_distributed_mode() + + if self.strategy is not None: + self._set_training_type_plugin() + else: + self.set_distributed_mode() self.configure_slurm_ddp() self.handle_given_plugins() self.update_device_type_if_ipu_plugin() + self.update_device_type_if_training_type_plugin_passed() self._validate_accelerator_type() self._set_devices_if_none() @@ -228,11 +232,11 @@ def select_accelerator_type(self) -> None: self._set_devices_to_cpu_num_processes() self._accelerator_type = DeviceType.CPU - if self.distributed_backend in ["auto"] + list(DeviceType): + if self.distributed_backend in self.accelerator_types: self.distributed_backend = None def _validate_accelerator_and_devices(self) -> None: - if self.distributed_backend not in ["auto"] + list(DeviceType) and self.devices is not None: + if self.distributed_backend not in self.accelerator_types and self.devices is not None: raise MisconfigurationException( f"You passed `devices={self.devices}` but haven't specified" " `accelerator=('auto'|'tpu'|'gpu'|'ipu'|'cpu')` for the devices mapping," @@ -285,9 +289,56 @@ def _set_devices_if_none(self) -> None: elif self._accelerator_type == DeviceType.CPU: self.devices = self.num_processes + def _handle_accelerator_and_distributed_backend( + self, distributed_backend: Optional[str], accelerator: Optional[Union[str, Accelerator]] + ) -> None: + if distributed_backend is not None: + rank_zero_deprecation( + f"`Trainer(distributed_backend={distributed_backend})` has been deprecated and will be removed in v1.5." + f" Use `Trainer(strategy={distributed_backend})` instead." + ) + if self.strategy is not None: + raise MisconfigurationException( + f"You have passed `Trainer(strategy={self.strategy})` but have" + f" also passed `Trainer(distributed_backend={distributed_backend})`." + f"HINT: Use just `Trainer(strategy={self.strategy})` instead." + ) + + if accelerator is not None and accelerator in list(DistributedType): + rank_zero_deprecation( + f"Passing {accelerator} `strategy` to the `accelerator` flag in Trainer has been deprecated" + f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={accelerator})` instead." + ) + if self.strategy is not None: + raise MisconfigurationException( + f"You have passed `Trainer(strategy={self.strategy})` but have" + f" also passed `Trainer(accelerator={accelerator})`." + f"HINT: Use just `Trainer(strategy={self.strategy})` instead." + ) + + def _set_training_type_plugin(self) -> None: + if isinstance(self.strategy, str) and self.strategy in TrainingTypePluginsRegistry: + self._training_type_plugin = TrainingTypePluginsRegistry.get(self.strategy) + if isinstance(self.strategy, str): + self.set_distributed_mode(self.strategy) + elif isinstance(self.strategy, TrainingTypePlugin): + self._training_type_plugin = self.strategy + def handle_given_plugins(self) -> None: - training_type = None + for plug in self.plugins: + if self.strategy is not None and self._is_plugin_training_type(plug): + raise MisconfigurationException( + f"You have passed `Trainer(strategy={self.strategy})`" + f" and you can only specify one training type plugin, but you have passed {plug} as a plugin." + ) + if self._is_plugin_training_type(plug): + rank_zero_deprecation( + f"Passing {plug} `strategy` to the `plugins` flag in Trainer has been deprecated" + f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plug})` instead." + ) + + training_type = self._training_type_plugin or None checkpoint = None precision = None cluster_environment = None @@ -350,6 +401,10 @@ def handle_given_plugins(self) -> None: self._checkpoint_io = checkpoint self._cluster_environment = cluster_environment or self.select_cluster_environment() + @property + def accelerator_types(self) -> List[str]: + return ["auto"] + list(DeviceType) + @property def precision_plugin(self) -> PrecisionPlugin: if self._precision_plugin is None: @@ -540,9 +595,18 @@ def root_gpu(self) -> Optional[int]: else None ) + @staticmethod + def _is_plugin_training_type(plugin: Union[str, TrainingTypePlugin]) -> bool: + if isinstance(plugin, str) and (plugin in TrainingTypePluginsRegistry or plugin in list(DistributedType)): + return True + return isinstance(plugin, TrainingTypePlugin) + @property def is_training_type_in_plugins(self) -> bool: - return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins) + return any( + (isinstance(plug, str) and plug in TrainingTypePluginsRegistry) or isinstance(plug, TrainingTypePlugin) + for plug in self.plugins + ) def select_precision_plugin(self) -> PrecisionPlugin: # set precision type @@ -875,6 +939,25 @@ def update_device_type_if_ipu_plugin(self) -> None: if isinstance(self._training_type_plugin, IPUPlugin) and self._device_type != DeviceType.IPU: self._device_type = DeviceType.IPU + def update_device_type_if_training_type_plugin_passed(self) -> None: + if isinstance(self.strategy, TrainingTypePlugin) or any( + isinstance(plug, TrainingTypePlugin) for plug in self.plugins + ): + if self._accelerator_type is not None: + if self.use_ipu: + self._device_type = DeviceType.IPU + elif self.use_tpu: + self._device_type = DeviceType.TPU + elif self.use_gpu: + self._device_type = DeviceType.GPU + else: + if self.has_ipu: + self._device_type = DeviceType.IPU + elif self.has_tpu: + self._device_type = DeviceType.TPU + elif self.has_gpu: + self._device_type = DeviceType.GPU + def configure_slurm_ddp(self): # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f5b4d6bdeda28..6699f3554b80c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -155,6 +155,7 @@ def __init__( flush_logs_every_n_steps: Optional[int] = None, log_every_n_steps: int = 50, accelerator: Optional[Union[str, Accelerator]] = None, + strategy: Optional[Union[str, TrainingTypePlugin]] = None, sync_batchnorm: bool = False, precision: Union[int, str] = 32, enable_model_summary: bool = True, @@ -354,6 +355,9 @@ def __init__( no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. + strategy: Supports different training strategies with aliases + as well custom training type plugins. + sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the @@ -424,6 +428,7 @@ def __init__( ipus, distributed_backend, accelerator, + strategy, gpus, gpu_ids, num_nodes, diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 0fe5b2824b82f..6a41e9032ea16 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -26,6 +26,7 @@ from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import ( + DataParallelPlugin, DDP2Plugin, DDPPlugin, DDPShardedPlugin, @@ -42,7 +43,7 @@ SLURMEnvironment, TorchElasticEnvironment, ) -from pytorch_lightning.utilities import DistributedType +from pytorch_lightning.utilities import DeviceType, DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -631,6 +632,78 @@ def test_accelerator_ddp_for_cpu(tmpdir): assert isinstance(trainer.training_type_plugin, DDPPlugin) +def test_exception_when_strategy_used_with_distributed_backend(): + with pytest.raises(MisconfigurationException, match="but have also passed"): + Trainer(distributed_backend="ddp_cpu", strategy="ddp_spawn") + + +def test_exception_when_strategy_used_with_accelerator(): + with pytest.raises(MisconfigurationException, match="but have also passed"): + Trainer(accelerator="ddp", strategy="ddp_spawn") + + +def test_exception_when_strategy_used_with_plugins(): + with pytest.raises(MisconfigurationException, match="only specify one training type plugin, but you have passed"): + Trainer(plugins="ddp_find_unused_parameters_false", strategy="ddp_spawn") + + +@pytest.mark.parametrize( + ["strategy", "plugin"], + [ + ("ddp_spawn", DDPSpawnPlugin), + ("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin), + ("ddp", DDPPlugin), + ("ddp_find_unused_parameters_false", DDPPlugin), + ], +) +def test_strategy_choice_cpu_str(tmpdir, strategy, plugin): + trainer = Trainer(strategy=strategy, accelerator="cpu", devices=2) + assert isinstance(trainer.training_type_plugin, plugin) + + +@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin]) +def test_strategy_choice_cpu_plugin(tmpdir, plugin): + trainer = Trainer(strategy=plugin(), accelerator="cpu", devices=2) + assert isinstance(trainer.training_type_plugin, plugin) + + +@RunIf(min_gpus=2) +@pytest.mark.parametrize( + ["strategy", "plugin"], + [ + ("ddp_spawn", DDPSpawnPlugin), + ("ddp_spawn_find_unused_parameters_false", DDPSpawnPlugin), + ("ddp", DDPPlugin), + ("ddp_find_unused_parameters_false", DDPPlugin), + ("ddp2", DDP2Plugin), + ("dp", DataParallelPlugin), + ("ddp_sharded", DDPShardedPlugin), + ("ddp_sharded_spawn", DDPSpawnShardedPlugin), + pytest.param("deepspeed", DeepSpeedPlugin, marks=RunIf(deepspeed=True)), + ], +) +def test_strategy_choice_gpu_str(tmpdir, strategy, plugin): + trainer = Trainer(strategy=strategy, accelerator="gpu", devices=2) + assert isinstance(trainer.training_type_plugin, plugin) + + +@RunIf(min_gpus=2) +@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin]) +def test_strategy_choice_gpu_plugin(tmpdir, plugin): + trainer = Trainer(strategy=plugin(), accelerator="gpu", devices=2) + assert isinstance(trainer.training_type_plugin, plugin) + + +@RunIf(min_gpus=2) +@pytest.mark.parametrize("plugin", [DDPSpawnPlugin, DDPPlugin]) +def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin): + + trainer = Trainer(strategy=plugin(), gpus=2) + assert isinstance(trainer.training_type_plugin, plugin) + assert trainer._device_type == DeviceType.GPU + assert isinstance(trainer.accelerator, GPUAccelerator) + + @pytest.mark.parametrize("precision", [1, 12, "invalid"]) def test_validate_precision_type(tmpdir, precision): diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index acb3fd65959eb..c8c557eab4ebf 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -24,7 +24,7 @@ from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities import _IPU_AVAILABLE +from pytorch_lightning.utilities import _IPU_AVAILABLE, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir): @RunIf(ipu=True) def test_no_warning_plugin(tmpdir): with pytest.warns(None) as record: - Trainer(default_root_dir=tmpdir, plugins=IPUPlugin(training_opts=poptorch.Options())) + Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options())) assert len(record) == 0 @@ -528,3 +528,18 @@ def test_set_devices_if_none_ipu(): trainer = Trainer(accelerator="ipu", ipus=8) assert trainer.devices == 8 + + +@RunIf(ipu=True) +def test_strategy_choice_ipu_plugin(tmpdir): + trainer = Trainer(strategy=IPUPlugin(), accelerator="ipu", devices=8) + assert isinstance(trainer.training_type_plugin, IPUPlugin) + + +@RunIf(ipu=True) +def test_device_type_when_training_plugin_ipu_passed(tmpdir): + + trainer = Trainer(strategy=IPUPlugin(), ipus=8) + assert isinstance(trainer.training_type_plugin, IPUPlugin) + assert trainer._device_type == DeviceType.IPU + assert isinstance(trainer.accelerator, IPUAccelerator) diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 5d021851bad4e..df5444ac776a6 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -227,6 +227,19 @@ def test_ddp_cpu_not_supported_on_tpus(): Trainer(accelerator="ddp_cpu") +@RunIf(tpu=True) +@pytest.mark.parametrize("strategy", ["tpu_spawn", "tpu_spawn_debug"]) +def test_strategy_choice_tpu_str(tmpdir, strategy): + trainer = Trainer(strategy=strategy, accelerator="tpu", devices=8) + assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) + + +@RunIf(tpu=True) +def test_strategy_choice_tpu_plugin(tmpdir): + trainer = Trainer(strategy=TPUSpawnPlugin(), accelerator="tpu", devices=8) + assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) + + @RunIf(tpu=True) def test_auto_parameters_tying_tpus(tmpdir): diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 3d0990e33db01..08449f6fbbcff 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -343,6 +343,16 @@ def test_v1_7_0_deprecate_parameter_validation(): from pytorch_lightning.core.decorators import parameter_validation # noqa: F401 +def test_v1_7_0_passing_strategy_to_accelerator_trainer_flag(): + with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): + Trainer(accelerator="ddp_spawn") + + +def test_v1_7_0_passing_strategy_to_plugins_flag(): + with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): + Trainer(plugins="ddp_spawn") + + def test_v1_7_0_weights_summary_trainer(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"): t = Trainer(weights_summary="full") diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b320842f5ae81..13003676b176b 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -26,7 +26,7 @@ from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset @@ -473,3 +473,13 @@ def teardown(self, stage): model = DebugModel() tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_device_type_when_training_plugin_tpu_passed(tmpdir): + + trainer = Trainer(strategy=TPUSpawnPlugin(), tpu_cores=8) + assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin) + assert trainer._device_type == DeviceType.TPU + assert isinstance(trainer.accelerator, TPUAccelerator) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 213eac01fce68..38034aa65b360 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2077,3 +2077,98 @@ def training_step(self, batch, batch_idx): UserWarning, match=r".*Error detected in.* Traceback of forward call that caused the error.*" ): trainer.fit(model) + + +@pytest.mark.parametrize( + "trainer_kwargs,expected", + [ + ( + dict(strategy=None, gpus=None), + dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + ), + ( + dict(strategy="dp", gpus=None), + dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + ), + ( + dict(strategy="ddp", gpus=None), + dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + ), + ( + dict(strategy="ddp", num_processes=2, gpus=None), + dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + ), + ( + dict(strategy="ddp", num_nodes=2, gpus=None), + dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + ), + ( + dict(strategy="ddp_cpu", num_processes=2, gpus=None), + dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + ), + ( + dict(strategy="ddp2", gpus=None), + dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + ), + ( + dict(strategy=None, gpus=1), + dict(_distrib_type=None, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + ), + ( + dict(strategy="dp", gpus=1), + dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + ), + ( + dict(strategy="ddp", gpus=1), + dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + ), + ( + dict(strategy="ddp_cpu", num_processes=2, gpus=1), + dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + ), + ( + dict(strategy="ddp2", gpus=1), + dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=1, num_processes=1), + ), + ( + dict(strategy=None, gpus=2), + dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + ), + ( + dict(strategy="dp", gpus=2), + dict(_distrib_type=DistributedType.DP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + ), + ( + dict(strategy="ddp", gpus=2), + dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.GPU, num_gpus=2, num_processes=2), + ), + ( + dict(strategy="ddp2", gpus=2), + dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1), + ), + ( + dict(strategy="ddp2", num_processes=2, gpus=None), + dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + ), + ( + dict(strategy="dp", num_processes=2, gpus=None), + dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + ), + ( + dict(strategy="ddp_spawn", num_processes=2, gpus=None), + dict(_distrib_type=DistributedType.DDP_SPAWN, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2), + ), + ( + dict(strategy="ddp_spawn", num_processes=1, gpus=None), + dict(_distrib_type=None, _device_type=DeviceType.CPU, num_gpus=0, num_processes=1), + ), + ], +) +def test_trainer_config_strategy(trainer_kwargs, expected, monkeypatch): + if trainer_kwargs["gpus"] is not None: + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: trainer_kwargs["gpus"]) + trainer = Trainer(**trainer_kwargs) + assert len(expected) == 4 + for k, v in expected.items(): + assert getattr(trainer, k) == v, f"Failed {k}: {v}"