diff --git a/test/conftest.py b/test/conftest.py index 0a909a1e2..cd9b12cc6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,12 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import multiprocessing import os import time from collections import defaultdict import pytest +try: + multiprocessing.set_start_method("spawn") +except Exception: + assert multiprocessing.get_start_method() == "spawn" + CALL_TIMES = defaultdict(lambda: 0.0) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 07eac0ec1..bc9e45963 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -134,7 +134,7 @@ ), ] -mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn" +mp_ctx = "spawn" @pytest.fixture diff --git a/tutorials/sphinx_tuto/export.py b/tutorials/sphinx_tuto/export.py index df8e3fda5..d7944e81d 100644 --- a/tutorials/sphinx_tuto/export.py +++ b/tutorials/sphinx_tuto/export.py @@ -120,13 +120,13 @@ t0 = time.time() model(x=x) -print(f"Time for TDModule: {(time.time()-t0)*1e6: 4.2f} micro-seconds") +print(f"Time for TDModule: {(time.time() - t0) * 1e6: 4.2f} micro-seconds") exported = model_export.module() # Exported version t0 = time.time() exported(x=x) -print(f"Time for exported module: {(time.time()-t0)*1e6: 4.2f} micro-seconds") +print(f"Time for exported module: {(time.time() - t0) * 1e6: 4.2f} micro-seconds") ################################################## # and the FX graph: diff --git a/tutorials/sphinx_tuto/tensordict_module.py b/tutorials/sphinx_tuto/tensordict_module.py index b31b6c7bd..7e8099fac 100644 --- a/tutorials/sphinx_tuto/tensordict_module.py +++ b/tutorials/sphinx_tuto/tensordict_module.py @@ -213,13 +213,13 @@ def forward(self, x): from torch.utils.benchmark import Timer print( - f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" + f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us" ) print( - f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" + f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us" ) print( - f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" + f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us" ) print("Compiled versions") @@ -227,19 +227,19 @@ def forward(self, x): for _ in range(5): # warmup block_notd_c(x) print( - f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" + f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us" ) block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead") for _ in range(5): # warmup block_tdm_c(x=x) print( - f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" + f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us" ) block_tds_c = torch.compile(block_tds, mode="reduce-overhead") for _ in range(5): # warmup block_tds_c(x=x) print( - f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median*1_000_000: 4.4f} us" + f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us" ) ###############################################################################