diff --git a/sheeprl/utils/fabric.py b/sheeprl/utils/fabric.py index bd6a60cf..e153993c 100644 --- a/sheeprl/utils/fabric.py +++ b/sheeprl/utils/fabric.py @@ -1,3 +1,5 @@ +from unittest import mock + from lightning.fabric import Fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies import SingleDeviceStrategy, SingleDeviceXLAStrategy @@ -23,4 +25,11 @@ def get_single_device_fabric(fabric: Fabric) -> Fabric: checkpoint_io=None, precision=fabric._precision, ) - return Fabric(strategy=strategy) + with mock.patch.dict("os.environ") as mocked_os_environ: + mocked_os_environ.pop("LT_DEVICES", None) + mocked_os_environ.pop("LT_STRATEGY", None) + mocked_os_environ.pop("LT_NUM_NODES", None) + mocked_os_environ.pop("LT_PRECISION", None) + mocked_os_environ.pop("LT_ACCELERATOR", None) + fabric = Fabric(strategy=strategy) + return fabric diff --git a/tests/test_utils/test_fabric.py b/tests/test_utils/test_fabric.py new file mode 100644 index 00000000..77c22620 --- /dev/null +++ b/tests/test_utils/test_fabric.py @@ -0,0 +1,13 @@ +from lightning import Fabric +from lightning.fabric.strategies import SingleDeviceStrategy + +from sheeprl.utils.fabric import get_single_device_fabric + + +def test_get_single_device_fabric(): + fabric = Fabric(devices=2, accelerator="cpu", precision=16) + single_device_fabric = get_single_device_fabric(fabric) + assert single_device_fabric.device == fabric.device + assert single_device_fabric._precision == fabric._precision + assert single_device_fabric.accelerator == fabric.accelerator + assert isinstance(single_device_fabric.strategy, SingleDeviceStrategy)